Code Monkey home page Code Monkey logo

ldm-3dg's Introduction

Latent 3D Graph Diffusion

PyTorch implementation for Latent 3D Graph Diffusion

Yuning You, Ruida Zhou, Jiwoong Park, Haotian Xu, Chao Tian, Zhangyang Wang, Yang Shen

In ICLR 2024.

License: GPL-3.0 (If you are interested in a different license, for example, for commercial use, please contact us.)

Overview

A pipeline to compress 3D graphs into the latent space, where a (vectorial) diffusion model is implemented to capture the distribution.

1. Unconditional Generation

Training Topological AE

Config environment following https://github.com/wengong-jin/hgraph2graph#installation, or my conda environment file https://github.com/Shen-Lab/LDM-3DG/blob/main/environment_topo_ae.yml. Specifically, rdkit=2019.03.4 if you would like to use my trained checkpoints (related issue: wengong-jin/hgraph2graph#44). (rdkit-pypi-2020.9.5.2 might also work)

Download data and trained model weights from https://zenodo.org/records/11005382.

cd ./AE_Topology

# get vocabulary for molecular graphs
python get_vocab.py --ncpu 40 < ../AE_topo_weights_and_data/smiles_chembl_mol3d_qm9_drugs.txt > ../AE_topo_weights_and_data/vocab.txt

# preprocess data for more efficient loading
python preprocess.py --train ../AE_topo_weights_and_data/smiles_mol3d_chembl_train.txt --vocab vocab.txt --ncpu 40 --mode single --out_path ../AE_topo_weights_and_data/processed_data_train/
python preprocess.py --train ../AE_topo_weights_and_data/smiles_chembl_mol3d_qm9_drugs.txt --vocab vocab.txt --ncpu 40 --mode single --out_path ../AE_topo_weights_and_data/processed_data/

# train ae
python train_generator_ptl.py  --ddp_num_nodes 1 --ddp_device 1 --train ../AE_topo_weights_and_data/processed_data_train --vocab ../AE_topo_weights_and_data/vocab.txt --save_dir ../AE_topo_weights_and_data/pretrained
# if train ae with gssl
python train_generator_gssl_ptl.py  --ddp_num_nodes 1 --ddp_device 1 --train ../AE_topo_weights_and_data/processed_data_train --vocab ../AE_topo_weights_and_data/vocab.txt --save_dir ../AE_topo_weights_and_data/pretrained_gssl

# generate smiles to emb dictionary
python generate_embedding.py --train ../AE_topo_weights_and_data/processed_data --vocab ../AE_topo_weights_and_data/vocab.txt --ckpt ../AE_topo_weights_and_data/pretrained/last.ckpt --save_fn ../AE_topo_weights_and_data/smiles2emb_dict.pt

Training Geometric AE

Generic environment would work for the following programs, and my conda environment file https://github.com/Shen-Lab/LDM-3DG/blob/main/environment.yml is also provided for reference.

Download trained model weights and generated samples from https://zenodo.org/records/11005378.

cd ./AE_Geometry_and_Unconditional_Latent_Diffusion
python main_2dto3d_encoder_decoder.py --ddp_num_nodes 1 --ddp_device 1 --log_dir ../AE_geom_uncond_weights_and_data/job16_decoder_2d_to_3d
# if train ae with gssl
python main_2dto3d_encoder_decoder_gssl.py --ddp_num_nodes 1 --ddp_device 1 --log_dir ../AE_geom_uncond_weights_and_data/job19_decoder_2d_to_3d_gssl

Generating embedding for qm9 and drugs

Download code and generated qm9 latent embeddings from https://zenodo.org/records/11005421.

cd ./e3_diffusion_for_molecules
# qm9
python process_qm9.py

# Drug: following steps 1-3 in https://github.com/ehoogeboom/e3_diffusion_for_molecules/tree/main/data/geom#how-to-build-geom-drugs to download data first
python build_geom_dataset.py

Training Diffusion Model

cd ./AE_Geometry_and_Unconditional_Latent_Diffusion
# qm9
python main_latent_ddpm_qm9.py --ddp_num_nodes 1 --ddp_device 1 --data_dir ../e3_diffusion_for_molecules/qm9/latent_diffusion/emb_2d_3d --log_dir ../AE_geom_uncond_weights_and_data/job17_latent_ddpm_qm9
# qm9 conditional generation: alpha, Cv, gap, homo, lumo, mu
python main_latent_ddpm_qm9_conditional.py --condition $condition --ddp_num_nodes 1 --ddp_device 1 --data_dir ../e3_diffusion_for_molecules/qm9/latent_diffusion/emb_2d_3d --log_dir ../AE_geom_uncond_weights_and_data/job21_latent_ddpm_qm9_condition_${condition}
# drug
python main_latent_ddpm_drug.py --ddp_num_nodes 1 --ddp_device 1 --data_dir ../e3_diffusion_for_molecules/data/geom --log_dir ../AE_geom_uncond_weights_and_data/job18_latent_ddpm_drug

