Code Monkey home page Code Monkey logo

functa's Introduction

Functa

This repository contains code for the ICML 2022 paper "From data to functa: Your data point is a function and you can treat it like one" by Emilien Dupont*, Hyunjik Kim*, Ali Eslami, Danilo Rezende and Dan Rosenbaum. *Denotes joint first authorship.

The codebase contains the meta-learning experiment for CelebA-HQ-64 and SRN CARS, along with a colab that creates a modulation dataset for CelebA-HQ-64.

Setup

To set up a Python virtual environment with the required dependencies, run:

# create virtual environment
python3 -m venv /tmp/functa_venv
source /tmp/functa_venv/bin/activate
# update pip, setuptools and wheel
pip3 install --upgrade pip setuptools wheel
# install all required packages
pip3 install -r requirements.txt

Note that the directory containing this repository must be included in the PYTHONPATH environment variable. This can be done by e.g.,

export PYTHONPATH=DIR_CONTAINING_FUNCTA

Once done with virtual environment, deactivate with command:

deactivate

then delete venv with command:

rm -r /tmp/functa_venv

Setup celeb_a_hq_custom dataset as Tensorflow dataset (TFDS)

The publicly available celeb_a_hq dataset with TFDS at https://www.tensorflow.org/datasets/catalog/celeb_a_hq requires manual preparation, for which there are some known issues: tensorflow/datasets#1496. Alternatively, there exist zip files that are publicly available for download. We convert the 128x128 resolution version into a tensorflow dataset (TFDS) so that we can readily load the data into our jax/haiku models with various data processing options that come with tfds. Note that the resulting dataset has a different ordering to the tfds version, hence any train/test split further down the line may be different to the one used in our paper, and the downsampling algorithm used may be different. We use tf.image.resize to resize to 64x64 resolution with the default biliear interpolation here.

To set up the tfds, run:

cd celeb_a_hq_custom
tfds build --register_checksums

This should be quick to run (few seconds).

Setup srn_cars dataset as Tensorflow dataset (TFDS) (Optional)

The publicly available srn_cars dataset exists as a zip file in the official PixelNeRF codebase. We convert this into a tensorflow dataset (tfds) so that we can readily load the data into our jax/haiku models with various data processing options that come with tfds.

To set up the tfds, run:

cd srn_cars
tfds build --register_checksums

This can take a while to run (~ 1hr) as we convert views of each scene into an array with shape (num_views, H, W, C), so set it running and enjoy some ☕

Run tests (Optional)

After setting up either dataset, check that you can successfully run a single step of the experiment by running the test for celeb_a_hq:

python3 -m test_celeb_a_hq

or for srn_cars:

python3 -m test_srn_cars

Run meta-learning experiment

Set the hyperparameters in experiment_meta_learning.py as desired by modifying the config values. Then inside the virtual environement, run the JAXline experiment via command:

python3 -m experiment_meta_learning --config=experiment_meta_learning.py

Download pretrained weights

Download pretrained weights for the CelebA-HQ-64 meta-learning experiments for mod_dim=64, 128, 256, 512, 1024 and srn_cars with mod_dim=128 here:

Dataset Modulation Dimension Link
CelebA-HQ-64 64 .npz
CelebA-HQ-64 128 .npz
CelebA-HQ-64 256 .npz
CelebA-HQ-64 512 .npz
CelebA-HQ-64 1024 .npz
SRN CARS 128 .npz

Note that the weights for CelebA-HQ-64 were obtained using the original tfds dataset, so they can be slightly different to the ones resulting from running the above meta-learning experiment with the custom celeb_a_hq dataset.

How to load these weights into the model is shown in the demo Colab below.

Create or Download modulations for CelebA-HQ-64

modulation_dataset_writer.py creates the modulations on celeba as npz. Before running, make sure the pretrained weights for the correct modulation dim have been downloaded. Then use mod_dim and pretrained_weights_dir as input args to the python script. Optionally also specify save_to_dir to store the created modulations as npz in a different directory than the directory you are running from. Run via command:

python3 -m modulation_dataset_writer \
  --mod_dim=64 \
  --pretrained_weights_dir=DIR_CONTAINING_PRETRAINED_WEIGHTS \
  --save_to_dir=DIR_TO_SAVE_MODULATION_DATASET

Alternatively, download the modulations here:

