Code Monkey home page Code Monkey logo

Comments (2)

henrysky avatar henrysky commented on May 27, 2024

Hi I see your discussion in the other thread. It is possible to train the neural ODEfunc which the neural ODEfunc depends on some parameters beside training times and ys (not sure if this is what you want?). Here is an example of training neural ODEfun to behave like Sine wave with a control over its amplitude and period as oposed to astroNN doc example where it has fixed amplitude and period (so to generate sine wave you want, you provide initial condition, time array and aux representing amplitude and period here). Also it is possible to use Keras API over your own training loop (keras API should be much faster and taking care most of the things).

So you can modified (or later today I can once I finish my works) the following example,

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import time
from tqdm import tqdm
import pylab as plt
import numpy as np
import tensorflow as tf
from types import MethodType

from astroNN.nn.losses import mean_squared_error, mean_absolute_error
from astroNN.shared.nn_tools import cpu_fallback, gpu_memory_manage
from astroNN.neuralode import odeint

from tensorflow.python.keras.engine import data_adapter

from astroNN.neuralode.dop853 import dop853
from astroNN.neuralode.runge_kutta import rk4

cpu_fallback()
gpu_memory_manage()

# tf.summary.trace_on(graph=True)

ts_size = 10
size = 10000

t = np.linspace(0, 25, size)
dt = t[1] - t[0]
t = t + np.random.uniform(-dt/2, dt/2, size)
# initial condition
true_y0 = [0., 2.]
true_y_np = np.stack([np.sin(2*t), 2*np.cos(2*t)]).T
true_y = tf.constant(true_y_np, dtype=tf.float32)

def train_step(self, data):
    # Unpack the data. Its structure depends on your model and
    # on what you pass to `fit()`.
    t, y = data
        
    with tf.GradientTape() as tape:
        y_pred = odeint(lambda x, t, aux: tf.squeeze(self([tf.expand_dims(x, axis=0), tf.expand_dims(aux, axis=0), tf.expand_dims(t, axis=0)])), y[:, 0], t['input'], 
                        aux=t['aux'], method='dop853')  # Forward pass
        # Compute the loss value
        # (the loss function is configured in `compile()`)
        loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

    # Compute gradients
    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)
    # Update weights
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))
    # Update metrics (includes the metric that tracks the loss)
    self.compiled_metrics.update_state(y, y_pred)
    # Return a dict mapping metric names to current value
    return {m.name: m.result() for m in self.metrics}

inputs = tf.keras.Input(shape=(2), name='input')
aux = tf.keras.Input(shape=(2), name='aux')
ts = tf.keras.Input(shape=(1))
dense1 = tf.keras.layers.Dense(32, activation='relu')(tf.keras.layers.concatenate([inputs, aux]))
dense2 = tf.keras.layers.Dense(32, activation='relu')(dense1)
dense3 = tf.keras.layers.Dense(2)(dense2)
model = tf.keras.Model(inputs=[inputs, aux, ts], outputs=[dense3])

model.compile(optimizer = tf.keras.optimizers.Adam(lr=0.01), loss=mean_absolute_error)
model.train_step = MethodType(train_step, model)

s = np.random.choice(np.arange(size - ts_size, dtype=np.int64), 2000, replace=False)
batch_parameter = tf.abs(tf.random.uniform([2000, 2], 0.3, 2))
batch_t = tf.cast(tf.stack([t[_s:_s+ts_size] for _s in s]), dtype=tf.float32)
batch_y = tf.stack([batch_parameter[:, 1:]*tf.sin(batch_parameter[:, 0:1]*batch_t), 
                    batch_parameter[:, 1:]*batch_parameter[:, 0:1]*tf.cos(batch_parameter[:, 0:1]*batch_t)], axis=2)


rlreduce = tf.keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=0.5, min_delta=0.001,
                                                patience=4, min_lr=1e-6, mode='min',
                                                verbose=2)

model.fit({'input': batch_t, 'aux': batch_parameter}, batch_y, batch_size=64, epochs=15, callbacks=[rlreduce])


from itertools import permutations

perms = list(permutations([0.5, 1., 1.5], 2))

batch_parameter = tf.constant(perms)
batch_parameter_expanded = tf.expand_dims(batch_parameter, axis=1)
t_test = tf.cast(t[:5000], dtype=tf.float32)
t_test = tf.stack([t_test]*6)
batch_y = tf.stack([batch_parameter[:, 1:]*tf.sin(batch_parameter[:, 0:1]*t_test), 
                        batch_parameter[:, 1:]*batch_parameter[:, 0:1]*tf.cos(batch_parameter[:, 0:1]*t_test)], axis=2)  # (batch, T, D)
batch_y0 = batch_y[:, 0, :]

min_idx, max_idx = tf.argmin(batch_parameter), tf.argmax(batch_parameter)

y_pred = odeint(lambda x, t, aux: tf.squeeze(model([tf.expand_dims(x, axis=0), tf.expand_dims(aux, axis=0), tf.expand_dims(t, axis=0)])), batch_y0, t_test, 
                aux=batch_parameter, method='dop853')
