View file src/colab/neural_odes.py - Download

# -*- coding: utf-8 -*-
"""neural_odes.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/199_y6r0UClGpDMmDITLLlbsdxmc9aZsB

https://implicit-layers-tutorial.org/neural_odes/

Neural Ordinary Differential Equations

If we want to build a continuous-time or continuous-depth model, differential equation solvers are a useful tool. But how exactly can we treat odeint as a layer for building deep models? The previous chapter showed how to compute its gradients, so the only thing missing is to give it some parameters. This chapter will show how and why to do so.
In this chapter we won’t be using any deep learning frameworks. Instead, we’ll build everything from scratch using differentiable Numpy commands available through JAX.

Preliminaries: Training a residual network
As a warm-up, we can define a simple deep neural network in only a few lines:
"""

import jax.numpy as jnp

def mlp(params, inputs):
  # A multi-layer perceptron, i.e. a fully-connected neural network.
  for w, b in params:
    outputs = jnp.dot(inputs, w) + b  # Linear transform
    outputs = jnp.tanh(outputs)        # Nonlinearity
  return outputs

"""
mlp is simply a composition of linear and nonlinear layers. Its parameters params are a list of weight matrices and bias vectors.
To make larger models, we can always chain together or compose layers. As a standard example, chaining together some smaller neural networks, such as mlp layers, adding each one’s input to its output, is called a residual network:"""

def resnet(params, inputs, depth):
  for i in range(depth):
    outputs = mlp(params, inputs) + inputs
  return outputs

"""To fit this model to data, we also need a loss, an initializer, and an optimizer:"""

!pip install jax

import numpy.random as npr
from jax import jit, grad

resnet_depth = 3
def resnet_squared_loss(params, inputs, targets):
  preds = resnet(params, inputs, resnet_depth)
  return jnp.mean(jnp.sum((preds - targets)**2, axis=1))

def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
  return [(scale * rng.randn(m, n), scale * rng.randn(n))
          for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]

# A simple gradient-descent optimizer.
@jit
def resnet_update(params, inputs, targets):
  grads = grad(resnet_squared_loss)(params, inputs, targets)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

"""As a sanity check, let’s fit our resnet to a toy 1D dataset (green circles) and plot the predictions of the trained model (blue curve):"""

# Toy 1D dataset.
inputs = jnp.reshape(jnp.linspace(-2.0, 2.0, 10), (10, 1))
targets = inputs**3 + 0.1 * inputs

# Hyperparameters.
layer_sizes = [1, 20, 1]
param_scale = 1.0
step_size = 0.01
train_iters = 1000

# Initialize and train.
resnet_params = init_random_params(param_scale, layer_sizes)
for i in range(train_iters):
  resnet_params = resnet_update(resnet_params, inputs, targets)

# Plot results.
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(6, 4), dpi=150)
ax = fig.gca()
ax.scatter(inputs, targets, lw=0.5, color='green')
fine_inputs = jnp.reshape(jnp.linspace(-3.0, 3.0, 100), (100, 1))
ax.plot(fine_inputs, resnet(resnet_params, fine_inputs, resnet_depth), lw=0.5, color='blue')
ax.set_xlabel('input')
ax.set_ylabel('output')

"""Building a neural ODE
Similar to a residual network, a neural ODE (or ODE-Net) takes a simple layer as a building block, and chains many copies of it together to buld a bigger model. In particular, our “base layer” is going to specify the dynamics of an ODE, and we’re going to chain the output of these base layers together according to the logic on an ODE solver.

We can easily build such a function by simply concatenating the state and current time, and sending that as the input to mlp:
"""

def nn_dynamics(state, time, params):
  state_and_time = jnp.hstack([state, jnp.array(time)])
  return mlp(params, state_and_time)

"""The remaining part of our model that we need to specify is how to combine evaluations of this dynamics layer. We could use any solver. JAX’s odeint function implements the standard adapative-step Dormand-Price solver."""

from jax.experimental.ode import odeint

def odenet(params, input):
  start_and_end_times = jnp.array([0.0, 1.0])
  init_state, final_state = odeint(nn_dynamics, input, start_and_end_times, params)
  return final_state