Modulation Dimension Link
64 .npz
128 .npz
256 .npz
512 .npz
1024 .npz

Again note that these modulations were obtained using the original tfds dataset, so they can be slightly different to the ones resulting from running the above script that uses the custom celeb_a_hq dataset.

Demo Colab Open In Colab

We also include a colab that shows how to visualize modulation reconstructions for CelebA-HQ-64.

Paper Figures

Figure 4

Meta-learned initialization + 4 gradient steps and target for test scene.

Figure 7

Course of optimization for imputation of voxel from partial observation.

From back

Partial observation Imputation

From front

Partial observation Imputation

From left

Partial observation Imputation

From lidar scan

Partial observation Imputation

Figure 9

Uncurated samples from DDPM (diffusion) trained on 64-dim modulations of SRN-cars.

Figure 10

Latent interpolation between two car scenes with moving pose.

Figure 11

Novel view synthesis from occluded view.

Occluded view Ground truth Inferred No prior

Figure 12

Uncurated samples from flow trained on 256-dim modulations on ERA-5 temperature data.

Figure 26

Additional voxel imputation results.

Partial observation Imputation

Figure 28

Additional novel view synthesis results.

Occluded view Ground truth Inferred No prior

Giving Credit

If you use this code in your work, we ask you to please cite our work:

@InProceedings{functa22,
  title = {From data to functa: Your data point is a function and you can treat it like one},
  author = {Dupont, Emilien and Kim, Hyunjik and Eslami, S. M. Ali and Rezende, Danilo Jimenez and Rosenbaum, Dan},
  booktitle = {39th International Conference on Machine Learning (ICML)},
  year = {2022},
}

Raising Issues

Please feel free to raise a GitHub issue.

License and disclaimer

Copyright 2022 DeepMind Technologies Limited

All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0

All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY). You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode

Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.

This is not an official Google product.

functa's People

Contributors

hyunjik11 avatar amrzv avatar

Stargazers

 avatar Spyros Georgoulas avatar MemeCat avatar  avatar Patrikas Vanagas avatar  avatar David Andel avatar  avatar Hahyeon Choi avatar matteo avatar Sacha Lewin avatar yupei zhang avatar Yining Jiao avatar YooJin Jang avatar hyun_cho avatar Qi Ma avatar ZQDL avatar Zak avatar Sachin Chanchani avatar Zhihua Liu avatar  avatar ABCD avatar Sergey Prokudin avatar Giulio Corallo avatar Woojeh avatar Tao Hu avatar Amey Varhade avatar  avatar Hyeokjun An avatar soniawang avatar  avatar Fau1ks avatar  avatar Inseo Lee avatar  avatar Woojin-Cho avatar Sua Lee avatar wildug avatar  avatar Alvin Sun avatar Prem Kumar Amanchi avatar half_night avatar John avatar Seok-Ju Hahn (Adam) avatar Tianshu Wen avatar  avatar Stanl avatar JunmingZhao avatar  avatar JasonChi  avatar Devendra Vyas avatar  avatar Iron-LYK avatar Ng Kam Woh avatar Pawel Cyrta avatar Juno Nam avatar isaac avatar Minsik avatar shekshaa avatar jaesunghuh avatar Seung-won Park avatar Ling-Hao CHEN avatar Longhui Yu avatar  avatar caozihan avatar Jiachen Xu avatar Younggyo Seo avatar Suprosanna Shit avatar  avatar  avatar ⑨ avatar  avatar Sihyeon Nater Kim avatar Darnell Granberry avatar Beerend avatar Andrew FigPope avatar Matthew avatar  avatar Chris avatar  avatar David Marx avatar Syntheset avatar Peter Baylies avatar  avatar Michael Churchill avatar Vincent Ho avatar Callum Tilbury avatar Mingyu Kim avatar  avatar Shenghsin Tai avatar Shengyu HUANG avatar  avatar Mahmoud Soliman avatar Baran Hashemi avatar Pablo Duque avatar iacopocurti avatar Mario Belledonne avatar  avatar Doyup Lee avatar Zaccharie Ramzi avatar

Watchers

Saran Tunyasuvunakool avatar  avatar  avatar  avatar Arun Sathiya avatar  avatar Doyup Lee avatar Jihoon Tack avatar iacopocurti avatar Matt Shaffer avatar

functa's Issues

