Comments (8)
Sorry this is actually not the trick exploited to its maximum. Their trick becomes better than the naive implementation when you combine it with a linear layer. I implemented it in the experimental directory : https://github.com/e3nn/e3nn-jax/blob/main/e3nn_jax/experimental/linear_shtp.py
When I tried it, it was faster
from e3nn-jax.
In my use case (MACE), I multiply the output of the tensor product with output.irreps.num_irreps
weights generated by another neural network, and then do a scatter sum. Will the trick help here?
This basically: https://github.com/ACEsuit/mace-jax/blob/4b899de2101c6e2085ee972aeac0e46a334fd9a0/mace_jax/modules/message_passing.py#L37
from e3nn-jax.
It's not easy to write it but you can do it.
NequIP and MACE share the same convolution layer. Here is my implementation of NequIP with eSCN trick: https://github.com/mariogeiger/nequip-jax/blob/main/nequip_jax/nequip_escn.py#L120-L140
from e3nn-jax.
Thank you for the responses and for pointing me to the NequIP code. .
I did compare LinearSHTP with manually computing the tensor product followed by an e3nn linear layer btw, and actually found LinearSHTP to be a little bit slower (without any external neural network weights). Code:
import jax
import e3nn_jax as e3nn
import time
from e3nn_jax.experimental.linear_shtp import LinearSHTP
import flax.linen as nn
from functools import partial
def deduce_tp_irreps(input_irreps, target_irreps):
required_L = e3nn.Irreps(input_irreps).lmax + e3nn.Irreps(target_irreps).lmax
sh_irreps = e3nn.Irreps.spherical_harmonics(required_L)
irreps_mid = e3nn.tensor_product(input_irreps,
e3nn.Irreps.spherical_harmonics(required_L),
filter_ir_out=target_irreps)
return sh_irreps, irreps_mid
class NN_SHTP_manual(nn.Module):
input_irreps: e3nn.Irreps
target_irreps: e3nn.Irreps
def setup(self):
self.sh_irreps, self.irreps_mid = deduce_tp_irreps(self.input_irreps, self.target_irreps)
self.linear = e3nn.flax.Linear(irreps_in=self.irreps_mid,
irreps_out=self.target_irreps,
)
@nn.compact
def __call__(self, inputs, directions):
out = e3nn.tensor_product(input1=inputs,
input2=e3nn.spherical_harmonics(input=directions, irreps_out=e3nn.Irreps(self.sh_irreps), normalize=True),
filter_ir_out=e3nn.Irreps(self.irreps_mid)
) # (B, n_edges, irreps_mid)
out = self.linear(out)
return out
# B = 20
B = 60
# B = 100
# B = 140
n_edges = 600
leading_shape = (B, n_edges)
degree = 5
input_irreps = "128x0e + 128x1o"
target_irreps = 128 * e3nn.Irreps.spherical_harmonics(degree)
nn_shtp_manual = NN_SHTP_manual(input_irreps=input_irreps, target_irreps=target_irreps)
@jax.jit
def manual(w, inputs, vectors):
out = nn_shtp_manual.apply(w, inputs, vectors)
return out
nn_shtp_opt = LinearSHTP(irreps_out=target_irreps)
@jax.jit
@partial(jax.vmap, in_axes=(None, 0, 0))
@partial(jax.vmap, in_axes=(None, 0, 0))
def optimized(w, inputs, vectors):
out = nn_shtp_opt.apply(w, inputs, vectors)
return out
irreps_mid = deduce_tp_irreps(input_irreps, target_irreps)[1]
inputs = e3nn.normal(irreps=input_irreps, key=jax.random.PRNGKey(0), leading_shape=leading_shape)
vectors = e3nn.normal(irreps="1o", key=jax.random.PRNGKey(0), leading_shape=leading_shape)
w_manual = nn_shtp_manual.init(jax.random.PRNGKey(0), inputs[0, 0], vectors[0, 0])
print("Manual:")
out1 = manual(w_manual, inputs, vectors)
print(out1.irreps, out1.shape, out1.array.sum(), "num_weights =", sum(x.size for x in jax.tree_util.tree_leaves(w_manual)))
del out1, w_manual
w_opt = nn_shtp_opt.init(jax.random.PRNGKey(0), inputs[0, 0], vectors[0, 0])
print("Optimized:")
out2 = optimized(w_opt, inputs, vectors)
print(out2.irreps, out2.shape, out2.array.sum(), "num_weights =", sum(x.size for x in jax.tree_util.tree_leaves(w_opt)))
del out2, w_opt
print()
del inputs, vectors
n_iters = 500
def run_manual():
ins = e3nn.normal(irreps=input_irreps, key=jax.random.PRNGKey(0), leading_shape=leading_shape)
vecs = e3nn.normal(irreps="1o", key=jax.random.PRNGKey(0), leading_shape=leading_shape)
weights = nn_shtp_manual.init(jax.random.PRNGKey(0), ins[0, 0], vecs[0, 0])
t0 = time.time()
for i in range(1, n_iters+1):
inputs = e3nn.normal(irreps=input_irreps, key=jax.random.PRNGKey(i), leading_shape=leading_shape)
vectors = e3nn.normal(irreps="1o", key=jax.random.PRNGKey(i), leading_shape=leading_shape)
out = manual(weights, inputs, vectors)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), out)
del inputs, vectors, out
t1 = time.time()
print("Manual:", (t1-t0) / n_iters, "s")
del ins, vecs, weights
def run_optimized():
ins = e3nn.normal(irreps=input_irreps, key=jax.random.PRNGKey(0), leading_shape=leading_shape)
vecs = e3nn.normal(irreps="1o", key=jax.random.PRNGKey(0), leading_shape=leading_shape)
weights = nn_shtp_opt.init(jax.random.PRNGKey(0), ins[0, 0], vecs[0, 0])
t0 = time.time()
for i in range(1, n_iters+1):
inputs = e3nn.normal(irreps=input_irreps, key=jax.random.PRNGKey(i), leading_shape=leading_shape)
vectors = e3nn.normal(irreps="1o", key=jax.random.PRNGKey(i), leading_shape=leading_shape)
out = optimized(weights, inputs, vectors)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), out)
del inputs, vectors, out
t1 = time.time()
print("Optimized:", (t1-t0) / n_iters, "s")
del ins, vecs, weights
print(leading_shape)
run_manual()
run_optimized()
How much faster was it when you ran it? And given that it was not faster for me for the tp+linear operation, should I expect it to be faster with the neural network weights?
For context I am currently working on an applied project, so would help a lot if the trick actually speeds thing up.
from e3nn-jax.
I'm surprised because you even try with a large L (L=5) I don't remember the numbers but it was faster even for small L (L=2).
Could you try from nequip_jax import NEQUIPLayerFlax
vs from nequip_jax import NEQUIPESCNLayerFlax
?
https://github.com/mariogeiger/nequip-jax
from e3nn-jax.
Hi, I did test this as well (modified your test.py file).
Same thing - the ESCN version was a bit slower.
from e3nn-jax.
I have no clue. Try maybe n_edges much larger? Like 2**18
from e3nn-jax.
I think I am done testing this, thanks for the responses anyway.
from e3nn-jax.
Related Issues (19)
- Add instructions for local development. HOT 4
- Ensure consistent code formatting. HOT 2
- Do we need sympy for orthonormalization? HOT 8
- Error when upgrading to `0.8.0` HOT 2
- Batchnorm HOT 5
- Backward pass runtime degradation (Linear + `tensor_product`) in the latest versions HOT 8
- Gate output irreps HOT 2
- Availability of FullTensorProduct function similar to PyTorch version HOT 2
- FullyConnectedTensorProduct Feature Discrepancy in JAX vs. Torch HOT 2
- migrate to `pyproject.toml` for packaging HOT 1
- refactor test suite HOT 1
- GitHub install fails silently HOT 3
- Legendre Transform tests are failing HOT 1
- Code coverage not working HOT 9
- TensorProduct and FullyConnectedTensorProduct in e3nn_jax HOT 1
- IrrepsArray.transform_by_angles and LinearSHTP does not support batch operation HOT 2
- Please wrap `clebsch_gordan` inside `functools.cache` HOT 1
- Improve test coverage. HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from e3nn-jax.