Comments (2)
Hi I see your discussion in the other thread. It is possible to train the neural ODEfunc which the neural ODEfunc depends on some parameters beside training times and ys (not sure if this is what you want?). Here is an example of training neural ODEfun to behave like Sine wave with a control over its amplitude and period as oposed to astroNN doc example where it has fixed amplitude and period (so to generate sine wave you want, you provide initial condition, time array and aux
representing amplitude and period here). Also it is possible to use Keras API over your own training loop (keras API should be much faster and taking care most of the things).
So you can modified (or later today I can once I finish my works) the following example,
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import time
from tqdm import tqdm
import pylab as plt
import numpy as np
import tensorflow as tf
from types import MethodType
from astroNN.nn.losses import mean_squared_error, mean_absolute_error
from astroNN.shared.nn_tools import cpu_fallback, gpu_memory_manage
from astroNN.neuralode import odeint
from tensorflow.python.keras.engine import data_adapter
from astroNN.neuralode.dop853 import dop853
from astroNN.neuralode.runge_kutta import rk4
cpu_fallback()
gpu_memory_manage()
# tf.summary.trace_on(graph=True)
ts_size = 10
size = 10000
t = np.linspace(0, 25, size)
dt = t[1] - t[0]
t = t + np.random.uniform(-dt/2, dt/2, size)
# initial condition
true_y0 = [0., 2.]
true_y_np = np.stack([np.sin(2*t), 2*np.cos(2*t)]).T
true_y = tf.constant(true_y_np, dtype=tf.float32)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
t, y = data
with tf.GradientTape() as tape:
y_pred = odeint(lambda x, t, aux: tf.squeeze(self([tf.expand_dims(x, axis=0), tf.expand_dims(aux, axis=0), tf.expand_dims(t, axis=0)])), y[:, 0], t['input'],
aux=t['aux'], method='dop853') # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
inputs = tf.keras.Input(shape=(2), name='input')
aux = tf.keras.Input(shape=(2), name='aux')
ts = tf.keras.Input(shape=(1))
dense1 = tf.keras.layers.Dense(32, activation='relu')(tf.keras.layers.concatenate([inputs, aux]))
dense2 = tf.keras.layers.Dense(32, activation='relu')(dense1)
dense3 = tf.keras.layers.Dense(2)(dense2)
model = tf.keras.Model(inputs=[inputs, aux, ts], outputs=[dense3])
model.compile(optimizer = tf.keras.optimizers.Adam(lr=0.01), loss=mean_absolute_error)
model.train_step = MethodType(train_step, model)
s = np.random.choice(np.arange(size - ts_size, dtype=np.int64), 2000, replace=False)
batch_parameter = tf.abs(tf.random.uniform([2000, 2], 0.3, 2))
batch_t = tf.cast(tf.stack([t[_s:_s+ts_size] for _s in s]), dtype=tf.float32)
batch_y = tf.stack([batch_parameter[:, 1:]*tf.sin(batch_parameter[:, 0:1]*batch_t),
batch_parameter[:, 1:]*batch_parameter[:, 0:1]*tf.cos(batch_parameter[:, 0:1]*batch_t)], axis=2)
rlreduce = tf.keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=0.5, min_delta=0.001,
patience=4, min_lr=1e-6, mode='min',
verbose=2)
model.fit({'input': batch_t, 'aux': batch_parameter}, batch_y, batch_size=64, epochs=15, callbacks=[rlreduce])
from itertools import permutations
perms = list(permutations([0.5, 1., 1.5], 2))
batch_parameter = tf.constant(perms)
batch_parameter_expanded = tf.expand_dims(batch_parameter, axis=1)
t_test = tf.cast(t[:5000], dtype=tf.float32)
t_test = tf.stack([t_test]*6)
batch_y = tf.stack([batch_parameter[:, 1:]*tf.sin(batch_parameter[:, 0:1]*t_test),
batch_parameter[:, 1:]*batch_parameter[:, 0:1]*tf.cos(batch_parameter[:, 0:1]*t_test)], axis=2) # (batch, T, D)
batch_y0 = batch_y[:, 0, :]
min_idx, max_idx = tf.argmin(batch_parameter), tf.argmax(batch_parameter)
y_pred = odeint(lambda x, t, aux: tf.squeeze(model([tf.expand_dims(x, axis=0), tf.expand_dims(aux, axis=0), tf.expand_dims(t, axis=0)])), batch_y0, t_test,
aux=batch_parameter, method='dop853')
plt.figure()
plt.plot(t[:5000], y_pred[0, :, 0], label='Predict')
plt.plot(t[:5000], batch_y[0, :, 0], ls='--', label='Ground Truth')
plt.legend(loc='best')
plt.show()
plt.figure()
plt.plot(t[:5000], y_pred[1, :, 0], label='Predict')
plt.plot(t[:5000], batch_y[1, :, 0], ls='--', label='Ground Truth')
plt.legend(loc='best')
plt.show()
plt.figure()
plt.plot(t[:5000], y_pred[2, :, 0], label='Predict')
plt.plot(t[:5000], batch_y[2, :, 0], ls='--', label='Ground Truth')
plt.legend(loc='best')
plt.show()
plt.figure()
plt.plot(t[:5000], y_pred[3, :, 0], label='Predict')
plt.plot(t[:5000], batch_y[3, :, 0], ls='--', label='Ground Truth')
plt.legend(loc='best')
plt.show()
plt.figure()
plt.plot(t[:5000], y_pred[4, :, 0], label='Predict')
plt.plot(t[:5000], batch_y[4, :, 0], ls='--', label='Ground Truth')
plt.legend(loc='best')
plt.show()
plt.figure()
plt.plot(t[:5000], y_pred[5, :, 0], label='Predict')
plt.plot(t[:5000], batch_y[5, :, 0], ls='--', label='Ground Truth')
plt.legend(loc='best')
plt.show()
from astronn.
I just have time to look at the discussion of the other thread and the above is not the thing you are looking for. I have coded up another example for your application from the code you provided. k1
, k2
, k3
are initialized as 0.5, 0.5, 0.5 for neuralODE model, and the training truth k1
, k2
, k3
are 1, 2, 3 respectively. The training time is very quick using dop853
so I am not sure if something is wrong with your tensorflow installation. And without much tuning of hyperparameters the model can recover k1
, k2
, k3
close to 1, 2, 3 respectively.
If you have further question I am happy to answer as I dont want to spend too much time adding comments the code.
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import time
from tqdm import tqdm
import pylab as plt
import numpy as np
import tensorflow as tf
from types import MethodType
from astroNN.nn.losses import mean_squared_error, mean_absolute_error, zeros_loss
from astroNN.shared.nn_tools import cpu_fallback, gpu_memory_manage
from astroNN.neuralode import odeint
from tensorflow.python.keras.engine import data_adapter
from astroNN.neuralode.dop853 import dop853
from astroNN.neuralode.runge_kutta import rk4
cpu_fallback()
gpu_memory_manage()
# tf.keras.backend.set_floatx('float64')
class Kinetics(tf.keras.Model):
def __init__(self, k1, k2, k3):
super().__init__()
self.k1, self.k2, self.k3 = k1, k2, k3
def call(self, y, *args):
s0, s1, s2 = y[0], y[1], y[2]
d_0 = - self.k1 * s0 + self.k2 * s1
d_1 = - self.k2 * s1 - self.k3 * s1 + self.k1 * s0
d_2 = self.k3 * s1
return tf.stack([d_0, d_1, d_2])
class Kinetics_model(tf.keras.Model):
def __init__(self):
super().__init__()
self.k1 = tf.Variable(.5, trainable=True, dtype=tf.float32)
self.k2 = tf.Variable(.5, trainable=True, dtype=tf.float32)
self.k3 = tf.Variable(.5, trainable=True, dtype=tf.float32)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
t, y = data
with tf.GradientTape() as tape:
y_pred = odeint(lambda x, t: tf.squeeze(self(tf.expand_dims(x, axis=0), tf.expand_dims(t, axis=0))), y[:, 0], t,
method='dop853', precision=tf.float32) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# tf.print(y, y_pred)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
def call(self, y, *args):
s0, s1, s2 = y[:, 0], y[:, 1], y[:, 2]
d_0 = - self.k1 * s0 + self.k2 * s1
d_1 = - self.k2 * s1 - self.k3 * s1 + self.k1 * s0
d_2 = self.k3 * s1
return tf.stack([d_0, d_1, d_2])
NUM_SAMPLES = 2000
t = tf.cast(tf.linspace(0., 10., num=NUM_SAMPLES), dtype=tf.float32)
y_init = tf.constant([1., 0., 0.], dtype=tf.float32)
# Compute the reference trajectory
ref_func = Kinetics(1., 2., 3.)
ref_traj = odeint(ref_func, y_init, t)
model = Kinetics_model()
model.compile(optimizer = tf.keras.optimizers.Adam(lr=1), loss=mean_squared_error)
model.fit(tf.stack([t[:5]]*400), tf.stack(tf.split(ref_traj, num_or_size_splits=400)), batch_size=64, epochs=5, shuffle=True)
from astronn.
Related Issues (15)
- Galaxy-10 missing images HOT 1
- tensorflow 2.4.1 HOT 3
- ApogeeBCNN() dimensions HOT 11
- Issue loading the Galaxy10 dataset HOT 5
- DR16 astroNN catalog of distances produces incorrect parsec values for Md and Mg stars HOT 8
- Transfer learning & Fine-tuning HOT 8
- Loading Galaxy10 dataset HOT 3
- Keras's fit_generator failed when use_multiprocessing=True on WIndows only HOT 1
- Bugs in 3 of the demo_tutorial/NN_uncertainty_analysis HOT 1
- Current .h5 dataset loading mechanism is problematic
- Complete Tensorflow support without installing Keras separately HOT 3
- Weird errors raised by running the new accelerated BNN test() method HOT 2
- Can not reproduce results of Uncertainty_Demo_MNIST.ipynb HOT 4
- ODE example on tensorflow 2.2.0 HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from astronn.