Code Monkey home page Code Monkey logo

gflownet's People

Contributors

alfred-rxrx avatar bengioe avatar dependabot[bot] avatar dmaljovec avatar hohyun312 avatar julienroyd avatar pjanowski avatar sobhanmp avatar timgaripov avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gflownet's Issues

Nested dataclasses do not reinitialize

We use nested dataclasses in Config. This is generally find since we only instantiate a Config once, but in some cases (e.g. testing) way may want to initialize multiple. This will not work as expected since the inner instances belong to the Config class rather than its instances.

  • replace Inner() config dataclasses by field(default_factory=Inner)

Make a generic conditional information class

We currently pass around conditional information (cond_info) as dicts where side information is stored under certain keys, and the model eventually receives just the encoding. It would be relevant to standardize this, and in doing so generalize things like thermometer encoding, or to- and from-list operations on minibatches of conditional information.

Potential modification of flat_rewards after batch-creation

Experiments using a ReplayBuffer have lead to surprising results where the model would be incapable to learn if our list of flat_rewards (tensors of shape (2,) ) would be stacked into a tensor of shape (batch, 2) before being pushed in the buffer. We were able to fix this issue by creating a copy of each item pushed into the replay buffer (note: only copying flat_rewards may have been required).

Our hypothesis is that flat_rewards (and potentially other tensors) are modified after the batch is created, which was not harmful when discarding the batch after the parameter update (on-policy) but became harmful when this modification caused the flat_rewards in the replay buffer to also be modified which would cause the model to train on wrong trajectory-reward pairs when this trajectory would be re-sampled from the buffer later on. It could be worth identifying what operation caused the buffer data to change to validate that this operation is intentional and should indeed occur.

Make multiprocessing terminate gracefully

Current multiprocessing/threading routines are not explicitly stopped, they just rely on the objects they belong to to be garbage collected to stop. This sometimes causes aesthetically displeasing logs where all the threads produce errors.

Harmonize use of masks

Masks are currently used somewhat ad-hoc here and there, in particular:

  • GraphSampler explicitly uses them, but it should be abstracted away
  • GraphTransformerGFN and derivatives end up applying masks themselves, but this feels like the "wrong" place

Instead, it would be good for GraphActionCategorical to do all the accounting of masks and mask-related code.

Update docker image

Maciej recommends switching to training-latest.

(this may involve upgrading all the torch_geometric wheels)

Convert configs to standard library

#93 introduces better configs but uses a simple (homemade) library to do so. It would be good to put in the effort to upgrade this to a more standard library such as OmegaConf or hydra.

Fix GraphTransformer forward function

I think I found inconsistency in the GraphTransformer doc string and the implementation of it, which are located at gflownet/models/graph_transformer.py. I don't think it is a critical flaw, but may impact the performance, as it is not intended.

Here is what current doc string says:

The per node outputs are the concatenation of the final (post graph-convolution) node embeddings
and of the final virtual node embedding of the graph each node corresponds to.

The per graph outputs are the concatenation of a global mean pooling operation, of the final
virtual node embeddings, and of the conditional information embedding.

And here is the current implementation of GraphTransformer:

1. class GraphTransformer(nn.Module):
2.     ...
3.     def forward(self, g: gd.Batch, cond: torch.Tensor):
4.         ...
5.         glob = torch.cat([gnn.global_mean_pool(o[: -c.shape[0]], g.batch), o[-c.shape[0] :]], 1)
6.         o_final = torch.cat([o[: -c.shape[0]]], 1)
7.         return o_final, glob

There are two problems:

  • (line 6) The per node outputs does not concatenate virtual node embeddings.
  • (line 5) The per graph outputs concatenate with virtual node embeddings, not with conditional information embedding. Either the doc string or the code should be fixed.

For the first option (fix doc string), the code should be fixed something like this:

class GraphTransformer(nn.Module):
    ...
    def forward(self, g: gd.Batch, cond: torch.Tensor):
        ...
        n_final = o[: -c.shape[0]] # final node embeddings (without virtual nodes)
        v_final = o[-c.shape[0] :] # final virtual node embeddings
        glob = torch.cat([gnn.global_mean_pool(n_final, g.batch), v_final], 1)
        o_final = torch.cat(
            [n_final, v_final.repeat_interleave(torch.bincount(g.batch), dim=0)], 1
        )
        return o_final, glob

And the corresponding doc string should be something like:

The per graph outputs are the concatenation of a global mean pooling operation, of the final node embeddings, and of the final virtual node embeddings.

Restore stem multi-connectivity in fragment environment

In the original (Bengio et al. 2021) fragment environment, some fragments had stem atoms that could be used multiple times (e.g. carbon atoms with more than one open valence electron).

In the current implementation, stem atoms can only be attached to once.

Add a MANIFEST.in

