Code Monkey home page Code Monkey logo

graphium's Introduction

Scaling molecular GNNs to infinity


PyPI Conda PyPI - Downloads Conda license GitHub Repo stars GitHub Repo stars test test-ipu release code-check doc codecov hydra

A deep learning library focused on graph representation learning for real-world chemical tasks.

  • โœ… State-of-the-art GNN architectures.
  • ๐Ÿ Extensible API: build your own GNN model and train it with ease.
  • โš—๏ธ Rich featurization: powerful and flexible built-in molecular featurization.
  • ๐Ÿง  Pretrained models: for fast and easy inference or transfer learning.
  • โฎ” Read-to-use training loop based on Pytorch Lightning.
  • ๐Ÿ”Œ Have a new dataset? Graphium provides a simple plug-and-play interface. Change the path, the name of the columns to predict, the atomic featurization, and youโ€™re ready to play!

Documentation

Visit https://graphium-docs.datamol.io/.

Installation for developers

For CPU and GPU developers

Use mamba, a faster and better alternative to conda.

If you are using a GPU, we recommend enforcing the CUDA version that you need with CONDA_OVERRIDE_CUDA=XX.X.

# Install Graphium's dependencies in a new environment named `graphium`
mamba env create -f env.yml -n graphium

# To force the CUDA version to 11.2, or any other version you prefer, use the following command:
# CONDA_OVERRIDE_CUDA=11.2 mamba env create -f env.yml -n graphium

# Install Graphium in dev mode
mamba activate graphium
pip install --no-deps -e .

For IPU developers

# Install Graphcore's SDK and Graphium dependencies in a new environment called `.graphium_ipu`
./install_ipu.sh .graphium_ipu

The above step needs to be done once. After that, enable the SDK and the environment as follows:

source enable_ipu.sh .graphium_ipu

Training a model

To learn how to train a model, we invite you to look at the documentation, or the jupyter notebooks available here.

If you are not familiar with PyTorch or PyTorch-Lightning, we highly recommend going through their tutorial first.

Running an experiment

We have setup Graphium with hydra for managing config files. To run an experiment go to the expts/ folder. For example, to benchmark a GCN on the ToyMix dataset run

graphium-train architecture=toymix tasks=toymix training=toymix model=gcn

To change parameters specific to this experiment like switching from fp16 to fp32 precision, you can either override them directly in the CLI via

graphium-train architecture=toymix tasks=toymix training=toymix model=gcn trainer.trainer.precision=32

or change them permanently in the dedicated experiment config under expts/hydra-configs/toymix_gcn.yaml. Integrating hydra also allows you to quickly switch between accelerators. E.g., running

graphium-train architecture=toymix tasks=toymix training=toymix model=gcn accelerator=gpu

automatically selects the correct configs to run the experiment on GPU. Finally, you can also run a fine-tuning loop:

graphium-train +finetuning=admet

To use a config file you built from scratch you can run

graphium-train --config-path [PATH] --config-name [CONFIG]

Thanks to the modular nature of hydra you can reuse many of our config settings for your own experiments with Graphium.

Preparing the data in advance

The data preparation including the featurization (e.g., of molecules from smiles to pyg-compatible format) is embedded in the pipeline and will be performed when executing graphium-train [...].

However, when working with larger datasets, it is recommended to perform data preparation in advance using a machine with sufficient allocated memory (e.g., ~400GB in the case of LargeMix). Preparing data in advance is also beneficial when running lots of concurrent jobs with identical molecular featurization, so that resources aren't wasted and processes don't conflict reading/writing in the same directory.

The following command-line will prepare the data and cache it, then use it to train a model.

# First prepare the data and cache it in `path_to_cached_data`
graphium data prepare ++datamodule.args.processed_graph_data_path=[path_to_cached_data]

# Then train the model on the prepared data
graphium-train [...] datamodule.args.processed_graph_data_path=[path_to_cached_data]

Note that datamodule.args.processed_graph_data_path can also be specified at expts/hydra_configs/.

Note that, every time the configs of datamodule.args.featurization changes, you will need to run a new data preparation, which will automatically be saved in a separate directory that uses a hash unique to the configs.

