View file src/j/dual.ijs - Download

echo 'Dual numbers for automatic differentiation'

load 'math/mt'


NB. Dual numbers are numbers of the form a + b * epsilon where epsilon^2 = 0
NB. The dual number a + b * epsilon can be represented by a matrix :
NB.  a 0
NB.  b a
NB. We will represent it by a boxed matrix :

dual1 =: 4 : 0
 < (x,0) ,: (y,x)
)

NB. Then we define operations on dual numbers by unboxing them, do the operation on matrices and box the result, with rank 0
dplus =: (4 : '< (>x) + (>y)')"0 0
dminus =: (4 : '< (>x) - (>y)')"0 0
dopp =: (3 : '< - >y')"0
dtimes =: (4 : '< (>x) +/ . * (>y)')"0 0
dinv =: (3 : '< %. >y')"0
ddiv =: (4 : '< (>x) +/ . * %. >y')"0 0
dexp =: (3 : '< geexp_mt_ >y')"0  NB. geexp_mt_ is matrix exponential
dsin =: (3 : '< ((geexp_mt_ 0j1 * >y) - (geexp_mt_ - 0j1 * >y)) % 2 * 0j1')"0
dcos =: (3 : '< ((geexp_mt_ 0j1 * >y) + (geexp_mt_ - 0j1 * >y)) % 2')"0

echo 'Example : (5 + 3 * epsilon) * (8 + epsilon) = 40 + 29 * epsilon'
echo (5 dual1 3) dtimes (8 dual1 1)

a =: (3 dual1 1) , (5 dual1 1)
echo 'Vector a = 3+epsilon 5+epsilon :'
echo a
echo 'Scalar product of a by itself = (3+epsilon)^2 + (5+epsilon)^2 = 34+16*epsilon :'
echo a dplus/ . dtimes a

NB. The Taylor series gives : f(x + epsilon) = f(x) + f'(x) * epsilon + 1/2 f''(x) * epsilon^2 + 1/3! f'''(x) * epsilon^3 + ...
NB. But with epsilon^2 = 0 it gives simply : f(x + epsilon) = f(x) + f'(x) * epsilon
NB. Example : 
NB. f(x) = x^2 + 3*x + 5  f(6) = 59
NB. f'(x) = 2*x + 3       f'(6) = 15
NB.                       f(6+epsilon) = 59+15*epsilon

f =: 3 : 0
 (y dtimes y) dplus ((3 dual1 0) dtimes y) dplus (5 dual1 0)
)

echo 'f(6+epsilon) = 59+15*epsilon'
echo f (6 dual1 1)   

NB. Generalization with several variables :
NB. a + b*epsilon + c*zeta + ...
NB. represented by matrix :
NB.  a 0 0 ...
NB.  b a 0 ...
NB.  c 0 a ...
NB.  ...

dual =: 4 : 0
 < (x * (i. 1+#y) =/ i. 1+#y) + |: (0,y),((#y),1+#y)$0
)

echo 'Example : 5 + 3*epsilon + zeta'
echo 5 dual 3 1

NB. Example : 
NB. g(x,y) = x*y + 3*x + 5*y  g(10,20) = 330 
NB. dg(x,y)/dx = y + 3        dg(10,20)/dx = 23
NB. dg(x,y)/dy = x + 5        dg(10,20)/dy = 15
NB.                           g(10+epsilon,20+zeta) = 330+23*epsilon+15*zeta

g =: 4 : 0
 (x dtimes y) dplus ((3 dual 0 0) dtimes x) dplus ((5 dual 0 0) dtimes y)
)

echo 'g(10+epsilon,20+zeta) = 330+23*epsilon+15*zeta'
echo (10 dual 1 0) g (20 dual 0 1)

value =: 3 : '< y * (i. 3) =/ i. 3'
epsilon =: < 0 0 0, 1 0 0, 0 0 0, 0 0 $ 0
zeta =: < 0 0 0, 0 0 0, 1 0 0, 0 0 $ 0

echo ((value 10) dplus epsilon) g ((value 20) dplus zeta)

m =: 4 : 0
 (x dtimes y) dminus dsin y
)

echo 'm(2+epsilon,3+zeta)'
echo (2 dual 1 0) m (3 dual 0 1)

