Code Monkey home page Code Monkey logo

gflownet's People

Contributors

alexandravolokhova avatar alexduvalinho avatar alexhernandezgarcia avatar carriepl avatar carriepl-mila avatar dannysalem avatar influencefunctional avatar josephdviviano avatar manhbao-nguyen avatar michalkoziarski avatar nikita-0209 avatar sh-divya avatar taoliu032 avatar vict0rsch avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

gflownet's Issues

Adjust mask of ctorus

          I get that this mask is in a way fake, but for consistency wouldn't it make more sense to set the first entry of the mask (corresponding to the generic action) to True, and the second (corresponding to EOS) to False if the number of actions is equal to the trajectory length, and the other way around if any continuous action is valid (except EOS) is valid?

Originally posted by @michalkoziarski in #193 (comment)

Need a way to restrict space groups in spacegroup.py

Need to pass a list of space group labels or numbers to spacegroup.py and have it restrict generation to only these groups.

I don't want to break it / make a mess. Assuming there's an elegant way to do it in maybe in get_mask_invalid_actions_forward(), but I haven't been able to run it an figure out the logic yet.

min and norm properties of proxies can lead to silent errors

So if the caller of that property (accidently) does anything inplace on the returned tensor, that change will affect everyone else calling min in the future?
If so, it feels like it might lead to bugs that would be really difficult to track.
Same comment applies to the other proxies that implement the same strategy : torus (for min and norm) and uniform (for min).

Originally posted by @carriepl in #204 (comment)

The proxy should not be copied into each environment instance

Currently, the proxy is set as an attribute of the environments and the base environment implements the methods proxy2reward() and reward2proxy() that determine the conversion between proxy outputs and reward. The environment also implements the methods reward() and reward_batch(), which call the proxy and the conversion methods. This is probably not ideal for various reasons.

I do not see any longer a good reason to keep the proxy and these methods within the environment. It seems possible and a good idea to completely detach the environment and the proxy. Some proxies need information from the environment, which is currently set via the call to Env.setup_proxy(), which calls the proxy's setup() method. But this could just be done elsewhere.

Now, in terms of alternatives, I am not completely settled on what the best option would be. In particular, where should the methods that convert between proxy and reward go?

  • In the (base) proxy?
  • In the GFlowNet agent?

Flexible Policy Definition

  • Policies were originally MLPs.
  • Now we need to be able to use arbitrary function approximators.
  • This will be a library-level change that will affect all projects.

`gflownet/gflownet.py` - should not store `self.env` as well as `self.env_maker`

currently, self.env is stored in the GFlowNet class, because other classes expect an env instance (to access various methods / attributes).

Since the gflownet now stores a class factory rather than a class instance, we should figure out another way to communicate with these other classes (instead of storing an env instance.

Integrate Crystal and CCrystal

Both environments seems to share quite a bit of functionality, it would be good to refactor it so it's not copy-pasted between them.

Branch Cleanup

We should clean out all the old / dead branches.

There are like a billion.

SVP :)

Decide on a format for conditional modelling.

Unless I'm missing it I don't see anywhere a method for conditional generation. For my purposes it would be

  1. load conditions in batches from a dataloader
  2. assign each env to a condition
  3. encode the condition (via a graph model - the condition in this case is a molecule and the conditions encoding is a vector)
  4. concatenate conditions encoding to GFN policy input
  5. train as normal

If possible it would also be ideal if the conditioning model could be updated during training along with the policy model, though I could probably find a way to pretrain one which is at least 'ok' if necessary. I haven't read deep enough into the conditional gflownet work to know what's optimal here. In my case, the distribution of high-scoring samples is both very sharp and extremely sensitive to the conditions.

In evaluation mode, for speed, we could call the conditioning model once and use the same encoding at all generation steps. For training, particularly if the conditioning model is being updated, probably fine to call it with the policy at every action.

I have started playing with this locally but don't want to conflict with any planned format.

Compute and log variances of the log probs

One more thing: I'd add computing and tracking two variances of the log probs:

  1. variance over samples of logprobs_estimates (to understand better the behaviour of the correlation coefficient over the training)
  2. median over samples of the variances of the logprobs_estimates over trajectories for each sample (to get a sense of how noisy the estimation is). The math is a bit tricky here as we use log mean as an estimation, not just the mean. But there're some work around: https://stats.stackexchange.com/questions/418313/variance-of-x-and-variance-of-logx-how-to-relate-them
    But in any case, we will need to compute empirical var(P_F(tau) / P_B (tau)) / n_traj for each sample and then play around a bit with it to get variance for the log mean estimation.

Originally posted by @AlexandraVolokhova in #167 (comment)

Check uses of copy()

  • Might be a good idea to rename the copy() method in the common utils as copy_state() and move it to the base environment.

replay.pkl

Small detail in gflownet.py: enumerate is a generator so tqdm(enumerate(...)) does not print a full-width progress bar as tqdm is not aware of the length of enumerate(...). Simple fix: enumerate(tqdm(...)) 😄

Issue

Not sure you want to fix this in this PR but testing a gflownet creates files (like replay.pkl) in the current working directory and pollutes it, particularly dangerous when it's tracked by git.

Create training README

Batch size:

  • forward: number of forward trajectories to include in the training batch. These are on-policy trajectories possibly with random actions (if random_action_prob > 0) or with a tempered policy if temperature < 1.0
  • train: number of backward trajectories to include in the training batch, sampled (backwards) from data points in a "training set"
  • replay: number of backward trajectories to include in the training batch, sampled (backwards) from data points in the replay buffer.

The total number of trajectories in the training batch is the sum of the above.

state2proxy

  1. state2proxy can be either state2oracle or state2obs. Input arg of state2oracle is a list of states, and that for state2obs is a single state. To use both of them interchangeably as state2proxy, the input form should be the same.
  2. Need to add support for transformer-friendly data transformation. This would be as simple as changing the data type of state to int orNone transformation. (In the latter case, transformation to int can be done within the forward call of the transformer itself.) state2oracle cannot be used as the transformer-friendly transformation because:
    a. for e.g., for the grid if we use state2proxy = state2oracle, input states to the transformer would be [-1, -1] (oracle-friendly) instead of [0, 0]) and the embedding for negative indices is not defined,
    b. it is not necessary that the oracle takes indices as input (necessary for the embedding layer)

Delete unused branches

We should clear out the outstanding PRs / old branches to make it easier for us to organize the work.

sample_batch args in `test_top_k` in gflownet.py

I just cloned a fresh copy of the repo in WSL and tried a run with all default configs, and it looks like the sample_batch on line 906 in gflownet.py is getting the wrong arguments (at least, it crashes every time). Looks like the parent function test_top_k was added August 1.

call:

for b in batch_with_rest(0, self.logger.test.n_top_k, self.batch_size.forward):
    gfn_states += self.sample_batch(
        self.env, len(b), train=False, progress=progress
    )[0]

function definition:

def sample_batch(
    self,
    n_forward: int = 0,
    n_train: int = 0,
    n_replay: int = 0,
    train=True,
    progress=False,
):

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.