Code Monkey home page Code Monkey logo

jaxpi's Introduction

JAX-PI

This repository is a comprehensive implementation of physics-informed neural networks (PINNs), seamlessly integrating several advanced network architectures, training algorithms from these papers

This repository also releases an extensive range of benchmarking examples, showcasing the effectiveness and robustness of our implementation. Our implementation supports both single and multi-GPU training, while evaluation is currently limited to single-GPU setups.

Updates

  • May 2024: We have released the code for our latest paper, "PirateNets: Physics-informed Deep Learning with Residual Adaptive Networks". Please see repo branch pirate for the implementation and examples.

Installation

Ensure that you have Python 3.8 or later installed on your system. Our code is GPU-only. We highly recommend using the most recent versions of JAX and JAX-lib, along with compatible CUDA and cuDNN versions. The code has been tested and confirmed to work with the following versions:

  • JAX 0.4.26
  • CUDA 12.4
  • cuDNN 8.9

You can install the latest versions of JAX and JAX-lib with the following commands:

pip3 install -U pip
pip3 install --upgrade jax jaxlib

Install JAX-PI with the following commands:

git clone https://github.com/PredictiveIntelligenceLab/jaxpi.git
cd jaxpi
pip install .

Quickstart

We use Weights & Biases to log and monitor training metrics. Please ensure you have Weights & Biases installed and properly set up with your account before proceeding. You can follow the installation guide provided here.

To illustrate how to use our code, we will use the advection equation as an example. First, navigate to the advection directory within the examples folder:

cd jaxpi/examples/advection

To train the model, run the following command:

python3 main.py 

To customize your experiment configuration, you may want to specify a different config file as follows:

python3 main.py --config=configs/sota.py 

Our code automatically supports multi-GPU execution. You can specify the GPUs you want to use with the CUDA_VISIBLE_DEVICES environment variable. For example, to use the first two GPUs (0 and 1), use the following command:

CUDA_VISIBLE_DEVICES=0,1 python3 main.py

Note on Memory Usage: Different models and examples may require varying amounts of GPU memory. If you encounter an out-of-memory error, you can decrease the batch size using the --config.batch_size_per_device option.

To evaluate the model's performance, you can switch to evaluation mode with the following command:

python3 main.py --config.mode=eval

Examples

In the following table, we present a comparison of various benchmarks. Each row contains information about the specific benchmark, its relative $L^2$ error, and links to the corresponding model checkpoints and Weights & Biases logs.

Benchmark Relative $L^2$ Error Checkpoint Weights & Biases
Allen-Cahn equation $5.37 \times 10^{-5}$ allen_cahn allen_cahn
Advection equation $6.88 \times 10^{-4}$ adv adv
Stokes flow $8.04 \times 10^{-5}$ stokes stokes
Kuramoto–Sivashinsky equation $1.61 \times 10^{-1}$ ks ks
Lid-driven cavity flow $1.58 \times 10^{-1}$ ldc ldc
Navier–Stokes flow in tori $3.53 \times 10^{-1}$ ns_tori ns_tori
Navier–Stokes flow around a cylinder - ns_cylinder ns_cylinder

Decaying Navier-Stokes flow in tori

ns_tori

Vortex shedding

ns_cylinder

ns_cylinder

ns_cylinder

Grey-Scott

Grey-Scott

Ginzburg–Landau

Ginzburg–Landau

Citation

@article{wang2023expert,
  title={An Expert's Guide to Training Physics-informed Neural Networks},
  author={Wang, Sifan and Sankaran, Shyam and Wang, Hanwen and Perdikaris, Paris},
  journal={arXiv preprint arXiv:2308.08468},
  year={2023}
}

@article{wang2024piratenets,
  title={PirateNets: Physics-informed Deep Learning with Residual Adaptive Networks},
  author={Wang, Sifan and Li, Bowen and Chen, Yuhan and Perdikaris, Paris},
  journal={arXiv preprint arXiv:2402.00326},
  year={2024}
}

jaxpi's People

Contributors

sifanexisted avatar shyams2 avatar hydrogensulfate avatar eltociear avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.