NB. Generalization for second derivative 
NB. a + b*epsilon + c*epsilon^2 with epsilon^3 = 0
NB. represented by matrix :
NB.  a 0 0
NB.  b a 0
NB.  c b 0
NB. epsilon represented by matrix :
NB.  0 0 0
NB.  1 0 0
NB.  0 1 0

eps =: < 0 0 0, 1 0 0, 0 1 0, 0 0 $ 0

NB. Example :
NB. h(x) = x^3 + 3*x^2 + 5*x  h(10) = 1350
NB. h'(x) = 3*x^2 + 6*x + 5   h'(10) = 365
NB. h''(x) = 6*x + 6          h''(10) = 66
NB.                           h(10+epsilon) = 1350 + 365*epsilon + 1/2 66 * epsilon^2

h =: 3 : 0
 (y dtimes y dtimes y) dplus ((value 3) dtimes y dtimes y) dplus ((value 5) dtimes y)
)

echo 'h(10+epsilon) = 1350 + 365*epsilon + 1/2 66 * epsilon^2'
echo h (value 10) dplus eps

NB. Generalization with second derivatives and two variables
NB. a + b*epsilon + c*zeta + d*epsilon^2 + e*epsilon*zeta + f*zeta^2
NB. represented by matrix :
NB.  a 0 0 0 0 0
NB.  b a 0 0 0 0
NB.  c 0 a 0 0 0
NB.  d b 0 a 0 0
NB.  e c b 0 a 0
NB.  f 0 c 0 0 a

value =: 3 : '< y * (i. 6) =/ i. 6'

epsilon =: 0 0 $ 0
epsilon =: epsilon, 0 0 0 0 0 0  NB. 1
epsilon =: epsilon, 1 0 0 0 0 0  NB. epsilon
epsilon =: epsilon, 0 0 0 0 0 0  NB. zeta
epsilon =: epsilon, 0 1 0 0 0 0  NB. epsilon^2
epsilon =: epsilon, 0 0 1 0 0 0  NB. epsilon * zeta
epsilon =: epsilon, 0 0 0 0 0 0  NB. zeta^2
epsilon =: < epsilon

echo 'epsilon :'
echo epsilon

zeta =: 0 0 $ 0
zeta =: zeta, 0 0 0 0 0 0
zeta =: zeta, 0 0 0 0 0 0
zeta =: zeta, 1 0 0 0 0 0
zeta =: zeta, 0 0 0 0 0 0
zeta =: zeta, 0 1 0 0 0 0
zeta =: zeta, 0 0 1 0 0 0
zeta =: < zeta

echo 'zeta :'
echo zeta

echo 'epsilon^2 :'
echo epsilon dtimes epsilon

echo 'zeta^2 :'
echo zeta dtimes zeta

echo 'epsilon*zeta :'
echo epsilon dtimes zeta

NB. Example : 
NB. k(x,y) = 3*x^2 + 5*y^2 + 6*x*y + 8*x + 9*y  k(10,1) = 454
NB. dk(x,y)/dx = 6*x + 6*y + 8                  dk(10,1)/dx = 74
NB. dk(x,y)/dy = 10*y + 6*x + 9                 dk(x,y)/dy = 79
NB. d^2k(x,y)/dx^2 = 6
NB. d^2k(x,y)/dy^2 = 10
NB. d^2k(x,y)/dxdy = 6
NB. k(x+epsilon,y+zeta) = k(x,y) + dk(x,y)/dx*epsilon + dk(x,y)/dy*zeta + 1/2*d^2k(x,y)/dx^2*epsilon^2 + 1/2*d^2k(x,y)/dy^2*zeta^2 + d^2k(x,y)/dxdy*epsilon*zeta
NB. k(10+epsilon,1+zeta) = 454 + 74*epsilon + 79*zeta + 3*epsilon^2 + 5*zeta^2 + 6*epsilon*zeta 

k =: 4 : 0
 ((value 3) dtimes x dtimes x) dplus ((value 5) dtimes y dtimes y) dplus ((value 6) dtimes x dtimes y) dplus ((value 8) dtimes x) dplus ((value 9) dtimes y)
)

echo 'k(10+epsilon,1+zeta) = 454 + 74*epsilon + 79*zeta + 3*epsilon^2 + 5*zeta^2 + 6*epsilon*zeta'
echo ((value 10) dplus epsilon) k ((value 1) dplus zeta)