Code Monkey home page Code Monkey logo

geoldm's Introduction

GeoLDM: Geometric Latent Diffusion Models for 3D Molecule Generation

License: MIT ArXiv

cover

Official code release for the paper "Geometric Latent Diffusion Models for 3D Molecule Generation", accepted at International Conference on Machine Learning, 2023.

Environment

Install the required packages from requirements.txt. A simplified version of the requirements can be found here.

Note: If you want to set up a rdkit environment, it may be easiest to install conda and run: conda create -c conda-forge -n my-rdkit-env rdkit and then install the other required packages. But the code should still run without rdkit installed though.

Train the GeoLDM

For QM9

python main_qm9.py --n_epochs 3000 --n_stability_samples 1000 --diffusion_noise_schedule polynomial_2 --diffusion_noise_precision 1e-5 --diffusion_steps 1000 --diffusion_loss_type l2 --batch_size 64 --nf 256 --n_layers 9 --lr 1e-4 --normalize_factors [1,4,10] --test_epochs 20 --ema_decay 0.9999 --train_diffusion --trainable_ae --latent_nf 1 --exp_name geoldm_qm9

For Drugs

First follow the intructions at data/geom/README.md to set up the data.

python main_geom_drugs.py --n_epochs 3000 --n_stability_samples 500 --diffusion_noise_schedule polynomial_2 --diffusion_steps 1000 --diffusion_noise_precision 1e-5 --diffusion_loss_type l2 --batch_size 32 --nf 256 --n_layers 4 --lr 1e-4 --normalize_factors [1,4,10] --test_epochs 1 --ema_decay 0.9999 --normalization_factor 1 --model egnn_dynamics --visualize_every_batch 10000 --train_diffusion --trainable_ae --latent_nf 2 --exp_name geoldm_drugs

Note: In the paper, we present an encoder early-stopping strategy for training the Autoencoder. However, in later experiments, we found that we can even just keep the encoder untrained and only train the decoder, which is faster and leads to similar results. Our released version uses this strategy. This phenomenon is quite interesting and we are also still actively investigating it.

Pretrained models

We also provide pretrained models for both QM9 and Drugs. You can download them from here. The pretrained models are trained with the same hyperparameters as the above commands except that latent dimensions --latent_nf are set as 2 (the results should be roughly the same if as 1). You can load them for running the following evaluations by putting them in the outputs folder and setting the argument --model_path to the path of the pretrained model outputs/$exp_name.

Evaluate the GeoLDM

To analyze the sample quality of molecules:

python eval_analyze.py --model_path outputs/$exp_name --n_samples 10_000

To visualize some molecules:

python eval_sample.py --model_path outputs/$exp_name --n_samples 10_000

Small note: The GPUs used for these experiment were pretty large. If you run out of GPU memory, try running at a smaller size.

Conditional Generation

Train the Conditional GeoLDM

python main_qm9.py --exp_name exp_cond_alpha --model egnn_dynamics --lr 1e-4 --nf 192 --n_layers 9 --save_model True --diffusion_steps 1000 --sin_embedding False --n_epochs 3000 --n_stability_samples 500 --diffusion_noise_schedule polynomial_2 --diffusion_noise_precision 1e-5 --dequantization deterministic --include_charges False --diffusion_loss_type l2 --batch_size 64 --normalize_factors [1,8,1] --conditioning alpha --dataset qm9_second_half --train_diffusion --trainable_ae --latent_nf 1

The argument --conditioning alpha can be set to any of the following properties: alpha, gap, homo, lumo, mu Cv. The same applies to the following commands that also depend on alpha.

Generate samples for different property values

python eval_conditional_qm9.py --generators_path outputs/exp_cond_alpha --property alpha --n_sweeps 10 --task qualitative

Evaluate the Conditional GeoLDM with property classifiers

Train a property classifier

cd qm9/property_prediction
python main_qm9_prop.py --num_workers 2 --lr 5e-4 --property alpha --exp_name exp_class_alpha --model_name egnn

Additionally, you can change the argument --model_name egnn by --model_name numnodes to train a classifier baseline that classifies only based on the number of nodes.

Evaluate the generated samples

Evaluate the trained property classifier on the samples generated by the trained conditional GeoLDM model

python eval_conditional_qm9.py --generators_path outputs/exp_cond_alpha --classifiers_path qm9/property_prediction/outputs/exp_class_alpha --property alpha --iterations 100 --batch_size 100 --task edm

Citation

Please consider citing the our paper if you find it helpful. Thank you!

@inproceedings{xu2023geometric,
  title={Geometric Latent Diffusion Models for 3D Molecule Generation},
  author={Minkai Xu and Alexander Powers and Ron Dror and Stefano Ermon and Jure Leskovec},
  booktitle={International Conference on Machine Learning},
  year={2023},
  organization={PMLR}
}