Sampling

cd ./AE_Geometry_and_Unconditional_Latent_Diffusion
python sample1_latent_ddpm_qm9_latent.py --log_dir $log_dir --sample_number $sample_number
# for conditional generation
python sample1_latent_ddpm_qm9_z_conditional.py --log_dir $log_dir --condition $condition

python sample2_latent_ddpm_qm9_2d.py --log_dir $log_dir
python sample3_latent_ddpm_qm9_3d.py --log_dir $log_dir

Evaluating

cd ./AE_Geometry_and_Unconditional_Latent_Diffusion
# jupyter notebooks
evaluate_unconditional.ipynb
evaluate_conditional.ipynb

2. Conditional Generation on Geometric Object

Download data, trained model weights and generated samples from https://zenodo.org/records/11005227, https://zenodo.org/records/11005419.

You might need the following packages for Vina Docking:

pip install meeko==0.1.dev3 scipy pdb2pqr vina==1.2.2 
python -m pip install git+https://github.com/Valdes-Tresanco-MS/AutoDockTools_py3

Training Topological AE

cd ./AE_Topology

# get vocabulary for molecular graphs
python get_vocab.py --ncpu 40 < ../AE_topo_weights_and_data/smiles_plus.txt > ../AE_topo_weights_and_data/vocab_pocket_aware.txt

# preprocess data for more efficient loading
python preprocess.py --train ../AE_topo_weights_and_data/smiles_mol3d_chembl_train.txt --vocab ../AE_topo_weights_and_data/vocab_pocket_aware.txt --ncpu 40 --mode single --out_path ../AE_topo_weights_and_data/processed_data_pocket_train/
python preprocess.py --train ../AE_topo_weights_and_data/smiles_plus.txt --vocab ../AE_topo_weights_and_data/vocab_pocket_aware.txt --ncpu 40 --mode single --out_path ../AE_topo_weights_and_data/processed_data_pocket/

# train ae
python train_generator_ptl.py  --ddp_num_nodes 1 --ddp_device 1 --train ../AE_topo_weights_and_data/processed_data_pocket_train --vocab ../AE_topo_weights_and_data/vocab_pocket_aware.txt --save_dir ../AE_topo_weights_and_data/pocket_pretrained
# if train ae with gssl
python train_generator_gssl_ptl.py  --ddp_num_nodes 1 --ddp_device 1 --train ../AE_topo_weights_and_data/processed_data_pocket_train --vocab ../AE_topo_weights_and_data/vocab_pocket_aware.txt --save_dir ../AE_topo_weights_and_data/pocket_pretrained_gssl

# generate smiles to emb dictionary
python generate_embedding.py --train ../AE_topo_weights_and_data/processed_data_pocket --vocab ../AE_topo_weights_and_data/vocab_pocket_aware.txt --ckpt ../AE_topo_weights_and_data/pocket_pretrained/last.ckpt --save_fn ../AE_topo_weights_and_data/smiles2emb_dict_pocket.pt

Training Geometric AE

Download data following https://github.com/guanjq/targetdiff#data

cd ./AE_Geometry_and_Conditional_Latent_Diffusion

# train ae
python -m scripts.train_ae configs/training.yml

# generate 2d and 3d embeddings
python -m scripts.generate_embedding configs/sampling.yml

Training Diffusion Model

cd ./AE_Geometry_and_Conditional_Latent_Diffusion

python -m scripts.train_latent_diffusion configs/training.yml

Sampling and evaluating

cd ./AE_Geometry_and_Conditional_Latent_Diffusion

# sample latent embeddings
python -m scripts.sample_z configs/training.yml

# reconstruct 2d
python -m scripts.sample_2d

# reconstruct 3d and evaluate ($data_id in {0, 1, ..., 99})
python -m scripts.sample_3d configs/sampling.yml --data_id $data_id
python -m scripts.evaluate outputs --docking_mode vina_score --protein_root data/test_set --data_id $data_id

Acknowledgements

