apax-hub / apax Goto Github PK
View Code? Open in Web Editor NEWA flexible and performant framework for training machine learning potentials.
License: MIT License
A flexible and performant framework for training machine learning potentials.
License: MIT License
As an extension of a general update PR (including Jax) we should also start support the latest python version.
Currently the ASE calculator does not check whether the particle number or box volume changes.
As a result, the predictions for a differently sized box are completely wrong.
We should add an additional command to evaluate the test set.
To better conform to other MD codes we should implement a config roughly like the following.
duration: 10_000 # fs
sampling_interval: 100 # fs
timestep: 0.5 # fs
Thanks to @PythonFZ for bringing this up.
We should at some point in the near future bump the jax version to its newest release.
Was there a particular reason why we specified an exact version of jax?
If not we probably want to lift that restriction or add that version to the CUDA install instructions.
Currently running the pip install... command for CUDA installs the latest jax version while the pyproject.toml specifies 0.4.1.
All of our model components require a considerable amount of parameters, leading to the model class itself also requiring all these parameters to initialize the components.
Would it be sensible to separate creation and use and initialize GMNN with the components instead of all the component parameters?
GMNN(n_basis, n_radial, units, ...)
vs
GMNN(Descriptor, MLP)
output data structure has to be changed since training different model versions is not possible.
ckpts are written to the model dir.
Quick fix is to set XLA_PYTHON_CLIENT_PREALLOCATE=false
.
Will be set in a PR today.
We need a second checkpoint that tracks the epoch loss to determine whether the current model params are the best seen so far.
currently, every model run saves its TB logs into its own directory. We probably want to move this from model_path/model_name/tb_logs
to model_path/tb_logs
so it's more convenient to compare different models.
This is quite annoying in combinations with the progress bars.
related to #22
If the batch size is larger than the training / test data one receives the following error:
ZeroDivisionError: float division by zero
Something like ValueError: batch size can not be larger than the available data
.
It is currently not possible to train with tensorboard and csv callback. If both specified in the config the metrics will just be saved in one format. No error or warning is returned.
https://github.com/GM-NN/gmnn-jax/blob/58afee506ddbc0297fe981c4550057632ef28e35/gmnn_jax/md/ase_calc.py#L39
https://github.com/GM-NN/gmnn-jax/blob/58afee506ddbc0297fe981c4550057632ef28e35/gmnn_jax/md/ase_calc.py#L67
should be super
calls
https://github.com/GM-NN/gmnn-jax/blob/58afee506ddbc0297fe981c4550057632ef28e35/gmnn_jax/md/ase_calc.py#L52
and respectively
https://github.com/GM-NN/gmnn-jax/blob/58afee506ddbc0297fe981c4550057632ef28e35/gmnn_jax/model/gmnn.py#L80-L91
does not load non-default values for r_max
, n_basis
, n_radial
, nn
, ...
The loss is currently accumulated outside the compiled step function. Probably doesn't have a perf impact, but for consistency it should probably be refactored to be part of the other metrics.
In the TF version, most of the model is in fp32. We should carefully evaluate for which parts this is possible.
There might be a bug in the atomic energy regression.
After a short research i could not find an easy way to optionally install jax[cpu]
or jax[cuda]
with poetry.
The standard installation ships jax[cpu]
.
Workaround for now is to delete the poetry.lock file, change jax = {version = "0.3.25", extras = ["cpu"]}
to jax = {version = "0.3.25", extras = ["cuda"]}
, and run poetry install
For various applications and downstream method development, we require at least a basic transfer learning implementation.
The implementation should implement:
This is sufficient for an elementary discriminative transfer learning implementation.
Currently loading models into the ASE calc and jaxMD inferencem odel assumes the model was constructed with the default parameters (nn=[512,512], n_basis=7, n_radial=5) and throws an error if that was not the case.
The config classes should be documented.
Currently we assume that every data point has the same cell size as the first one. If others have different sizes, the model is not aware and produces wrong results.
Before we tackle #45 we should at least try to catch this error.
Low prio. Once we train on stresses we should come back to this. I've added a TODO comment in the loss class.
on initialization an AttributeError is raised since our TFModelSpoof does not have the attribute distribute_strategy
.
This can probably be fixed by just using an actual keras model.
We already have the arguments in the config, however in run.py
these are ignored and only config.data_path
is used.
This should be easy to fix.
ASE does not allow to handle all possible labels, only labels that are known by ASE. This is an issue for example with MAT training.
We could experiment with multi GPU training and local SGD methods, which would be fairly simple to implement in jax.
This could potentially give some speedup for training large batch_sizes / datasets of large systems.
Many datasets consist of differently sized systems. currently we assume that everything is the same size and has the same cell. In order for the code to be more generally applicable we should implement this.
Since some configs allow for extras, the validate command can be deceiving.
Also we should perform some stricter validation.
Internal file handling occurs via a mix of pathlib and strings. We should refactor the code to only use the former
I was just checking why I'm missing TensorFlow and looked I looked into the pyproject.toml of gmnn-jax.
I could not resolve my problem but thought, the following could b jax = 0.4.1
.
This doesn't change anything tbh.
https://github.com/GM-NN/gmnn-jax/blob/327c7fe49f01027d719db466d6a337ec54b9457b/pyproject.toml#L16
While we do allow the specification of the number of NN layers in the config, the model class assumes it to be exactly 2, failing if it's less and ignoring additional layer specifications:
self.dense1 = NTKLinear(units[0], b_init=b_init, name="dense1")
self.dense2 = NTKLinear(units[1], b_init=b_init, name="dense2")
self.dense3 = NTKLinear(1, b_init=b_init, name="dense3")
Checking the logs for training runs across different systems and durations reveals that saving a checkpoint after every epoch causes them to queue up in the async manager.
This could be circumvented by introducing a checkpointing interval, although I would have to test whether that would actually incur a speedup since saving is performed asynchronously.
It seems like models work with single and double precision to some extend. So I would suggest adding the following to the config for gmnn and jaxmd.
jax_enable_x64: true
It would be nice, if one could disable a certain loss by setting its weight to zero / null
loss:
- name: energy
loss_type: structures
- name: forces
loss_type: structures
weight: 0.0 # this
- name: forces
loss_type: angles
weight: null # or this
We should reintroduce regression tests at some point.
This time we could make proper ones that actually perform a complete training. but are marked as slow so you don't have to run them every time you wan to test the code.
When enabling fp64, we still get the following warning:
cannot safely cast value from dtype=int64 to dtype=int32. In future JAX releases this will result in an error.
warnings.warn("scatter inputs have incompatible types: cannot safely cast "```
This is most likely caused either by atomic numbers or neighbors being in int32. We should check that and make an explicit conversion.
Related to #65
For easier tracking of the training/sim progress and following @PythonFZ 's suggestion, we should add a TQDM progress bar.
For the MD we can only reasonably add it to the outer loop, which is fine.
For training we could either have only one for the epochs or an additional one for the inner loop over the batches.
Considering that our epochs are typically quite fast, I would suggest to only add the progress bar for the overall training (tracking epoch number and epoch loss)
Thoughts?
JaxMD continues, even if all quantities are NaN
. One could check, e.g. isinstance(temperature, float)
and exit the simulation if it is NaN
.
Units for the tqdm bar would also be a nice addition. I've set them to be units of fs
. This would require, that inner_loop * dt > 1 fs
at all times and probably also requires that they are integers so it could be tricky.
Both losses have non-zero contributions when supplied with predictions/labels from ghost atoms.
e.g. (1.0 - dotp) / F_0_norm
is nan for a (0,0,0) label.
Writing a custom training loop using this code is still not terribly ergonomic. Perhaps we can improve the situation a bit by better encapsulating the input pipeline etc.
This is useful for my sampling method.
I was just trying to fit the HM21 dataset and it appears that neighborlists are not correctly reallocated for differently sized periodic system (works for gas phase)
│ 975 in neighbor_list_fn │
│ │
│ 972 │ d = partial(metric_sq, **kwargs) │
│ 973 │ d = vmap(d) │
│ 974 │ return lax.cond( │
│ ❱ 975 │ │ jnp.any(d(position, nbrs.reference_position) > threshold_sq), │
│ 976 │ │ (position, nbrs.error), neighbor_fn, │
│ 977 │ │ nbrs, lambda x: x) │
│ 978 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: vmap got inconsistent sizes for array axes to be mapped:
* one axis had size 8: axis 0 of argument Ra of type float64[8,3];
* one axis had size 32: axis 0 of argument Rb of type float64[32,3]
currently the starting structure is not written to the output. At least for completeness, this should be done.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.