Code Monkey home page Code Monkey logo

kernex's Introduction

Differentiable Stencil computations in JAX

Installation |Description |Quick example |More Examples |Benchmarking

Tests pyver codestyle Downloads Open In Colab codecov DOI

๐Ÿ› ๏ธ Installation

pip install kernex

๐Ÿ“– Description

Kernex extends jax.vmap/jax.lax.map/jax.pmap with kmap and jax.lax.scan with kscan for general stencil computations.

โฉ Quick Example

kmap kscan
import kernex as kex
import jax.numpy as jnp

@kex.kmap(kernel_size=(3,))
def sum_all(x):
    return jnp.sum(x)

x = jnp.array([1,2,3,4,5])
print(sum_all(x))
# [ 6  9 12]
import kernex as kex 
import jax.numpy as jnp

@kex.kscan(kernel_size=(3,))
def sum_all(x):
return jnp.sum(x)

x = jnp.array([1,2,3,4,5])
print(sum_all(x))
# [ 6 13 22]
`jax.vmap` is used to sum each window content. `lax.scan` is used to update the array and the window sum is calculated sequentially. the first three rows represents the three sequential steps used to get the solution in the last row.

๐Ÿ”ข More examples

1๏ธโƒฃ Convolution operation
import jax
import jax.numpy as jnp
import kernex as kex

@jax.jit
@kex.kmap(
    kernel_size= (3,3,3),
    padding = ('valid','same','same'))
def kernex_conv2d(x,w):
    # JAX channel first conv2d with 3x3x3 kernel_size
    return jnp.sum(x*w)
2๏ธโƒฃ Laplacian operation
# see also
# https://numba.pydata.org/numba-doc/latest/user/stencil.html#basic-usage
import jax
import jax.numpy as jnp
import kernex as kex

@kex.kmap(
    kernel_size=(3,3),
    padding= 'valid',
    relative=True) # `relative`= True enables relative indexing
def laplacian(x):
    return ( 0*x[1,-1]  + 1*x[1,0]   + 0*x[1,1] +
             1*x[0,-1]  +-4*x[0,0]   + 1*x[0,1] +
             0*x[-1,-1] + 1*x[-1,0]  + 0*x[-1,1] )

print(laplacian(jnp.ones([10,10])))
# [[0., 0., 0., 0., 0., 0., 0., 0.],
#  [0., 0., 0., 0., 0., 0., 0., 0.],
#  [0., 0., 0., 0., 0., 0., 0., 0.],
#  [0., 0., 0., 0., 0., 0., 0., 0.],
#  [0., 0., 0., 0., 0., 0., 0., 0.],
#  [0., 0., 0., 0., 0., 0., 0., 0.],
#  [0., 0., 0., 0., 0., 0., 0., 0.],
#  [0., 0., 0., 0., 0., 0., 0., 0.]]
3๏ธโƒฃ Get Patches of an array
import jax
import jax.numpy as jnp
import kernex as kex

@kex.kmap(kernel_size=(3,3),relative=True)
def identity(x):
    # similar to numba.stencil
    # this function returns the top left cell in the padded/unpadded kernel view
    # or center cell if `relative`=True
    return x[0,0]

# unlike numba.stencil , vector output is allowed in kernex
# this function is similar to
# `jax.lax.conv_general_dilated_patches(x,(3,),(1,),padding='same')`
@jax.jit
@kex.kmap(kernel_size=(3,3),padding='same')
def get_3x3_patches(x):
    # returns 5x5x3x3 array
    return x

mat = jnp.arange(1,26).reshape(5,5)
print(mat)
# [[ 1  2  3  4  5]
#  [ 6  7  8  9 10]
#  [11 12 13 14 15]
#  [16 17 18 19 20]
#  [21 22 23 24 25]]


# get the view at array index = (0,0)
print(get_3x3_patches(mat)[0,0])
# [[0 0 0]
#  [0 1 2]
#  [0 6 7]]
4๏ธโƒฃ Linear convection
Problem setup Stencil view
import jax
import jax.numpy as jnp
import kernex as kex
import matplotlib.pyplot as plt

