Code Monkey home page Code Monkey logo

Comments (8)

mariogeiger avatar mariogeiger commented on August 15, 2024

Sorry this is actually not the trick exploited to its maximum. Their trick becomes better than the naive implementation when you combine it with a linear layer. I implemented it in the experimental directory : https://github.com/e3nn/e3nn-jax/blob/main/e3nn_jax/experimental/linear_shtp.py

When I tried it, it was faster

from e3nn-jax.

Kalyan0821 avatar Kalyan0821 commented on August 15, 2024

In my use case (MACE), I multiply the output of the tensor product with output.irreps.num_irreps weights generated by another neural network, and then do a scatter sum. Will the trick help here?
This basically: https://github.com/ACEsuit/mace-jax/blob/4b899de2101c6e2085ee972aeac0e46a334fd9a0/mace_jax/modules/message_passing.py#L37

from e3nn-jax.

mariogeiger avatar mariogeiger commented on August 15, 2024

It's not easy to write it but you can do it.

NequIP and MACE share the same convolution layer. Here is my implementation of NequIP with eSCN trick: https://github.com/mariogeiger/nequip-jax/blob/main/nequip_jax/nequip_escn.py#L120-L140

from e3nn-jax.

Kalyan0821 avatar Kalyan0821 commented on August 15, 2024

Thank you for the responses and for pointing me to the NequIP code. .
I did compare LinearSHTP with manually computing the tensor product followed by an e3nn linear layer btw, and actually found LinearSHTP to be a little bit slower (without any external neural network weights). Code:

import jax
import e3nn_jax as e3nn
import time
from e3nn_jax.experimental.linear_shtp import LinearSHTP
import flax.linen as nn
from functools import partial


def deduce_tp_irreps(input_irreps, target_irreps):
    required_L = e3nn.Irreps(input_irreps).lmax + e3nn.Irreps(target_irreps).lmax
    sh_irreps = e3nn.Irreps.spherical_harmonics(required_L)
    irreps_mid = e3nn.tensor_product(input_irreps, 
                                     e3nn.Irreps.spherical_harmonics(required_L), 
                                     filter_ir_out=target_irreps)
    return sh_irreps, irreps_mid


class NN_SHTP_manual(nn.Module):
    input_irreps: e3nn.Irreps
    target_irreps: e3nn.Irreps

    def setup(self):
        self.sh_irreps, self.irreps_mid = deduce_tp_irreps(self.input_irreps, self.target_irreps)
        self.linear = e3nn.flax.Linear(irreps_in=self.irreps_mid,
                                       irreps_out=self.target_irreps,
                                       )
    @nn.compact
    def __call__(self, inputs, directions):

        out = e3nn.tensor_product(input1=inputs,
                                  input2=e3nn.spherical_harmonics(input=directions, irreps_out=e3nn.Irreps(self.sh_irreps), normalize=True),
                                  filter_ir_out=e3nn.Irreps(self.irreps_mid)
                                )  # (B, n_edges, irreps_mid)
        out = self.linear(out)
        return out


# B = 20
B = 60
# B = 100
# B = 140

n_edges = 600
leading_shape = (B, n_edges)
degree = 5
input_irreps = "128x0e + 128x1o"
target_irreps = 128 * e3nn.Irreps.spherical_harmonics(degree)


nn_shtp_manual = NN_SHTP_manual(input_irreps=input_irreps, target_irreps=target_irreps)
@jax.jit
def manual(w, inputs, vectors):
    out = nn_shtp_manual.apply(w, inputs, vectors)
    return out

nn_shtp_opt = LinearSHTP(irreps_out=target_irreps)
@jax.jit
@partial(jax.vmap, in_axes=(None, 0, 0))
@partial(jax.vmap, in_axes=(None, 0, 0))
def optimized(w, inputs, vectors):
    out = nn_shtp_opt.apply(w, inputs, vectors)
    return out

irreps_mid = deduce_tp_irreps(input_irreps, target_irreps)[1]

inputs = e3nn.normal(irreps=input_irreps, key=jax.random.PRNGKey(0), leading_shape=leading_shape)
vectors = e3nn.normal(irreps="1o", key=jax.random.PRNGKey(0), leading_shape=leading_shape)