plt.figure()
plt.plot(t[:5000], y_pred[0, :, 0], label='Predict')
plt.plot(t[:5000], batch_y[0, :, 0], ls='--', label='Ground Truth')
plt.legend(loc='best')
plt.show()

plt.figure()
plt.plot(t[:5000], y_pred[1, :, 0], label='Predict')
plt.plot(t[:5000], batch_y[1, :, 0], ls='--', label='Ground Truth')
plt.legend(loc='best')
plt.show()

plt.figure()
plt.plot(t[:5000], y_pred[2, :, 0], label='Predict')
plt.plot(t[:5000], batch_y[2, :, 0], ls='--', label='Ground Truth')
plt.legend(loc='best')
plt.show()

plt.figure()
plt.plot(t[:5000], y_pred[3, :, 0], label='Predict')
plt.plot(t[:5000], batch_y[3, :, 0], ls='--', label='Ground Truth')
plt.legend(loc='best')
plt.show()

plt.figure()
plt.plot(t[:5000], y_pred[4, :, 0], label='Predict')
plt.plot(t[:5000], batch_y[4, :, 0], ls='--', label='Ground Truth')
plt.legend(loc='best')
plt.show()

plt.figure()
plt.plot(t[:5000], y_pred[5, :, 0], label='Predict')
plt.plot(t[:5000], batch_y[5, :, 0], ls='--', label='Ground Truth')
plt.legend(loc='best')
plt.show()

image

from astronn.

henrysky avatar henrysky commented on May 27, 2024

I just have time to look at the discussion of the other thread and the above is not the thing you are looking for. I have coded up another example for your application from the code you provided. k1, k2, k3 are initialized as 0.5, 0.5, 0.5 for neuralODE model, and the training truth k1, k2, k3 are 1, 2, 3 respectively. The training time is very quick using dop853 so I am not sure if something is wrong with your tensorflow installation. And without much tuning of hyperparameters the model can recover k1, k2, k3 close to 1, 2, 3 respectively.

If you have further question I am happy to answer as I dont want to spend too much time adding comments the code.

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import time
from tqdm import tqdm
import pylab as plt
import numpy as np
import tensorflow as tf
from types import MethodType

from astroNN.nn.losses import mean_squared_error, mean_absolute_error, zeros_loss
from astroNN.shared.nn_tools import cpu_fallback, gpu_memory_manage
from astroNN.neuralode import odeint

from tensorflow.python.keras.engine import data_adapter

from astroNN.neuralode.dop853 import dop853
from astroNN.neuralode.runge_kutta import rk4

cpu_fallback()
gpu_memory_manage()

# tf.keras.backend.set_floatx('float64')

class Kinetics(tf.keras.Model):

    def __init__(self, k1, k2, k3):
        super().__init__()
        self.k1, self.k2, self.k3 = k1, k2, k3

    def call(self, y, *args):
        s0, s1, s2 = y[0], y[1], y[2]

        d_0 = - self.k1 * s0 + self.k2 * s1
        d_1 = - self.k2 * s1 - self.k3 * s1 + self.k1 * s0
        d_2 = self.k3 * s1

        return tf.stack([d_0, d_1, d_2])

class Kinetics_model(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.k1 = tf.Variable(.5, trainable=True, dtype=tf.float32)
        self.k2 = tf.Variable(.5, trainable=True, dtype=tf.float32)
        self.k3 = tf.Variable(.5, trainable=True, dtype=tf.float32)

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        t, y = data

        with tf.GradientTape() as tape:
            y_pred = odeint(lambda x, t: tf.squeeze(self(tf.expand_dims(x, axis=0), tf.expand_dims(t, axis=0))), y[:, 0], t, 
                            method='dop853', precision=tf.float32)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
#         tf.print(y, y_pred)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}


    def call(self, y, *args):
        s0, s1, s2 = y[:, 0], y[:, 1], y[:, 2]

        d_0 = - self.k1 * s0 + self.k2 * s1
        d_1 = - self.k2 * s1 - self.k3 * s1 + self.k1 * s0
        d_2 = self.k3 * s1

        return tf.stack([d_0, d_1, d_2])
    
NUM_SAMPLES = 2000
t = tf.cast(tf.linspace(0., 10., num=NUM_SAMPLES), dtype=tf.float32)
y_init = tf.constant([1., 0., 0.], dtype=tf.float32)

# Compute the reference trajectory
ref_func = Kinetics(1., 2., 3.)
ref_traj = odeint(ref_func, y_init, t)

model = Kinetics_model()
model.compile(optimizer = tf.keras.optimizers.Adam(lr=1), loss=mean_squared_error)

model.fit(tf.stack([t[:5]]*400), tf.stack(tf.split(ref_traj, num_or_size_splits=400)), batch_size=64, epochs=5, shuffle=True)

image

from astronn.

Related Issues (15)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.