License

Under the Apache-2.0 license. See LICENSE.

Documentation

  • Diagram for data processing in Graphium.

Data Processing Chart

  • Diagram for Muti-task network in Graphium

Full Graph Multi-task Network

graphium's People

Contributors

alip67 avatar blazejba avatar callumm-graphcore avatar cwognum avatar dominvivo avatar eltociear avatar engmubarak48 avatar graphriela avatar hadim avatar hannesstark avatar hatemhelal avatar joao-alex-cunha avatar jstlaurent avatar kerstink-gc avatar kiddozhu avatar luis-mueller avatar lujiarui avatar maciej-sypetkowski avatar mercuryseries avatar michalkoziarski avatar odymov avatar s-maddrellmander avatar shenyanghuang avatar therence1 avatar wenkelf avatar zhiyil1230 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

graphium's Issues

Create base class for the datamodule

The time where we don't want to use SMILES as input data will come very soon. Let's create a parent to DGLFromSmilesDataModule before it's too late.

Interesting repository

This is a repository done by a previous student that worked on a first iteration of library. The student was a good programmer, but did not implement the layers and architectures in DGL, so the repository became outdated as the models were inefficient.

However, there is a lot of interesting things that he implemented such as molecular properties extractor with gaussian normalization for predictive tasks, and atoms/edges embedding in this link.

I also implemented in another repo an atom_property_getter_rescaled with the goal of using it as atom embeddings.

Load model checkpoint and run test set

Implement the test-set run in the main_run file by loading the saved checkpoints, generate a report with the test-set metrics (CSV or YAML file), and save a CSV file with the raw predictions, and another CSV file with the raw targets.

The next steps for the ModelWrapper

In this thread, I want to discuss what are the next steps and the known problems for the ModelWrapper class, to make sure that the objectives are aligned and that we answer most of the problems.

First step

  • Clean-up molecular stuff using datamol
  • Change the name of the class to PredictorModelWrapper to better highlight that it is used for predictive models
  • Split the data loading from the model wrapper using LightningDataModule
  • Make sure the data is transferred to the right device after the loading when the data is loaded from the CPU but the model runs on the GPU
  • Fix the metrics

Second step

  • Improve how the collate is done. Right now, a function custom_collate is used, but I think we can define a collate_fn in the data loader?
  • Implement test_step and test_dataloader in the LightningDataModule
  • Make sure that the model wrapper allows to easily save and reload checkpoints

Third step

  • Check whether the class BestEpochFromSummary is still needed, or something similar is now in pytorch lightning
  • Check whether the class TrainingProgressFromSummary is still needed, or something similar is now in pytorch lightning
  • Check whether the class HyperparamsMetricsTensorBoardLogger is still needed, or something similar is now in pytorch lightning
  • Right now, the number of workers for data loaders is fixed to 0 when using the GPU due to some bugs. Fix it to allow parallel data loading.
  • Implement a better way to ignore nans from the target when computing the metrics and the loss, especially in multitask settings.
  • Save final train/val/test performance and hparams in a YAML file, to avoid needing to open tensorboard to see them

Delete the expts and config folder

Being picky here but I would delete the expts folder. Move main_micro_ZINC.py to goli/cli and move config_gnns.yaml into a new goli/config folder.

As for accessing data files that are in a python module, it's not recommended to use things such as MAIN_DIR = os.path.dirname(goli.__path__._path[0]) (considered as hacky) and instead you should use importlib_resources: https://importlib-resources.readthedocs.io/en/latest/using.html#example (or pkg_resources).

Originally posted by @hadim in #11 (comment)

Fix errors in the documentation

Some block functions do not work with mkdocs, but used to work with Sphinx

  • ..note::
  • ..math::
  • ..code-block::
  • ..warning::

Some sections also do not work

The first quote doesn't work, but the second seems to work. The same is true for Math and See Also

Example:
    This is an example that doesn't work.
Example
-----------
   This is an example that works, but is the older syntax.

@classproperty does not appear in the docs

The methods that are defined as class properties with the decorator @classproperty does not appear in the rendered docs.

Too much GPU memory?