"""Without loss of generality, we can make the integration time go from 0 to 1.
That’s it! We’ve defined an ODE net. Below, we’ll talk a bit more about what’s happening inside of odeint, but for now, let’s hook it up to an optimizer and see if we can fit it to data!

Batching an ODE Net
To support batching (evaluating the ODE-Net on more than one training example) we can simply use Jax’s vmap function, which automatically adds batching dimensions. This transformation is non-trivial, since odeint contains while loops and control flow, but JAX can do it automatically. The vmapped odeint creates independent parallel solvers running in parallel on each batch element, waiting for the last one to finish before returning all the final states together. But it still combines the calls to the dynamics function into one efficient vectorized call shared across all batch elements.
In enviroments that don’t have vmap, typically what’s done is to create one giant ODE that combines the dynamics of every example in the batch, solve it all in one call to odeint, and then split up the results across the batch.
"""

from jax import vmap
batched_odenet = vmap(odenet, in_axes=(None, 0))

"""What remains is simply to initialize the parameters, hook up the model to the loss function, and train the ODE-Net:"""

# We need to change the input dimension to 2, to allow time-dependent dynamics.
odenet_layer_sizes = [2, 20, 1]

def odenet_loss(params, inputs, targets):
  preds = batched_odenet(params, inputs)
  return jnp.mean(jnp.sum((preds - targets)**2, axis=1))

@jit
def odenet_update(params, inputs, targets):
  grads = grad(odenet_loss)(params, inputs, targets)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

# Initialize and train ODE-Net.
odenet_params = init_random_params(param_scale, odenet_layer_sizes)

for i in range(train_iters):
  odenet_params = odenet_update(odenet_params, inputs, targets)

# Plot resulting model.
fig = plt.figure(figsize=(6, 4), dpi=150)
ax = fig.gca()
ax.scatter(inputs, targets, lw=0.5, color='green')
fine_inputs = jnp.reshape(jnp.linspace(-3.0, 3.0, 100), (100, 1))
ax.plot(fine_inputs, resnet(resnet_params, fine_inputs, resnet_depth), lw=0.5, color='blue')
ax.plot(fine_inputs, batched_odenet(odenet_params, fine_inputs), lw=0.5, color='red')
ax.set_xlabel('input')
ax.set_ylabel('output')
plt.legend(('Resnet predictions', 'ODE Net predictions'))

"""The two regression methods both match the data, but extrapolate slightly differently.

Activation trajectories
In a deep residual network, we can examine the activations between each block. In an ODE-Net, we can instead examing the activation trajectories as a function of depth:
"""

fig = plt.figure(figsize=(6, 4), dpi=150)
ax = fig.gca()

@jit
def odenet_times(params, input, times):
  def dynamics_func(state, time, params):
    return mlp(params, jnp.hstack([state, jnp.array(time)]))
  return odeint(dynamics_func, input, times, params)

times = jnp.linspace(0.0, 1.0, 200)

for i in fine_inputs:
  ax.plot(odenet_times(odenet_params, i, times), times, lw=0.5)

ax.set_xlabel('input / output')
ax.set_ylabel('time / depth')

"""In this toy setting where there is only one hidden unit, the trajectories can never cross each other, limiting the classes of functions that can be learned. However, this limitation can be overcome (if desired) by adding auxiliary dimensions to the network’s input that are discarded at the network output.

What form can the dynamics take?
There are a few restrictions to make the ODE solution well-defined and unique which we will discuss later. But in general it can be almost any tractable, differentiable, parametric function. In other words, odeint is a layer that takes in another layer to specify its dynamics function. This layer can be a fully-connected net, a convnet, U-net, or even some kinds of transformer!

Where can we use odeint layers?
The short answer is: anywhere you can use a residual net, you can use an ODENet. Both require the size of the input to be the same size as the output

Software
Here are a few of the more comprehensive toolkits that let one fit neural ODES:
TorchDiffEq - A PyTorch library purpose-build for building and fitting neural ODE models.Jax - A general-purpose numerical computing framework for Python, which includes a differentiable Dopri5 solver.JuliaDiffEq - A comprehensive suite of differential equation solvers for the Julia language.TorchDyn - A suite of model templates, tutorials, and application notebooks.

Stochastic and Partial Differential Equations
Besides ordinary differential equations, there are many other variants of differential equations that can be fit by gradients, and developing new model classes based on differential equations is an active research area. The solution of almost any type of differential equation can be seen as a layer!
Here are some pointers to recent work in these areas:
ICLR Workshop on Integration of Deep Neural Models and Differential EquationsScalable Gradients for Stochastic Differential EquationsLearning composable energy surrogates for PDE order reductionAmortized finite element analysis for fast PDE-constrained optimization
"""