# see https://nbviewer.org/github/barbagroup/CFDPython/blob/master/lessons/01_Step_1.ipynb

tmax,xmax = 0.5,2.0
nt,nx = 151,51
dt,dx = tmax/(nt-1) , xmax/(nx-1)
u = jnp.ones([nt,nx])
c = 0.5

# kscan moves sequentially in row-major order and updates in-place using lax.scan.

F = kernex.kscan(
        kernel_size = (3,3),
        padding = ((1,1),(1,1)),
        # n for time axis , i for spatial axis (optional naming)
        named_axis={0:'n',1:'i'},  
        relative=True
    )


# boundary condtion as a function
def bc(u):
    return 1

# initial condtion as a function
def ic1(u):
    return 1

def ic2(u):
    return 2

def linear_convection(u):
    return ( u['i','n-1'] - (c*dt/dx) * (u['i','n-1'] - u['i-1','n-1']) )


F[:,0]  = F[:,-1] = bc # assign 1 for left and right boundary for all t

# square wave initial condition
F[:,:int((nx-1)/4)+1] = F[:,int((nx-1)/2):] = ic1
F[0:1, int((nx-1)/4)+1 : int((nx-1)/2)] = ic2

# assign linear convection function for
# interior spatial location [1:-1]
# and start from t>0  [1:]
F[1:,1:-1] = linear_convection

kx_solution = F(jnp.array(u))

plt.figure(figsize=(20,7))
for line in kx_solution[::20]:
    plt.plot(jnp.linspace(0,xmax,nx),line)
5๏ธโƒฃ Gaussian blur
import jax
import jax.numpy as jnp
import kernex as kex

def gaussian_blur(image, sigma, kernel_size):
    x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size- 1) / 2.0, kernel_size)
    w = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(sigma))
    w = jnp.outer(w, w)
    w = w / w.sum()

    @kex.kmap(kernel_size=(kernel_size, kernel_size), padding="same")
    def conv(x):
        return jnp.sum(x * w)

    return conv(image)
6๏ธโƒฃ Depthwise convolution
import jax
import jax.numpy as jnp
import kernex as kex

@jax.jit
@jax.vmap
@kex.kmap(
    kernel_size= (3,3),
    padding = ('same','same'))
def kernex_depthwise_conv2d(x,w): 
    return jnp.sum(x*w)

h,w,c = 5,5,2
k=3

x = jnp.arange(1,h*w*c+1).reshape(c,h,w)
w = jnp.arange(1,k*k*c+1).reshape(c,k,k)
print(kernex_depthwise_conv2d(x,w))
7๏ธโƒฃ Average pooling 2D
@jax.vmap # vectorize over the channel dimension
@kex.kmap(kernel_size=(3,3), strides=(2,2))
def avgpool_2d(x):
    # define the kernel for the Average pool operation over the spatial dimensions
    return jnp.mean(x)
8๏ธโƒฃ Runge-Kutta integration
# lets solve dydt = y, where y0 = 1 and y(t)=e^t
# using Runge-Kutta 4th order method
# f(t,y) = y
import jax.numpy as jnp
import matplotlib.pyplot as plt
import kernex as kex


t = jnp.linspace(0, 1, 5)
y = jnp.zeros(5)
x = jnp.stack([y, t], axis=0)
dt = t[1] - t[0]  # 0.1
f = lambda tn, yn: yn


def ic(x):
    """ initial condition y0 = 1 """
    return 1.