When running the original DGN and PNA models on the ogb-molpcba datasets, I was able to get up to 6.6M parameters with batch size 2048 running correctly on an AWS-G4 instance with 16GB of GPU-RAM.

However, in the Goli implementation, I am limited to around 5.6M parameters. I have done some experiments where I exceeded that amount and it runs correctly for a few minutes, but crashes later by saying out-of-memory. For example, using 5.9M parameters, it runs for 3 full epochs but crashes at ~25% completion of the 4th epoch.

Here are possible causes that could be explored:

  • I am not sure what is happening. Is there a memory leak?
  • A problem related to Pytorch-lightning?
  • Is it due to the number of workers that is too high?
  • Or the fact that workers are persistent and don't release memory?

This is an important problem to fix, because the ogb-molpcba dataset is parameter hungry, I need more params to win the leaderboard. Otherwise, we can implement the virtual batch size #48 .

[Meta] Next

@DomInvivo @Therence1 here is what I have in mind for the release:

  • Finish current PRs (dom)
  • Logic to load a pretrained model (hadrien, dom)
  • Create a public S3 bucket and put in it two folders: pretrained-models and datasets. (hadrien)
  • Move micro_ZINC and ZINC_bench_gnn into the S3 bucket + add a CLI tools for easy download: goli download micro_ZINC /my/path. (hadrien)
  • Restructure README: add plenty of fancy icons. (hadrien)
  • Better doc: make the index doc page (what is goli, what you can do with it, etc) (hadrien)
  • Rework tutorials: keep the code as it is but add a more detailed description of what is happening (eventually add relevant equations + images). (hadrien and dom)
  • Find a logo (hadrien, therence, dom)
  • Draft a blog post (dom want to give it a try?)

Saving the featurizer?

When saving the model, it does not include the featurizer since it is part of the datamodule. In the case we want to use the model on another dataset for Transfer-learning or property prediction, we need to have the featurizer.

Right now, I manually save the configuration file and reload the key datamodule: args: featurization, but if the goli module is updated, the featurization could still change.

I propose saving both the featurization arguments in a YAML file and the featurization in a pickle file, although we need to discuss if one approach is better than the other.

Implement multi-task manager

We need to modify the structure to allow for training on multiple tasks simultaneously.

Task: Any property we want to predict, or other objective we want to accomplish.

Goal of the implementation

Some properties cannot be combined easily in a single dataset, or cannot use the same loss. We need to decouple them into different tasks, each with its own loss, so that the model can be trained on all tasks concurrently.

How to do it

It is a bit complicated and requires many changes. In the configuration file, remove the part associated to the data loading from the datamodule, and remove the loss and metrics from the predictor. Instead, create a dict and its associated class called task_manager. An example of such config below.

multi_task_manager:
  task_batching: [[task_qm1, task_qm2], [task_prot]]
  tasks:
    task_qm1:
      weight: 1.0
      df_path: "data/qm_properties.csv"
      smiles_col: SMILES
      label_cols: [homo, lumo, energy]
      loss: mse
      metrics_list: ["mae", "pearsonr", "f1 > 0.5"]
      target_nan_mask: ignore-flatten
      task_nn:
        out_dim: 3
        hidden_dims: *middle_dim
        depth: 0
        activation: relu
        last_activation: none
        dropout: 0.2
        normalization: batch_norm
        last_normalization: "none"
        residual_type: simple
    task_qm2:
      weight: 0.5
      df_path: "data/qm_properties.csv"
      smiles_col: SMILES
      label_cols: [is_stable]
      loss: bce
      metrics_list: ["auc", "f1 > 0.5"]
      target_nan_mask: ignore-flatten
      task_nn: ....
    task_prot:
        weight: 3.0
        df_path: "data/protein_binding.csv"
        smiles_col: smile
        label_cols: assayID-*
        loss: mse
        metrics_list: ["mae", "pearsonr", "f1 < 0.5"]
        target_nan_mask: ignore-flatten
        task_nn: ....

The predictor handling of the metrics and the loss will have to change to accommodate the multi-tasks.
Some tasks like task_qm1 and task_qm2 will be concatenated together, while task_prot will be done on a different set of molecules concurrently.