Question about building the tfds

I have a question in building the new tfds. For example, the command "cd celeb_a_hq_custom ” “tfds build --register_checksums" are to setup celeb_a_hq_custom dataset as tfds.
However, when I run the test_celeb_a_hq.py, it just threw out the error :
"tensorflow_datasets.core.load.DatasetNotFoundError: Dataset celeb_a_hq_custom not found."
Why did celeb_a_hq_custom fail to be added as the new tfds?

experiment-meta-learning.py --config=None error

hi.
test_celeb_a_hq works fine but when running experiment-meta-learning,
if name == 'main':
flags.mark_flag_as_required('config')
app.run(functools.partial(platform.main, Experiment))

flags.mark_flag_as_required('config') returns the following error: FATAL Flags parsing error: flag --config=None: Flag --config must have a value other than None.
i was wondering why this problem is happening

3DShapeNet 데이터를 이용하는 모델 관련 코드를 제공해주실 수 있으신 지 궁금합니다!

안녕하세요

저는 포항공과대학교 전자전기공학과 학부생입니다.

다름이 아니라, 해당 논문에서 3DShapeNet 데이터를 이용하여 얻은 결과 이미지를 확인할 수 있었습니다.
제가 지금 하고 있는 연구가 이처럼 3D Generation, Reocnstruction 관련된 연구라 3DShapeNet 데이터를 입력으로 받는 메타 러닝 기반 모델의 코드가 github에 제공된다면 좋을 것 같다고 생각했습니다. 아직 해당 github 페이지에는 3DShapeNet 데이터를 활용하여 modulation vector(INR 일부)를 얻는 코드는 제공되어 있지 않은 것 같습니다.

혹시 이와 관련된 코드를 제공해주실 수 있으신지 궁금합니다!

감사합니다.

value of model.l2_weight and value of noise_std

I would like to know if the value of model.l2_weight and model.noise_std for the training of both Celeba and SRN cars should be kept at 0 in order to achieve the best performance from the model. I have tried to find in the paper these values but I couldn't.

experiment_meta_learning.py errors

Hi,
while I am running experiment_meta_learning.py, I obtain the following error:

absl.flags._exceptions.UnparsedFlagAccessError: Trying to access flag --config before flags were parsed.
During handling of the above exception, another exception occurred:
absl.flags._exceptions.UnparsedFlagAccessError: Trying to access flag --jaxline_post_mortem before flags were parsed.

If I run test_celeb_a_hq.py/test_srn_cars.py the code works smoothly

Modulation vector for SRN Cars

Hi, thank you for sharing the code of the great work.

Could you open the modulation vectors for SRN cars?
It would be helpful to examine the implementation of NeRF functionality and future works.

Thank you in advance.

Setting batch modulations to zero

Hi,

Thanks for making the code available! I was not able to find the implementation of set batch modulations to zero (it is defined in 4th line of Algorithm 1 in the paper) in the repo. Would you mind pointing that part to us?

Thank in advance,
Best regards

Uneven GPU usage

Hi,
Thank you for providing your code. I'm super new to jax and I'm not sure if this is a common thing to jax or your codebase, but I'm seeing a strange GPU ram usage on my GPUs, I have a single node with 4 GPUs (each with 24GB of RAM). The first GPU is using 23895MiB (according to nvidia-smi), and the rest are using 1487MiB.
Is that expected or something is wrong with my environment? (BTW, I also sat XLA_PYTHON_CLIENT_PREALLOCATE to false and it didn't make a difference)

Latent Modulation Implementation

Hi,

I have a question about latent modulation. The paper says that the shift modulation at each layer comes from a linear map on the latent vector. However, the code seems to map latent vector through a ReLU MLP and extract the intermediate layer output as the shift modulation for the original network.

Is this a typo in the paper? Or is it a new type of modulation that works better?

Code for the next steps

Hi and congratulations for the nice work !
I was wondering whether you intended to also give access to the code that does operations with the latent modulations, i.e classification or generation tasks once the INR is trained. Besides, do you have recommendations to efficiently train a classifier or generator based on the modulations ?
Thanks for you help !

Codes for "Spatial Functa"

Hello !
Recently, I found the new version of functa, "spatial functa", on arXiv!
I think that it's very interesting and nice work !!
May I asks whether the official code for spatial functa would be released or not .

Thank you :)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.