def rk4(x):
    """ runge kutta 4th order integration step """
    # โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”      โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”
    # โ”‚ y0 โ”‚*y1*โ”‚ y2 โ”‚      โ”‚[0,-1]โ”‚[0, 0]โ”‚[0, 1]โ”‚
    # โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ค ==>  โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
    # โ”‚ t0 โ”‚ t1 โ”‚ t2 โ”‚      โ”‚[1,-1]โ”‚[1, 0]โ”‚[1, 1]โ”‚
    # โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”˜      โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”˜
    t0 = x[1, -1]
    y0 = x[0, -1]
    k1 = dt * f(t0, y0)
    k2 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k1)
    k3 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k2)
    k4 = dt * f(t0 + dt, y0 + k3)
    yn_1 = y0 + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
    return yn_1


F = kex.kscan(kernel_size=(2, 3), relative=True, padding=((0, 1)))  # kernel size = 3

F[0:1, 1:] = rk4
F[0, 0] = ic
# compile the solver
solver = jax.jit(F.__call__)
y = solver(x)[0, :]

plt.plot(t, y, '-o', label='rk4')
plt.plot(t, jnp.exp(t), '-o', label='analytical')
plt.legend()

img

kernex's People

Contributors

asem000 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

dmarx srib bbsun

kernex's Issues

Support Pytrees

Does kernex support Pytrees? I did not find an example. It would be very useful to support moving-window filters with "global" weights or simply multiple inputs, such as a cross-channel bilateral filter in my case.

Repro:

import jax.numpy as jnp
import kernex

@kernex.kmap(kernel_size=(3, 3))
def kernel(tree):
    x, y = tree
    return jnp.sum(x * jnp.square(y))

data = jnp.arange(20 * 30).reshape((20, 30))
out = kernel((data, data))

raises

Traceback (most recent call last):
  File "/home/clemisch/kernex_tree.py", line 52, in <module>
    out = kernel((data, data))
          ^^^^^^^^^^^^^^^^^^^^
  File "/home/clemisch/venvs/11/lib64/python3.11/site-packages/kernex/interface/kernel_interface.py", line 131, in call
    self.shape = array.shape
                 ^^^^^^^^^^^
AttributeError: 'tuple' object has no attribute 'shape'

`mode="reflect"` in padding is incorrect

I think mode="reflect" for padding_kwargs is incorrect:

import jax.numpy as jnp
import kernex

@kernex.kmap(
    kernel_size=(3,),
    padding=("same"),
    relative=False,
    padding_kwargs=dict(mode="reflect"),
)
def f(x):
    return x

x = jnp.array([1, 2, 3, 4, 5])
y = f(x)
z = jnp.pad(x, 1, mode="reflect")

print("x: ", x)
print("y: ", y)
print("z: ", z)

gives

x:  [1 2 3 4 5]
y:  [[3 1 2]     # <-- the `3` is incorrect, should be `2`
     [1 2 3]
     [2 3 4]
     [3 4 5]
     [4 5 4]]
z:  [2 1 2 3 4 5 4]   # <-- here, the first element is `2`

The Kernex output reflects incorrectly: the first element is 3 instead of 2.

add argnum/argname

kernex operate on views of the first function array argument, it is useful to add argnum/argname to select which argument to operate on in a similar fashion in jax transformation

Implementation strategy

This project looks really cool!

I would love to understand at a high level how this package works -- how do you actually implement stencil computations in JAX? Do you reuse jax.lax.scan or something else? Does it support auto-diff? How does performance compare on CPU/GPU/TPU (or whichever configs you've tried)?

Document border handling with padding

How does Kernex handle borders and out-of-bounds indices, especially with padding="same"? I skimmed the source code but did not find it. Very superficially it seems to me that it performs 0-padding. Is this the case?

If yes, it would be useful to support different schemes like "mirror", "edge", or ideally dropping the out-of-bounds indices.

Repro:

from pylab import *
import jax.numpy as jnp
import kernex

@kernex.kmap(kernel_size=(11, 11), padding="same")
def kernel(patch):
    return jnp.mean(patch)

data = ones((100, 100))
out = kernel(data)

figure()
plot(out[50], "+-", label="Kernex")
plot(ones_like(out[50]), "+-", label="Ideal")
legend()

image

Add dilation

  • Dilation can be added by changing the view indices generation step value.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.