Currently non-.py files are not bundled in the wheel (e.g. when doing a pip install git+http://.../gflownet.git), which means some files would be missing (like gflownet/envs/frag_72.txt). This can be fixed by adding a MANIFEST.in file.

Investigate multiprocessing memory usage further

It's not clear that we are doing multiprocessing in the best way possible.

  • we have been seeing problems with too many shared tensors being created, the current patch to that is #59 (which naively pickles everything!)
  • we obviously waste time copying tensors and objects from/to workers

A great enhancement for this project would be to dig deeper into those problems.

Turn action tuples into its own type

As raised in #42, we should probably not be passing around unspecific Tuple[int,int,int] when passing around graph action indices.

To improve the situation, we should create a new class for this data type and refactor the code to use it.

Interpretation of edge_attr in FragMolBuildingEnvContext

Hi!

Thanks for the great work on GFlowNets. I am going through the code, to understand and reproduce the multi-objective molecule generation experiments (from the Multi-objective GFlowNets paper).

I am a little bit confused by the interpretation of edge_attributes created by FragMolBuildingEnvContext.graph_to_Data

The edge_attr tensor is initialized as

edge_attr = torch.zeros((len(g.edges) * 2, self.num_edge_dim))

and then filled as follows

for i, e in enumerate(g.edges):
    ad = g.edges[e]
    a, b = e
    for n, offset in zip(e, [0, self.num_stem_acts]):
        idx = ad.get(f'{int(n)}_attach', -1) + 1
        edge_attr[i * 2, idx] = 1
        edge_attr[i * 2 + 1, idx] = 1
        ...

I do not completely understand how I should interpret the features edge_attr[i * 2, :] and edge_attr[i * 2 + 1, :] of the edge number i.

If I understand correctly, in the code above idx takes value 0 when the attribute f'{a}_attach' (f'{b}_attach') is not set. Otherwise, idx is the 1-based index of the stem of node a (b). In particular, if both attributes f'{a}_attach' and f'{b}_attach' are set to 0, then the edge_attr[i * 2, :] = [0, 1, 0, 0, 0, ... ].

Questions:

  • Is edge_attr[2 * i, :] supposed to be the one-hot encoding of the stems of nodes a and b connected by the i-th edge?
  • If that's the case. Shouldn't the code be something like this?

self.num_edge_dim = (most_stems + 1) * 2

...

edge_attr = torch.zeros((len(g.edges) * 2, self.num_edge_dim))

...

for i, e in enumerate(g.edges):
    ad = g.edges[e]
    a, b = e
    for n, offset in zip(e, [0, self.num_stem_acts]):
        idx = ad.get(f'{int(n)}_attach', -1) + 1
        edge_attr[i * 2, idx + offset] = 1
        edge_attr[i * 2 + 1, idx + offset] = 1
        ...

Make code pass prebuild checks

This seems to involve:

  • Linting code properly with flake8
  • Configuring dependencies so that mypy can run
  • Fixing minor security issues

Exception in thread Thread-1 (_run_pareto_accumulation)

Hi guys,

First, thank you for making this resource available, the GFlowNet seems a really promising method.

My question is that I've installed locally it following the instructions on the README, and now I'm trying the example in sEH fragment-based MOO task in /gflownet/src/gflownet/tasks/seh_frag_moo.py. However, the calculation always dies with the error message below. I've also tried the seh_frag.py, but it dies with the same message.

$ python seh_frag_moo.py
[...]
18/06/2024 18:13:28 - INFO - logger - Final generation steps completed - sampled_reward_avg:0.39 igd:0.42 lifetime_igd_frontOnly:0.43 PCent:1.87 lifetime_PCent_frontOnly:0.69
Exception in thread Thread-1 (_run_pareto_accumulation):
Traceback (most recent call last):
  File "/blue/lic/seabra/.conda/envs/gflownet/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/blue/lic/seabra/.conda/envs/gflownet/lib/python3.10/threading.py", line 953, in run

Here is my configuration:

$ python
Python 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch_geometric as gn
>>> torch.__version__
'2.1.2+cu121'
>>> torch.version.cuda
'12.1'
>>> torch.cuda.get_device_name()
'NVIDIA A100-SXM4-80GB'
>>> torch.cuda.get_device_properties(0)
_CudaDeviceProperties(name='NVIDIA A100-SXM4-80GB', major=8, minor=0, total_memory=81050MB, multi_processor_count=108)
>>> gn.__version__
'2.4.0'

Am I missing something here? Thanks a lot!

Make the code work with forkserver/spawn

Right now setting the multiprocessing start method causes things to not work. The code jumps from the cycle method to MPObjectProxy's getattr and it does not make any sense.

Duplicate logarithm in `seh_frag_moo`

I have been working with the seh_frag_moo.py script, but I noticed what looks like a bug in the computation of the log reward.

In the SEHMOOTask class, the function cond_info_to_logreward first calls self.pref_cond.transform, which runs scalar_logreward = (flat_reward * cond_info["preferences"]).sum(1).clamp(min=1e-30).log() in the MultiObjectiveWeightedPreferences class, thereby taking a logarithm of the rewards.

Then, two lines later, cond_info_to_logreward calls self.temperature_conditional.transform, which runs scalar_logreward = linear_reward.squeeze().clamp(min=1e-30).log() in the TemperatureConditional class, thereby taking a second logarithm since the linear_reward variable actually holds the scalar_logreward value from the previous function call.

Since the first logarithm generally produces negative numbers, the second call clamps all these negatives to 1e-30 and then takes the log of that, giving -4144.6533 as the reward for every molecule.

Should one of these clamp & logarithm operations be removed to prevent the double logarithm?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.