w_manual = nn_shtp_manual.init(jax.random.PRNGKey(0), inputs[0, 0], vectors[0, 0])
print("Manual:")
out1 = manual(w_manual, inputs, vectors)
print(out1.irreps, out1.shape, out1.array.sum(), "num_weights =", sum(x.size for x in jax.tree_util.tree_leaves(w_manual)))
del out1, w_manual

w_opt = nn_shtp_opt.init(jax.random.PRNGKey(0), inputs[0, 0], vectors[0, 0])
print("Optimized:")
out2 = optimized(w_opt, inputs, vectors)
print(out2.irreps, out2.shape, out2.array.sum(), "num_weights =", sum(x.size for x in jax.tree_util.tree_leaves(w_opt)))
del out2, w_opt
print()


del inputs, vectors
n_iters = 500
def run_manual():
    ins = e3nn.normal(irreps=input_irreps, key=jax.random.PRNGKey(0), leading_shape=leading_shape)
    vecs = e3nn.normal(irreps="1o", key=jax.random.PRNGKey(0), leading_shape=leading_shape)
    weights = nn_shtp_manual.init(jax.random.PRNGKey(0), ins[0, 0], vecs[0, 0])

    t0 = time.time()
    for i in range(1, n_iters+1):
        inputs = e3nn.normal(irreps=input_irreps, key=jax.random.PRNGKey(i), leading_shape=leading_shape)
        vectors = e3nn.normal(irreps="1o", key=jax.random.PRNGKey(i), leading_shape=leading_shape)

        out = manual(weights, inputs, vectors)
        jax.tree_util.tree_map(lambda x: x.block_until_ready(), out)
        del inputs, vectors, out
    t1 = time.time()
    print("Manual:", (t1-t0) / n_iters, "s")
    del ins, vecs, weights

def run_optimized():
    ins = e3nn.normal(irreps=input_irreps, key=jax.random.PRNGKey(0), leading_shape=leading_shape)
    vecs = e3nn.normal(irreps="1o", key=jax.random.PRNGKey(0), leading_shape=leading_shape)
    weights = nn_shtp_opt.init(jax.random.PRNGKey(0), ins[0, 0], vecs[0, 0])

    t0 = time.time()
    for i in range(1, n_iters+1):
        inputs = e3nn.normal(irreps=input_irreps, key=jax.random.PRNGKey(i), leading_shape=leading_shape)
        vectors = e3nn.normal(irreps="1o", key=jax.random.PRNGKey(i), leading_shape=leading_shape)

        out = optimized(weights, inputs, vectors)
        jax.tree_util.tree_map(lambda x: x.block_until_ready(), out)
        del inputs, vectors, out
    t1 = time.time()
    print("Optimized:", (t1-t0) / n_iters, "s")
    del ins, vecs, weights


print(leading_shape)
run_manual()
run_optimized()

Output:
Screenshot from 2024-04-17 23-55-17

How much faster was it when you ran it? And given that it was not faster for me for the tp+linear operation, should I expect it to be faster with the neural network weights?

For context I am currently working on an applied project, so would help a lot if the trick actually speeds thing up.

from e3nn-jax.

mariogeiger avatar mariogeiger commented on August 15, 2024

I'm surprised because you even try with a large L (L=5) I don't remember the numbers but it was faster even for small L (L=2).

Could you try from nequip_jax import NEQUIPLayerFlax vs from nequip_jax import NEQUIPESCNLayerFlax?

https://github.com/mariogeiger/nequip-jax

from e3nn-jax.

Kalyan0821 avatar Kalyan0821 commented on August 15, 2024

Hi, I did test this as well (modified your test.py file).
Same thing - the ESCN version was a bit slower.

from e3nn-jax.

mariogeiger avatar mariogeiger commented on August 15, 2024

I have no clue. Try maybe n_edges much larger? Like 2**18

from e3nn-jax.

Kalyan0821 avatar Kalyan0821 commented on August 15, 2024

I think I am done testing this, thanks for the responses anyway.

from e3nn-jax.

Related Issues (19)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.