Code Monkey home page Code Monkey logo

apax's People

Contributors

m-r-schaefer avatar pre-commit-ci[bot] avatar pythonfz avatar tetracarbonylnickel avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

chronum94

apax's Issues

Allow for python 3.11

As an extension of a general update PR (including Jax) we should also start support the latest python version.

Bump Jax Version and make version explicit in GPU instructions

We should at some point in the near future bump the jax version to its newest release.
Was there a particular reason why we specified an exact version of jax?
If not we probably want to lift that restriction or add that version to the CUDA install instructions.

Currently running the pip install... command for CUDA installs the latest jax version while the pyproject.toml specifies 0.4.1.

Rework Model Class

All of our model components require a considerable amount of parameters, leading to the model class itself also requiring all these parameters to initialize the components.

Would it be sensible to separate creation and use and initialize GMNN with the components instead of all the component parameters?

GMNN(n_basis, n_radial, units, ...)

vs

GMNN(Descriptor, MLP)

output data structure

output data structure has to be changed since training different model versions is not possible.
ckpts are written to the model dir.

Add best Model checkpoint

We need a second checkpoint that tracks the epoch loss to determine whether the current model params are the best seen so far.

TB callback config should accept a log dir argument

currently, every model run saves its TB logs into its own directory. We probably want to move this from model_path/model_name/tb_logs to model_path/tb_logs so it's more convenient to compare different models.

Batch Size too large

Description

If the batch size is larger than the training / test data one receives the following error:

ZeroDivisionError: float division by zero

Expected

Something like ValueError: batch size can not be larger than the available data.

Issues with GMNN-calc

Turn overall loss tracking into a `clu` metric

The loss is currently accumulated outside the compiled step function. Probably doesn't have a perf impact, but for consistency it should probably be refactored to be part of the other metrics.

poetry cpu installation

After a short research i could not find an easy way to optionally install jax[cpu] or jax[cuda] with poetry.
The standard installation ships jax[cpu].
Workaround for now is to delete the poetry.lock file, change jax = {version = "0.3.25", extras = ["cpu"]} to jax = {version = "0.3.25", extras = ["cuda"]}, and run poetry install

Transfer Learning Implementation

For various applications and downstream method development, we require at least a basic transfer learning implementation.

The implementation should implement:

  • a path to a pre-trained model checkpoint in the CheckpointConfig
  • masking model parameters with optax mask

This is sufficient for an elementary discriminative transfer learning implementation.

Allow for arbitrary model parameters for ASE and JaxMD

Currently loading models into the ASE calc and jaxMD inferencem odel assumes the model was constructed with the default parameters (nn=[512,512], n_basis=7, n_radial=5) and throws an error if that was not the case.

TensorBoard Callback broken

on initialization an AttributeError is raised since our TFModelSpoof does not have the attribute distribute_strategy.
This can probably be fixed by just using an actual keras model.

Multi GPU training

We could experiment with multi GPU training and local SGD methods, which would be fairly simple to implement in jax.
This could potentially give some speedup for training large batch_sizes / datasets of large systems.

Implement Training On multiple system and cell sizes

Many datasets consist of differently sized systems. currently we assume that everything is the same size and has the same cell. In order for the code to be more generally applicable we should implement this.

Number of NN layers is assumed to be 2

While we do allow the specification of the number of NN layers in the config, the model class assumes it to be exactly 2, failing if it's less and ignoring additional layer specifications:

self.dense1 = NTKLinear(units[0], b_init=b_init, name="dense1")
self.dense2 = NTKLinear(units[1], b_init=b_init, name="dense2")
self.dense3 = NTKLinear(1, b_init=b_init, name="dense3")

Add checkpointing Interval

Checking the logs for training runs across different systems and durations reveals that saving a checkpoint after every epoch causes them to queue up in the async manager.
This could be circumvented by introducing a checkpointing interval, although I would have to test whether that would actually incur a speedup since saving is performed asynchronously.

Disbale loss

It would be nice, if one could disable a certain loss by setting its weight to zero / null

loss:
- name: energy
  loss_type: structures
- name: forces
  loss_type: structures
  weight: 0.0 # this
- name: forces
  loss_type: angles
  weight: null # or this

Regression tests

We should reintroduce regression tests at some point.
This time we could make proper ones that actually perform a complete training. but are marked as slow so you don't have to run them every time you wan to test the code.

Check for int precision

When enabling fp64, we still get the following warning:

cannot safely cast value from dtype=int64 to dtype=int32. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "```

This is most likely caused either by atomic numbers or neighbors being in int32. We should check that and make an explicit conversion.
Related to #65 

Add Training and MD Progress bars

For easier tracking of the training/sim progress and following @PythonFZ 's suggestion, we should add a TQDM progress bar.

For the MD we can only reasonably add it to the outer loop, which is fine.
For training we could either have only one for the epochs or an additional one for the inner loop over the batches.
Considering that our epochs are typically quite fast, I would suggest to only add the progress bar for the overall training (tracking epoch number and epoch loss)

Thoughts?

Jax MD Temperature calculation and exit condition

  1. The temperature calculation seems to be wrong:

https://github.com/GM-NN/gmnn-jax/blob/58ad44f7ea582049bfe1d044129fba2a89b2d62a/gmnn_jax/md/nvt.py#L145-L146

  1. JaxMD continues, even if all quantities are NaN. One could check, e.g. isinstance(temperature, float) and exit the simulation if it is NaN.

  2. Units for the tqdm bar would also be a nice addition. I've set them to be units of fs. This would require, that inner_loop * dt > 1 fs at all times and probably also requires that they are integers so it could be tricky.

The angle losses need padding

Both losses have non-zero contributions when supplied with predictions/labels from ghost atoms.
e.g. (1.0 - dotp) / F_0_norm is nan for a (0,0,0) label.

Improve run abstractions / modularity

Writing a custom training loop using this code is still not terribly ergonomic. Perhaps we can improve the situation a bit by better encapsulating the input pipeline etc.

NL precomputing doesn't work for periodic systems of different size.

I was just trying to fit the HM21 dataset and it appears that neighborlists are not correctly reallocated for differently sized periodic system (works for gas phase)

│ 975 in neighbor_list_fn                                                                          │
│                                                                                                  │
│    972 │   d = partial(metric_sq, **kwargs)                                                      │
│    973 │   d = vmap(d)                                                                           │
│    974 │   return lax.cond(                                                                      │
│ ❱  975 │   │   jnp.any(d(position, nbrs.reference_position) > threshold_sq),                     │
│    976 │   │   (position, nbrs.error), neighbor_fn,                                              │
│    977 │   │   nbrs, lambda x: x)                                                                │
│    978                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 8: axis 0 of argument Ra of type float64[8,3];
  * one axis had size 32: axis 0 of argument Rb of type float64[32,3]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.