Example of how the tasks classes will be:

The Task Manager Classes

MultiTaskManager

Inputs

  • metrics_configs: The configs of all metrics used across the tasks
  • multitask_configs: The configs of all the tasks

Role

  • Read the data efficiently: When SMILES are similar across datasets, only featurize them once
  • Prepare the metrics and loss functions
  • Initialize all the TaskManager
  • Add all the weighted losses together
  • Generate a tensorflow and YAML report of the metrics for each task
  • Control how the DataModule loads molecules from each task

TaskManager

Inputs

  • metrics
  • featurized molecules
  • task_weight

Role

  • Compute the metrics and loss associated to a specific class

Remove batch_size warning

Since Pytorch-Lightning 1.5, there's a new warning that the batch size is ambiguous and to pass the parameter batch_size into self.log.

However, we do not use self.log explicitly, so we cannot remove the warning.

The warning comes because PL takes takes the "smiles" key from the batch dictionary, but instead of looking at the length of the sequence as a batch size, it looks at the length of each smiles and finds inconsistencies.

Increase model flexibility

Modify class goli.dg.networks.DGLGraphNetwork to have a flexible input MLP, and a flexible output MLP. Right now, the model only takes an output MLP using the class FeedForwardNN

Meta: goli inception

The below mostly focus on what is outside of the business logic of goli, meaning it's about coding good practices and guidelines to make the future integration to the platform easy and seamless. It also prepares the lib for a possible release to the public domain (open source).


What I recently did on the default branch:

  • change the default branch from main to master (this is because one tool we use to manage the release process called rever does not support main yet).
  • add CI for testing and linting the code (with black)
  • add a dummy test (so we have a test folder)
  • start a doc with mkdocs (similar but simpler to set up and fewer files than sphinx): private access to the doc at https://invivoai--goli.github.privpage.net/
  • tweak the env.yml file
  • use datamol dep: datamol is the numpy of molecule built on top of rdkit, it provides convenient functions such as dm.to_mol. We extensively use it on the platform in order to standardize the way we work with molecules.
  • tweak readme to reflect changes
  • move part of the readme doc to CONTRIBUTE.md
  • add __init__.py files in various Python modules.

Infra tasks (todo by dom or/and hadrien):

  • create a private S3 bucket to manage goli's dataset and experiment results

