Comments (7)
Thanks for reaching out. Adding these datasets sounds interesting and happy to support you.
FedJAX uses sqlite files instead of tfrecord files. One way to add support for these datasets would be as follows:
- Create a script to download the datasets from TFF, postprocess the dataset by the scripts you mentioned.
- Use SQLiteFederatedDataBuilder to build a sqlite dataset file. Note that SQLiteFederatedDataBuilder takes as input client_ids and their corresponding client examples. Hence as long as we are able to read and parse the client examples from TFF's tfrecords, it can be used in conjunction with SQLiteFederatedDataBuilder.
- Cache the sqlite file for further use.
Some of the utilities in #216 might be useful.
Please let us know if you have any questions.
from fedjax.
Hi @stheertha and thanks for the support!
I'll have a look at the SQLite stuff and get back to you.
I went a different way for the moment (just to play around) by creating a TFRecordFederatedDataset
class that extends InMemoryFederatedData
. Basically, I just read the TFRecords (one per client) and map them to client_data.Examples
just to see if that would work.
I don't think that's the way to go, mainly because images have different shapes and I can't create the numpy objects this way. I'm still looking into it, but do you think that would be an issue with the SQLiteFederatedData
too?
from fedjax.
Hi @marcociccone! I am not very familiar with these two datasets. By "images have different shapes" do you mean images in these datasets are not already transformed into a uniform height/width? How about images belonging to the same client? Can they be different in height/width too?
In JAX in general, we need to keep the possible input shapes to a small set to avoid repetitive XLA compilations (each unique input shape configuration will require one XLA compilation), so padding or some other types of transformation is needed to ensure uniform input shapes.
The main problem with deciding what to do with images in different shapes is first deciding how models consume them, so that we can choose an appropriate storage format (i.e. either padding or resizing). I am not very knowledgeable with image models. How do they deal with a batch of images that are in different shapes? Base on my limited understanding of a Conv layer, won't different input shapes lead to different output shapes after a Conv layer? What will an output layer do in that case?
from fedjax.
Hi @kho! By looking at images in the tfrecord of a randomly sampled client, I see that images have different height/width.
Images should be then randomly cropped to 299x299 (inaturalist) or 224x224(gldv2) before being batched together and consumed by the neural network (see the ECCV2020 paper proposing these two datasets sec 6.5 for more details).
This is a standard data augmentation practice when dealing with image datasets to increase the variability of the dataset.
See also this input data pipeline for imagenet as a reference.
Do you think that doing something like that would be possible with the current fedjax data pipeline?
from fedjax.
Thanks for the clarification. This is supported by FedJAX pipeline and can be done in two ways:
Option 1: When converting tfrecords to SQLiteFederatedDataBuilder
, do some offline processing to make all images in the same shape. For example, this can be done via zero padding. This will ensure everything can be stored in the numpy format. Then during training and evaluation, use BatchPreprocessor
to do the preprocessing (e.g., random cropping) as required for both train and test datases.
Option 2: When converting tfrecords to SQLiteFederatedDataBuilder
, do entire preprocessing offline including random cropping offline and then store the processed image, all of which have the shape (299, 299). This way the stored dataset is already preprocessed and it can be just used directly without additional preprocessing.
from fedjax.
Thanks for your answer!
I think option 1 is the way to go to ensure enough data variability (multiple crops for each image).
However, gldv2 image sizes span from 300 to 800 pixels, and some clients have up to 1K images so zero-padding to the max shape (800x800) and storing the np.array in memory would require around 14gb. Also I should check that heavily padded images aren't mostly empty when cropped.
I still need to check the codebase carefully but what if we create a TfDataClientDataset
class that iterates over the tf.data
object rather than the np.array
?
This would be more efficient in terms of memory and allow us to take advantage of the tf.data
input pipeline.
from fedjax.
Sorry about the confusion, I didn't know the datasets were this big (should have read the READMEs more carefully). Could you help me run some quick stats on gldv2? That will help me figure out if putting everything inside a SQLite database is feasible.
- The actual {max, median, average} total number of pixels of all images from the same client (I am wondering if 800x800x1000 is too pessimistic).
- The total number of clients in the training set.
- The total number of pixels of all images in the training set.
Regarding your proposal of wrapping tf.data
, there is a very significant overhead in iterator creation in tf.data.Dataset
, which become problematic in federated learning since we need to create many dataset iterators during a single experiment. However the calculus might just be different if an individual client is big enough. The stats above will also help in evaluating that.
I also have one question about how people outside Google usually work with such big datasets. Are the files actually stored on local disks, some NFS volume, or some other distributed file system (e.g. GCS or S3)?
from fedjax.
Related Issues (17)
- Clarifying the meaning of "weight" HOT 1
- Centralized (server-only) algorithm HOT 1
- Implementing SCAFFOLD HOT 2
- Feature request: Convert standard dataset into a federated dataset HOT 5
- PatentMime^
- Full EMNIST example does not exhibit parallelization HOT 2
- FedJax depends on TensorFlow Federated? HOT 7
- How to create a validation dataset? HOT 4
- Problem of Quick Start in Readme.md HOT 2
- Support for haiku models with non-trainable state HOT 2
- Support for manually modifying client/server learning rate HOT 1
- CIFAR 100 Questions HOT 5
- Implement standard CIFAR-100 model in fedjax.models.cifar100 HOT 1
- Add support for stateful clients HOT 8
- Error of the Stackoverflow Tokernizer example HOT 1
- External contributions? HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from fedjax.