The implementation cannot proceed without referencing https://github.com/ehoogeboom/e3_diffusion_for_molecules, https://github.com/wengong-jin/hgraph2graph, https://github.com/directmolecularconfgen/dmcg, https://github.com/JeongJiHeon/ScoreDiffusionModel, https://github.com/chaitjo/geometric-gnn-dojo, and https://github.com/guanjq/targetdiff.

Citation

If you use this code for you research, please cite our paper.

@inproceedings{you2024latent,
  title={Latent 3D Graph Diffusion},
  author={You, Yuning and Zhou, Ruida and Park, Jiwoong and Xu, Haotian and Tian, Chao and Wang, Zhangyang and Shen, Yang},
  booktitle={International Conference on Learning Representations},
  year = {2024}
}

ldm-3dg's People

Contributors

yyou1996 avatar shen-lab avatar

Stargazers

Cody Zhao avatar Kim Seong Hyeon avatar Rishabh Anand avatar  avatar tyr avatar Zhimin Zhang avatar David Zhu avatar Jed Homer avatar Mason Minot avatar Hyunwoo Ryu avatar xiw_chg avatar Alex Morehead avatar Kaiyuan Gao avatar Bowen Gao avatar  avatar Yu-Sheng Chen avatar 板蓝根 avatar Yoshitaka Inoue avatar

Watchers

 avatar Kostas Georgiou avatar  avatar  avatar

ldm-3dg's Issues

The environmental problem about GPU version

Hi Authors,

Congrats on the great work!

As I was reading your team's paper about 3DG, I want to reproduce your excellent work. However, I noticed that you didn't specify the GPU version used to train your model in the paper. Additionally, I observed that the default batch size is 512 and the number of epochs is 10,000. These numbers seem larger than those used in other models I have encountered before. I couldn't find the exact numbers you used to train the unconditional QM9 model in the paper. Therefore, I am wondering: What GPU version did you use to train the unconditional QM9 model, and are the batch size and number of epochs really 512 and 10,000?

Thanks!

Error in evaluation step for Conditional Generation on Geometric Object

Hi Authors,

Congrats on the great work! I am trying to develop a flow-matching version of your work, but I encounter some challenges when reproducing your work on Conditional Generation on Geometric Objects.

When I ran Sampling and evaluating commands, I encountered the following error for all data_id I tried.

[2024-05-08 23:53:49,959::evaluate::INFO] Vina Score:  Mean: nan Median: nan
[2024-05-08 23:53:49,959::evaluate::INFO] Vina Min  :  Mean: nan Median: nan
Traceback (most recent call last):
  File "/global/cfs/cdirs/mp54/jsliang/MLFF/conda_envs/eval/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/global/cfs/cdirs/mp54/jsliang/MLFF/conda_envs/eval/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/global/cfs/cdirs/mp54/jsliang/LFM-3DG/AE_Geometry_and_Conditional_Latent_Diffusion/scripts/evaluate.py", line 154, in <module>
    print_ring_ratio([r['chem_results']['ring_size'] for r in results], logger)
  File "/global/cfs/cdirs/mp54/jsliang/LFM-3DG/AE_Geometry_and_Conditional_Latent_Diffusion/scripts/evaluate.py", line 34, in print_ring_ratio
    logger.info(f'ring size: {ring_size} ratio: {n_mol / len(all_ring_sizes):.3f}')
ZeroDivisionError: division by zero

What I changed are:

  1. the model used in sample_z because I do not have the folder ldm_2023_11_16__18_01_30
    model.load_state_dict(torch.load('logs_diffusion/ldm_2023_11_16__18_01_30/checkpoints/30000.pt')['model'])
    to
    model.load_state_dict(torch.load('../AE_geom_cond_weights_and_data/weight_diffusion.pt')['model'])
  2. the version name from 'final-001' to 'final' in pl_pair_dataset.py to match the file name in https://github.com/guanjq/targetdiff#data
    def __init__(self, raw_path, transform=None, version='final'):
    Can you help check what's wrong?

Question about reproducing results

Hi Yuning,

I am very interesting in your great work, and I'm attempting to reproduce your results using the Jupyter notebook. However, I've encountered some challenges with the 2D distribution results while using the provided sampled SMILES data.

Specifically, I directly evaluated AE_geom_uncond_weights_and_data/job17_latent_ddpm_qm9_spatial_graphs/sample_smiles.pt and sample_conformer.pt. For the test SMILES, I loaded e3_diffusion_for_molecules/data/smiles_qm9.txt. All the files were sourced from your comprehensive Zenodo repository.

I've attached a figure showing the results I obtained. Could you kindly advise if I'm using the correct data or provide any guidance on the appropriate files to use for reproducing the 2D distribution results?

image

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.