Acknowledgements

This repo is built upon the previous work EDM. Thanks to the authors for their great work!

geoldm's People

Contributors

minkaixu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

geoldm's Issues

Other output format

Hey Minkai,

thanks for sharing your promising work. Is there a way to convert the output to some other format like pdbtq/pdb, Chem.Mol, or a smiles/SELFies Object? (I can only find the .txt files which only contain the positions of atoms)

Thanks,
Lennart Jaretzki

Question about datasets_config.py script

Hello Minkai! My team and I had a question about datasets_config.py script: what do the parameters 'distances' and 'radius_dic' mean? Thank you so much, we would really appreciate your response!

z_h data formatting issue in EnLatentDiffusion model

Hi Minkai!

Thank you for sharing the code! Just one quick question regarding the format of h data throughout the training and sampling process.

At first, h is defined as {'categorical': one_hot, 'integer': charges} and the data is concatenated with categorical at the front of integer. However, at line 1310 of the EnLatentDiffusion model, z_h is formatted as z_h = {'categorical': torch.zeros(0).to(z_h), 'integer': z_h}, meaning that the charges part is placed before the categorical part.

Then, take sampling for instance: here z0[:, :, -1:] is used as charges, meaning z0 has a format different from that of z_h in the diffusion model.

Should z_h = {'categorical': torch.zeros(0).to(z_h), 'integer': z_h} be changed to z_h = {'categorical': z_h, 'integer': torch.zeros(0).to(z_h)} instead?

Thanks!
Tianyi

Drug data split

Hi,
I try to use the main_geom_drugs,py to run , but it seems to have some error,

image

And I also try to solve it, but it maybe is the build_geom_dataset.py and line 101 data_list = [data_list[i] for i in perm], this problem is because the data_list contains subarrays of varying shapes,and in line 107 np.spilt need same shape,
So how can I solve this problem?

best regard,
Zhongyu

Question about autoencoder training stage

Hello, regarding the autoencoder phase, I have a question. Is the latent invariant feature dimension k mentioned in the paper referring to the dimension of the encoder's output μh, while the dimension of μx remains 3?

how to use for nonQM9 and nonDrug?

HI MinkaiXu,

The paper looks very exciting. I have a small input dataset with SMILES and an associated experimental property. Is there a simplified documentation on how I can try your algorithm to read such input and try training-testing along with my choice of property prediction performance, prediction evaluation and structure generation?

To non-expert, I am not sure how to go about trying your algorithm on my own input dataset, thanks so much. For my input, I cannot use your QM9 and drug training datasets.

Any guidance/help/code is appreciated. Thanks,
JL

question while loading pretrained models

Thanks for your code @MinkaiXu a lot : ), it's cool.
I met a problem in loading pretrained models, the error message is as follows:
Traceback (most recent call last): File "eval_analyze.py", line 198, in <module> main() File "eval_analyze.py", line 127, in main with open(join(eval_args.model_path, 'args.pickle'), 'rb') as f: FileNotFoundError: [Errno 2] No such file or directory: 'outputs/qm9_p/args.pickle'
I'm confused how to solve it. Hope you'll help me with it!

Joint Training

Hi Minkai,

Thanks for the amazing work! I am wondering if in the code the AE and the Diffusion are trained jointly by default (with --train_diffusion and --trainable_ae), instead of training separately?

Training time for QM9 dataset

How long does it take to train the model on qm9 for 3000 epochs with batch size of 64? On my machine it seems like even one epoch would take 5 hours with a batch size of 64. Are the hyper parameters I am using correct?

Autoencoder is identity function on atom coordinates? Equivalence to EDM

Hi Minkai,

Thank you for sharing this work! When I analyze the sampling results of GeoLDM, I found the latent variable z_x is almost equal to the decoded atom positions. Below are molecules I reconstructed with decoded atom pos and atom type (left) and z_x and decoded atom type (right) respectively. They are almost same.

z_x + recon atom type recon atom pos + recon atom type

A further analysis on the reconstruction results of the auto encoder in GeoLDM indicates that both encoder and decoder are almost identity functions on atom coordinates. If so, can I consider GeoLDM is actually equivalent to 3D space diffusion (i.e. EDM) since #latent variables is equal to #atoms and both encoder and decoder are identity functions on atom coordinates, except that there is an auto-encoder part on atom types?

If this is correct, I’m also wondering how did you train the autoencoder in your published version. I can understand the training will lead to identity functions with the reconstruction loss only, but you mentioned in the repo that the encoder is remained untrained. If so, why is the encoder not a random mapping but a identity function instead?

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.