What should be done + some guidelines (it also includes Dom's suggestions from the original README file):

  • replace functions such as to_mol by datamol.
    • consider contributing to datamol when relevant.
  • create a notebooks folder to highlight an interesting part of the code (take izanagi as an example: https://github.com/invivoai-platform/izanagi/tree/master/tutorials)
  • cleanup existing GNN models and keep only the relevant ones
  • convert MolecularTransformer by something function-based, simpler to use and more flexible.
  • update ModelWrapper to the latest PL version and use their metrics API (https://pytorch-lightning.readthedocs.io/en/latest/metrics.html)
  • consider optuna as an hparams search engine: on this, I know this is very opinionated. I find optuna to be simpler and lighter than hydra. If you like hydra or already are comfortable with it, feel free to ignore optuna.
  • for any file IO in goli, use fsspec so you can transparently read/write on a local, S3 or GS filesystem without the need for a different code for each.
  • use a private S3 bucket to deal with the dataset and results of the experiments that are going to be performed.
    • make a python command using fsspec to download the dataset instead of using a GDrive (or use the S3 cli tool).
    • clearly separate the different files/folders in this repo: dataset, pretrained-model, experiments, etc
  • use omegaconf instead of hydra to read/load config files.
  • Replace all the print statements by logger.info or equivalent using the loguru library and from loguru import logger.
  • use click for all CLI commands.
  • clearly separate experiments from re-usable source code.

Feel free to discuss high-level things in this ticket. To discuss/implement a specific point from above, please open a new ticket.

ping @DomInvivo

Ignore molecules that fail featurization

When a single molecule fails featurization, the whole parallel loop hangs with no error. It just stops the parallel process but the python code is still in a run state.

I would like the parallelization to catch the error, give a warning, and return the list of indexes that failed. Otherwise, we cannot process any dataset where a single molecule gives an error.

https://github.com/valence-discovery/goli/blob/b89c8d8dcac6c59fb81a2bac0f75caedd986cf55/goli/data/datamodule.py#L366

When using n_jobs=0, the code throws an error. For example, on the BindingDB, I get this:

[16:30:26] non-ring atom 28 marked aromatic
Traceback (most recent call last):
  File "/home/dominique_valencediscovery_com/goli/expts/main_run_predict.py", line 73, in <module>
    main(cfg)
  File "/home/dominique_valencediscovery_com/goli/expts/main_run_predict.py", line 50, in main
    preds = trainer.predict(model=predictor, datamodule=datamodule)
  File "/home/dominique_valencediscovery_com/miniconda3/envs/goli/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 631, in predict
    results = self._run(model)
  File "/home/dominique_valencediscovery_com/miniconda3/envs/goli/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 704, in _run
    self.data_connector.prepare_data(model)
  File "/home/dominique_valencediscovery_com/miniconda3/envs/goli/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 56, in prepare_data
    self.trainer.datamodule.prepare_data()
  File "/home/dominique_valencediscovery_com/miniconda3/envs/goli/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py", line 385, in wrapped_fn
    return fn(*args, **kwargs)
  File "/home/dominique_valencediscovery_com/miniconda3/envs/goli/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py", line 49, in wrapped_fn
    return fn(*args, **kwargs)
  File "/home/dominique_valencediscovery_com/goli/goli/data/datamodule.py", line 366, in prepare_data
    features = dm.utils.parallelized(
  File "/home/dominique_valencediscovery_com/miniconda3/envs/goli/lib/python3.8/site-packages/datamol/utils/jobs.py", line 202, in parallelized
    return runner(fn, inputs_list, arg_type=arg_type)
  File "/home/dominique_valencediscovery_com/miniconda3/envs/goli/lib/python3.8/site-packages/datamol/utils/jobs.py", line 133, in __call__
    return self.sequential(*args, **kwargs)
  File "/home/dominique_valencediscovery_com/miniconda3/envs/goli/lib/python3.8/site-packages/datamol/utils/jobs.py", line 99, in sequential
    res = [
  File "/home/dominique_valencediscovery_com/miniconda3/envs/goli/lib/python3.8/site-packages/datamol/utils/jobs.py", line 100, in <listcomp>
    JobRunner.wrap_fn(callable_fn, arg_type, **fn_kwargs)(dt)
  File "/home/dominique_valencediscovery_com/miniconda3/envs/goli/lib/python3.8/site-packages/datamol/utils/jobs.py", line 67, in _run
    return fn(args, **fn_kwargs)
  File "/home/dominique_valencediscovery_com/goli/goli/features/featurizer.py", line 622, in mol_to_dglgraph
    mol = Chem.RemoveHs(mol)
Boost.Python.ArgumentError: Python argument types in
    rdkit.Chem.rdmolops.RemoveHs(NoneType)
did not match C++ signature:
    RemoveHs(RDKit::ROMol mol, RDKit::MolOps::RemoveHsParameters params, bool sanitize=True)
    RemoveHs(RDKit::ROMol mol, bool implicitOnly=False, bool updateExplicitCount=False, bool sanitize=True)
  1%|โ–‹                                                                                                                     | 14034/2222073 [03:59<10:27:12, 58.67it/s]

I tracked down the error to this beautiful molecule:
CN[C@@H](C)C(=O)N[C@H]1CN(CC[C@H]2CC[C@H](N2C1=O)c1nc2c(cccc2[nH]1)-c(:c:c):c:c)C(=O)CCCCCCCCCCC(=O)N1CC[C@H]2CC[C@H](N2C(=O)[C@H](C1)NC(=O)[C@H](C)NC)c1nc2c(cccc2[nH]1)-c1ccccc1

image

Replace @classproperty in Python 3.9

When Python 3.9 is released, it will support the use of the @property with the @classmethod, so we can remove the decorator classproperty and maybe solve the issue that the class properties do not appear in the docs.

Deal with NaN featurization

Deal with molecules that fail a part of the featurization, such as [He][He].

Add a parameter nan_mask to the featurization of both nodes and edges with different possible attributes:

  • "raise": DEFAULT. Raise an error when there is a nan in the featurization
  • "warn": Raise a warning when there is a nan in the featurization
  • None: Don't do anything
  • Floating value: Replace nans by the specified value

Change batch_norm naming?

For GNNs, multiple types of norming have been developed and batch norm is not necessarily the best. We also have graph_norm, layer_norm, instance_norm.

Should we change the naming convention of NN layers? Instead of batch_norm=BOOL, we would have norm=STR, where the str will be batch_norm, instance_norm, graph_norm, and have a norm parser?

Doing this change after the release might require to have depreciation warnings and stuff like that, so even if none of the other norms are implemented, I think it's a good idea to make the move now. What do you think @hadim ?

Update to work with PL metrics

The code was developed initially for pytorch-lightning 0.8. With the new releases, they have changed completely how the metrics work, and I have only done a few quick patches. We need to update the code to make it work with the newest metric system.

  • In the file goli.trainer.metrics, update everything (Thresholder, pearsonr, spearmanr, METRICS_DICT, MetricWithThreshold, MetricFunctionToClass
  • Update the class goli.trainer.model_wrapper.ModelWrapper accordingly
  • Update the function goli.commons.config_loader.config_load_metrics to work with the new implementation
  • Use the new MetricCollection class from Pytorch-Lightning 1.2

Fix the docs search

The search bar doesn't work when visiting the online website, although it works when generating the docs locally.

Implement bond-length without conformer

Implement a bond-lenght function that uses the covalent radius of 2 connected atoms? Because the conformer generation fails for weird molecules present in mol-pcba.

GPU not used automatically

For some reason, Pytorch-lightning doesn't seem to run automatically on the GPU, despite the model being able to run on the GPU.
@hadim Maybe this has to do with DGL?

image

Pretrained models

  • Pretrain some models with goli and upload them to a public S3.
  • Add a goli function to reload a model from S3. This is probably already possible with PredictorModule.load_from_checkpoint("path/to/s3").
  • Tutorials: inference only with goli
  • Tutorials: fine-tuning with goli

Training crashes when reaches a plateau

When running the configurations config_ZINC_bench_gnn.yaml, after 32 epochs, a plateau is reached. However, instead of reducing the learning rate, it starts to only train on the first batch of each epoch and ignores the rest, until it reaches the minimum number of epochs required to stop.

Add virtual batch sizes

I'm not sure what's the technical term for it, but add the option of running multiple batches before updating the gradient. This is particularly useful for datasets that perform better with large batches, such as Mol-PCBA.

Make BaseDGLLayer inheritance more robust

https://github.com/invivoai/goli/blob/85b78866035fdba30f17d36596d900b573d85d4d/goli/dgl/dgl_layers/base_dgl_layer.py#L15

Replace the method _parse_layer_args by mutiple abstract properties, using abc.abstractproperty

  • get_layer_specific_kwargs_of_lists kwargs arguments to be used separately at each layer
  • get_true_out_dims true output dimension of the layer, which can vary due to number of heads or number of aggregators
  • get_kwargs_keys_to_remove kwargs arguments to remove from the list, due to conflict or being used by the layer-specific kwargs

Since these methods will be abstract, all new layers will have to implement them to make sure that they are used correctly

Should we also add some static functions to BaseDGLLayer such as get_collate_fn and get_featurization, so that they are defined in the Layer? And any new added layer will simply need to be re-defined here, without needing to change the code in the dataloader and trainer?

Add residual to FeedForwardNN?

Not sure about this issue, but we could add a boolean specifying if we want a residual connection in FeedForwardNN, and maybe replace the ResNet architecture?

Improve molecular transformations

Refactor the whole file goli.mol_utils.transformers.py.

  • Replace to_mol by calling the same function from datamol
  • Replace the MoleculeTransformer and AdjGraphTransformer by something much simpler. to discuss
  • Implement a featurizer similar to dgl-lifesci

Generalize PNA into a DGN model

The directional graph networks (DGN) generalizes the PNA paper, but with the possibility of adding directions based on the eigenvectors.

Implement the DGN layer, and make sure that the eigenvectors are pre-computed in the Dataset.
See the implementation here, but clean it to make it more similar to the implementation of PNA in the current repository.

Save a compressed dataset file

https://github.com/valence-discovery/goli/blob/cb3bbce8c92967bd5e47f531a5eb5a0f337e6379/goli/data/datamodule.py#L535

Some large datasets require >30GB of storage, it's way too much, and too long and large to download from an online file. For example, the processed htsfp-pcba dataset takes about 15GB, but compressed it takes 860MB, a very significant difference, both in storage and download time.

Allow the saving and loading to use compressed format such as .cache.gz

Improve naming of models

In Tensorboard, the names are "version_0", "version_1", .... And in the model checkpoints, the names are "model.ckpt", "model-v1.ckpt", "model-v2.ckpt".

Have a better naming convention, and have at least the date and time, so that a specific Tensorboard model can be matched to a checkpoint.

Implement PNA with towers

Implement the PNA model with towers PNATowersLayer. It is similar to the PNAComplexLayer, but uses consecutive towers before the message passing. See this link for the Towers implementation of the DGN.

Deal with class imbalance

Deal with class imbalance, either by having a loss weight, or by super-sampling the under represented classes. It is unclear which method should be prioritized.

  • Super-sampling
    This can work very well for single task property prediction. However, in multi-labels settings, it becomes ambiguous to define how we want to sample since it is difficult to guarantee that all classes are represented more often. For each sample, we can have a balanced weight attributed to all the labels, and each sample gets selected with a probability proportional to the sum of its weights.

  • Balanced-BCELoss
    At first view, this seems easy since we can simply add a larger weight to the low-represented labels. However, for multi-labels predictions, it is harder since the weights taken by the class BCELoss are attributed to each sample, not each label of each sample. Therefore, we need to compute the BCELoss without reduction, multiply by a weight matrix, then apply the reduction.

Implement graph norm

Graph norm is a proposed approach to stabilize training of GNNs and improve their results. They say that batch_norm doesn't work well since there's too much variability between batches.

See paper here: https://arxiv.org/pdf/2009.03294.pdf

The graph norm should be implemented as a new string option "graph_norm" for the normalization parameter of the class BaseDGLLayer.

Incompatible with pytorch 1.9 and Python 3.9

Pytorch 1.9 changed a few things, including how Tensor objects are imported. Goli and many other packages are no longer compatible with Pytorch 1.9. Either restrict it in the env.yml file, or fix the problem.

Also, some packages are incompatible with Python 3.9. Also amend the env.yml file

Support for multi-GPUs

Code crashes when using multiple GPUs, with some of the tensors being on the CPU, thus causing an error. To investigate more.

  • Comment out every call to .device

Clean-up models

Clean-up most of the GNNs, and keep the essential graph layers

  • GCN
  • GIN
  • GAT
  • Gated-GCN
  • PNA

Clean-up the redundant MLP classes

  • base_layers.MLP
  • gin_layer.MLP
  • dgl_layer.MLPReadout

Add support for the OGB datasets

Add support for the datasets from Open graph benchmark (OGB), mainly for the MolHIV, MolTox21 and MolPCBA datasets.

I would suggest creating a new DataModule similar to the DGLFromSmilesDataModule

Replace intermittent pooling by VirtualNode

The method _forward_intermittent_pool_layers from goli.dgl.networks.FeedForwardDGL._forward_intermittent_pool_layer should be changed to a virtual node.

See implementation of the VirtualNode class here and the implementation in the network here

Default parameters of PNA do not use activation functions

I realized that the default parameters of PNA do not use any activation in the conv layers. I am currently exploring whether this makes sense or not. We are still out-performing the SOTA without activation, but I'm unsure it is a good idea.

There are a few reasons that it could still work

  • Concatenating multiple aggregators forces the network to always learn a smaller embedding of the same information, and acts as a form of non-linearity
  • No activation function allows the network to learn more easily
  • Some of the operations used, like max and min capture a non-linear response of neighbours.

But I still have some tests to do

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.