View file src/colab/ipl.py - Download

# -*- coding: utf-8 -*-
"""ipl.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1Y6150HtsF6d_UxIo4EH_kkKhbtvHqgQK

**Intelligent Proof Logic : Automated theorem proving powered by deep learning**

by Jacques Bailhache ( jacques.bailhache@gmail.com )

**The formal system : Proof Logic**

See http://log.chez.com/text/logic/tuto-pl-v8.html for a tutorial introduction to Proof Logic or http://log.chez.com/text/logic/tuto-pl-v8.html for a more complete presentation.

Proof Logic is a logic formalism inspired by combinatory logic and lambda calculus, using De Bruijn indices. The indices are written \*, '\*, ''\*, ... and λ x is written [x].

In Proof Logic, a proposition is an equality a = b, and everything is a proof. An expression is a proof of its equality with itself.

Succintely, a proof can be :
 - a symbol (which proves its equality with itself)
 - the initial index * (which proves * = *)
 - 'x where x is a proof (concretely an index)
 - \[x\] where x is a proof
 - the application x y where x and y are proofs : if x proves a = b and y proves c = d then x y proves a c = b d
 - x ; y where x and y are proofs is generalized transitivity : if there is one common term in equalities proved by x and y, then x ; y proves the equality of the remaining terms, for example if x proves a = b and y proves b = c then x ; y proves a = c.
 - an axiom x = y, which proves x = y (the operator "=" may be used only in axioms)

**Teaching a neural network to find the proof of a given equality**

Random proofs are generated and their conclusion (an equality a = b) are computed. The neural network is trained with the conclusion as input and the proof as output, i.e. it is trained to find the proof when the conclusion is given to it.

**Encoding of proofs for the neural network**

Inputs and outputs of the neural network are binary, so proofs must be encoded by binary sequences. They are first encoded by sequences of numbers, and then these sequences of numbers are encoded with one-hot encoding (the number n is encoded by a sequence of bits with a 1 at position n and 0s at other positions).
A proof can be represented by a "partial" binary tree : subtrees below symbols and \* are ignored, and right subtrees below unary operators ' and [] are ignored. For example, the binary tree corresponding to [*] a (the identity applied to the symbol a) is (where i means ignored) :

         apply
        /     \
       []      a
      /  \    / \
     *    i  i   i


Each position in this binary tree is assigned a number, which corresponds to the position in the sequence of numbers representing the proof where will be placed a number representing the node at this positions. The root of the tree is assigned the number 0. The position of the left subtree are assigned the odd numbers (1, 3, 5, ...) and the position of the right subtree are assigned the even numbers (2, 4, , ...). For each subtree, the first number (1 for the left and 2 for the right) is assigned to the root of the subtree, and the following numbers are assigned alternatively to the left and the right subtrees of the subtree, which give the following numbering :

          0
       /     \
      1       2
     / \     / \
    3   5   4   6

Numbers are assigned to the different nodes :
 - 0 : *
 - 1 : '
 - 2 : []
 - 3 : application
 - 4 : generalized transitivity
 - ...
 - n : first symbol
 - n+1 : second symbol
 - ...
 - p : first axiom
 - p+1 : second axiom
 - ...

 For example, for the proof [*] a, we have 3 at position 0, 2 at position 1, 6 at position 2 (if 6 represents the symbol a), 0 at position 3, the numbers at the other positions are not significant. It gives the sequence of numbers ⁉

 [3, 2, 6, 0, ...]
"""

# Proof Logic
# Proof =
#  - ('SMB', s)
#  - ('ANY')
#  - ('VAR')
#  - ('NXV', x)
#  - ('FNC', x)
#  - ('APL', x, y)
#  - ('GTR', x, y)
#  - ('EQU', x, y)
#  - ('RLR', x, y)
#  - ('RED', x, y)
#  (...)

full_red = False
orig_reduce1 = False

def equal(x, y):
    if x == y: return True
    if not x: return False
    if not y: return False
    if len(x) != len(y): return False
    for i in range(len(x)):
        if x[i] != y[i]: return False
    return True

def depth(x):
  if (not x): return 0
  if x[0] == 'SMB' or x[0] == 'ANY' or x[0] == 'VAR' or x[0] == 'AXM': return 0
  if x[0] == 'NXV' or x[0] == 'FNC' or x[0] == 'RLR' or x[0] == 'RED': return 1 + depth(x[1])
  if x[0] == 'APL' or x[0] == 'GTR' or x[0] == 'EQU':
    d1 = depth(x[1])
    d2 = depth(x[2])
    if d1 >= d2: return 1 + d1
    return 1 + d2

def shift(u, x):
    if equal(x, u):
        return ('NXV', x)
    if (not x):
        return x
    if x[0] == 'SMB': return x
    if x[0] == 'ANY': return x
    if x[0] == 'VAR': return x
    if x[0] == 'FNC': return ('FNC', shift(('NXV', u), x[1]))
    if (len(x) < 3): return (x[0], shift(u, x[1]))
    return (x[0], shift(u, x[1]), shift(u, x[2]))

def subst(u, a, b):
    if equal(u, a): return b
    if (not a): return a
    if a[0] == 'NXV' and equal(a[1], u): return u
    if a[0] == 'SMB' or a[0] == 'VAR' or a[0] == 'ANY': return a
    if a[0] == 'FNC': return ('FNC', subst(('NXV', u), a[1], shift(('VAR',), b)))
    if len(a) < 3: return (a[0], subst(u, a[1], b))
    return (a[0], subst(u, a[1], b), subst(u, a[2], b))

# x contains y ?
def cont(x, y):
    if equal(x, y): return True
    if not x: return False
    if x[0] == 'SMB' or x[0] == 'VAR' or x[0] == 'ANY': return False
    if len(x) < 3: return cont(x[1], y)
    return cont(x[1], y) or cont(x[2], y)

def contsp(x, y):
    if cont(x, y): return True
    if full_red: return False
    if not x: return False
    if not y: return False
    if x[0] != y[0]: return False
    if len(x) < 2:
        x1 = None
        x2 = None
    elif len(x) == 2:
        x1 = x[1]
        x2 = None
    else:
        x1 = x[1]
        x2 = x[2]
    if len(y) < 2:
        y1 = None
        y2 = None
    elif len(y) == 2:
        y1 = y[1]
        y2 = None
    else:
        y1 = y[1]
        y2 = y[2]
    return contsp(x1, y1) and contsp(x2, y2)


def reduce1(x):
    if x == None: return x
    if x[0] == 'SMB': return x
    if x[0] == 'ANY': return x
    if x[0] == 'VAR': return x
    if x[0] == 'APL' and x[1][0] == 'FNC': return subst(('VAR',), x[1][1], x[2])
    if x[0] == 'FNC' and x[1][0] == 'APL' and x[1][2][0] == 'VAR': return x[1][1]
    if len(x) < 3: return (x[0], reduce1(x[1]))
    return (x[0], reduce1(x[1]), reduce1(x[2]))

MAX = 10000

def reduce2(x):
    if not x: return x
    a = []
    n = 0
    y = x
    while True:
        n += 1
        a.append(y)
        z = reduce1(y)
        if n >= MAX: break
        found = False
        for a1 in a:
            if contsp(z, a1):
                found = True
                break
        if found: return y
        y = z
    return z

def reduce(x):
    return reduce2(x)

def eq(x, y):
    return equal(x, y)

def eqr(x, y):
    return eq(x, y) or eq(x, reduce(y)) or eq(reduce(x), y) or eq(reduce(x), reduce(y))

def sides(x):
    if not x: return (None, None)
    if x[0] == 'SMB' or x[0] == 'ANY' or x[0] == 'VAR': return (x, x)
    if x[0] == 'EQU': return (x[1], x[2])
    if x[0] == 'RED': return sides(reduce(x[1]))
    if x[0] == 'RLR':
        lr = sides(x[1])
        return (reduce(lr[0]), reduce(lr[1]))
    if x[0] == 'GTR':
        lr1 = sides(x[1])
        l1 = lr1[0]
        r1 = lr1[1]
        lr2 = sides(x[2])
        l2 = lr2[0]
        r2 = lr2[1]
        if eqr(l1, l2): return (r1, r2)
        if eqr(r1, l2): return (l1, r2)
        if eqr(l1, r2): return (r1, l2)
        if eqr(r1, r2): return (l1, l2)
        return (x, x)
    lr1 = sides(x[1])
    if len(x) < 3:
        return ((x[0], lr1[0]), (x[0], lr1[1]))
    lr2 = sides(x[2])
    return ((x[0], lr1[0], lr2[0]), (x[0], lr1[1], lr2[1]))

def left(x):
    return sides(x)[0]

def right(x):
    return sides(x)[1]

def abstr(u, v, x):
    if equal(v, x): return u
    if not cont(x, v): return x
    if not x: return x
    if x[0] == 'FNC': return ('FNC', abstr(('NXV', u), v, x[1]))
    if len(x) < 2:
        return (x[0],)
    if len(x) == 2:
        return (x[0], abstr(u, v, x[1]))
    return (x[0], abstr(u, v, x[1]), abstr(u, v, x[2]))

def pllambda(v, x):
    if x[0] == 'APL' and equal(x[2], v) and not cont(x[1], v): return x[1]
    return ('FNC', abstr(('VAR',), v, x))


def string_of_proof(x):
    if x == None: return 'None'
    if x[0] == 'SMB': return x[1]
    if x[0] == 'ANY': return '_'
    if x[0] == 'VAR': return '*'
    if x[0] == 'NXV': return "'" + string_of_proof(x[1])
    if x[0] == 'FNC': return '[' + string_of_proof(x[1]) + ']'
    if x[0] == 'APL': return '(' + string_of_proof(x[1]) + ' ' + string_of_proof(x[2]) + ')'
    if x[0] == 'GTR': return '(' + string_of_proof(x[1]) + ' ; ' + string_of_proof(x[2]) + ')'
    if x[0] == 'EQU': return '(' + string_of_proof(x[1]) + ' = ' + string_of_proof(x[2]) + ')'
    if x[0] == 'RED': return '@' + string_of_proof(x[1])
    if x[0] == 'RLR': return '$' + string_of_proof(x[1])
    if x[0] == 'AXM': return f"AXM{x[1]}"
    return '?'


def deduce(x):
    concl = sides(x)
    print('The proof')
    print('    ' + string_of_proof(x))
    print('proves')
    print('    ' + string_of_proof(concl[0]))
    print('equals')
    print('    ' + string_of_proof(concl[1]))
    return concl

DI = ('EQU', ('SMB', 'I'), pllambda(('SMB', 'a'), ('SMB', 'a')))
DK = ('EQU', ('SMB', 'K'), pllambda(('SMB', 'a'), pllambda(('SMB', 'b'), ('SMB', 'a'))))
DS = ('EQU', ('SMB', 'S'), pllambda(('SMB', 'a'), pllambda(('SMB', 'b'), pllambda(('SMB', 'c'), ('APL', ('APL', ('SMB', 'a'), ('SMB', 'c')), ('APL', ('SMB', 'b'), ('SMB', 'c')))))))
SKKI = ('GTR', ('APL', ('APL', DS, DK), DK), DI)
sides(SKKI)
deduce(SKKI)

ext = ('GTR', ('APL', ('APL', DS, ('APL', DK, ('SMB', 'x'))), ('APL', DK, ('SMB', 'y'))), ('APL', DK, ('APL', ('SMB', 'x'), ('SMB', 'y'))))
print(sides(ext))
deduce(ext)


def ap(f, x):
    return ('APL', f, x)

def ap2(f, x, y):
    return ap(ap(f, x), y)

def ap3(f, x, y, z):
    return ap(ap(ap(f, x), y), z)

def gtr(x, y):
    return ('GTR', x, y)

def equ(x, y):
    return ('EQU', x, y)

def rlr(x):
    return ('RLR', x)


I = ('FNC', ('VAR',))

x = ('SMB', 'x')
y = ('SMB', 'y')
z = ('SMB', 'z')
parent = ('SMB', 'parent')
grandparent = ('SMB', 'grandparent')
allan = ('SMB', 'allan')
brenda = ('SMB', 'brenda')
charles = ('SMB', 'charles')

gpRule1 = pllambda(x, pllambda(y, pllambda(z, equ(ap(ap2(parent, x, y), ap(ap2(parent, y, z), ap2(grandparent, x, z))), I))))
gpAxiom1 = equ(ap2(parent, allan, brenda), I)
gpAxiom2 = equ(ap2(parent, brenda, charles), I)
gpTheorem1 = rlr(gtr (ap2(grandparent,allan,charles),
                  gtr (ap(gpAxiom2,ap2(grandparent,allan,charles)),
                   gtr (ap(gpAxiom1,ap(ap2(parent,brenda,charles),ap2(grandparent,allan,charles))),
                    ap3(gpRule1,allan,brenda,charles)))))

print(sides(gpTheorem1))
deduce(gpTheorem1)

axm1 = ('EQU', ('SMB', 'a'), ('SMB', 'b'))
axm2 = ('EQU', ('SMB', 'b'), ('SMB', 'c'))

# ops = ['ANY', 'VAR', 'NXV', 'FNC', 'APL', 'GTR', 'RLR', 'RED']
ops = ['VAR', 'NXV', 'FNC', 'APL', 'GTR', 'RLR', 'RED']

#smbs = ['a', 'b', 'c']
#axms = [DI, DK, DS]

smbs = ['a', 'b', 'c']
axms = [axm1, axm2]


# def num_op(op):
#     if op in ops: return ops.index(op) + 1
#     return 0

arities = {
    'SMB': 0,
    'ANY': 0,
    'VAR': 0,
    'NXV': 1,
    'FNC': 1,
    'APL': 2,
    'GTR': 2,
    'EQU': 2,
    'RLR': 1,
    'RED': 1,
    'AXM': 0,
}

initial_node = 0

def dict_of_proof_1(x, dict, start, step):
    if x[0] == 'SMB':
        if x[1] in smbs:
            dict[start] = initial_node + len(ops) + smbs.index(x[1])
        else:
            dict[start] = 0
    elif x[0] == 'AXM':
        dict[start] = initial_node + len(ops) + len(smbs) + x[1]
    elif x[0] in ops:
        dict[start] = initial_node + ops.index(x[0])
    else:
        dict[start] = 0
    if arities[x[0]] > 0:
        dict_of_proof_1(x[1], dict, start+step, step*2)
    if arities[x[0]] > 1:
        dict_of_proof_1(x[2], dict, start+2*step, step*2)

def dict_of_proof(x):
    dict = {}
    dict_of_proof_1(x, dict, 0, 1)
    return dict

def proof_of_dict_1(dict, start, step):
    numop = dict[start]
    if numop < initial_node or numop >= initial_node+len(ops)+len(smbs)+len(axms):
        return None
    if numop >= initial_node+len(ops)+len(smbs):
        # return ('AXM', numop-1-len(ops)-len(smbs)
        return axms[numop-initial_node-len(ops)-len(smbs)]
    if numop >= initial_node+len(ops):
        return ('SMB', smbs[numop-initial_node-len(ops)])
    op = ops[numop-initial_node]
    if arities[op] < 1:
        return (op,)
    y = proof_of_dict_1(dict, start+step, step*2)
    if arities[op] < 2:
        return (op, y)
    z = proof_of_dict_1(dict, start+2*step, step*2)
    return (op, y, z)

def proof_of_dict(dict):
    return proof_of_dict_1(dict, 0, 1)

print(proof_of_dict(dict_of_proof(('APL',('APL',('VAR',),('VAR',)),('VAR',)))))
print(proof_of_dict(dict_of_proof(('AXM',1))))

import random

random.seed(32)

def random_proof():
    opnum = random.randrange(len(ops)+len(smbs)+len(axms))
    if opnum >= len(ops)+len(smbs):
        return ('AXM', opnum-len(ops)-len(smbs))
    if opnum >= len(ops):
        return ('SMB', smbs[opnum-len(ops)])
    op = ops[opnum]
    if arities[op] < 1:
        return (op,)
    y = random_proof()
    if arities[op] < 2:
        return (op, y)
    z = random_proof()
    return (op, y, z)

def subst_axm(x):
    op = x[0]
    if op == 'AXM':
        return axms[x[1]]
    if op == 'SMB':
        return x
    if arities[op] < 1:
        return (op,)
    y = subst_axm(x[1])
    if arities[op] < 2:
        return (op, y)
    z = subst_axm(x[2])
    return (op, y, z)

def prove_rnd(a, b):
  #print("Random proofs:")
  random.seed(32)
  for i in range(100000):
    x = random_proof()
    y = subst_axm(x)
    concl = sides(y)
    if equal(concl[0], a) and equal(concl[1], b):
        print("---------------------------------------------------------------------")
        print(i)
        print(x)
        print(y)
        deduce(y)
        break

a = ('SMB', 'a')
b = ('SMB', 'b')
c = ('SMB', 'c')

# prove_rnd(('SMB', 'a'), ('SMB', 'a'))
# prove_rnd(('SMB', 'b'), ('SMB', 'b'))
# prove_rnd(('SMB', 'c'), ('SMB', 'c'))
# prove_rnd(('SMB', 'a'), ('SMB', 'b'))
# prove_rnd(('SMB', 'b'), ('SMB', 'a'))
# prove_rnd(('SMB', 'b'), ('SMB', 'c'))
# prove_rnd(('SMB', 'c'), ('SMB', 'b'))
# prove_rnd(('SMB', 'a'), ('SMB', 'c'))
# prove_rnd(('SMB', 'c'), ('SMB', 'a'))

# prove_rnd(ap(a, b), ap(a, c))
# prove_rnd(ap(a, c), ap(c, a))
# prove_rnd(ap(b, a), ap(b, c))
# prove_rnd(ap(b, c), ap(b, a))


def dict_of_list(l):
    dict = {}
    for i in range(len(l)):
        dict[i] = l[i]
    return dict

def list_of_dict(dict, n=0):
    if n == 0:
      m = max(dict.keys())
    else:
      m = n-1
    l = [0 for i in range(m+1)]
    for i, x in dict.items():
        l[i] = x
    return l


def proofs_of_size(n):
    proofs = []
    for op in ops:
        if arities[op] == 0:
            proofs.append((op,))
    for s in smbs:
        proofs.append(('SMB', s))
    for i in range(len(axms)):
        proofs.append(('AXM', i))
    if n > 0:
        proofs1 = proofs_of_size(n-1)
        for x in proofs1:
            for op in ops:
                if arities[op] == 1:
                    proofs.append((op, x))
            for y in proofs1:
                for op in ops:
                    if arities[op] == 2:
                        proofs.append((op, x, y))
    return proofs

for x in proofs_of_size(1):
    print(string_of_proof(x))

def f2(p2, y):
    f = p2[0]
    p = p2[1]
    x = p2[2]
    for op in ops:
        if arities[op] == 2:
            f(p, (op, x, y))

def f1(p1, x):
    f = p1[0]
    p = p1[1]
    n = p1[2]
    for op in ops:
        if arities[op] == 1:
            f(p, (op, x))
    treat_proofs_of_size(n, f2, (f, p, x))

def treat_proofs_of_size(n, f, p):
    for op in ops:
        if arities[op] == 0:
            f(p, (op,))
    for s in smbs:
        f(p, ('SMB', s))
    for i in range(len(axms)):
        f(p, ('AXM', i))
    if n > 0:
        treat_proofs_of_size(n-1, f1, (f, p, n-1))

def f(p, x):
    print(string_of_proof(x))

print("")
treat_proofs_of_size(1, f, 0)

l = []
def fa(p, x):
    l.append(x)

treat_proofs_of_size(1, fa, 0)

print("")
for x in l:
    print(string_of_proof(x))

import torch

"""Generate target_nproofs proofs and store them and their conclusions."""

random.seed(32)
mx = 0
ma = 0
mb = 0

n = 0
target_nproofs = 10000
nproofs = 0
maxdepth = 4
maxlen = 2**(maxdepth+1)-1
nnodes = initial_node+len(ops)+len(smbs)+len(axms)
print(f"maxlen={maxlen} nnodes={nnodes}")

lproofs = []
lconclusions = []

bproofs = torch.LongTensor([])
bconclusions = torch.FloatTensor([])

for i in range(20000):
  x = random_proof()
  if depth(x) <= maxdepth:
    y = subst_axm(x)
    concl = sides(y)
    a = concl[0]
    b = concl[1]

    dx = dict_of_proof(x)
    lx = list_of_dict(dx)

    da = dict_of_proof(a)
    la = list_of_dict(da)

    db = dict_of_proof(b)
    lb = list_of_dict(db)

    if len(lx) <= maxlen and len(la) <= maxlen and len(lb) <= maxlen:
      lx = list_of_dict(dx, maxlen)
      la = list_of_dict(da, maxlen)
      lb = list_of_dict(db, maxlen)
      lc = la + lb
      #print(f"lx:{len(lx)} lc:{len(lc)}")
      lproofs.append(lx)
      lconclusions.append(lc)
      nproofs += 1
      if nproofs % 1000 == 0:
        print(f"nproofs={nproofs}")

      tx = torch.LongTensor(lx)
      bx = torch.nn.functional.one_hot(tx,nnodes)
      fx = torch.flatten(bx)
      bproofs = torch.cat((bproofs, fx.unsqueeze(0)))

      tc = torch.LongTensor(lc)
      bc = torch.nn.functional.one_hot(tc, nnodes)
      fc = torch.flatten(bc)
      bconclusions = torch.cat((bconclusions, fc.unsqueeze(0)))

      if nproofs >= target_nproofs: break

print(f"nproofs={nproofs} i={i}")
print(bproofs.shape)
print(bconclusions.shape)
print(bproofs[0])
print(bconclusions[0])

from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torch.utils.data import TensorDataset

import numpy

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

print(len(bconclusions))
print(len(bproofs))

"""Extract parts of stored conclusions and proofs for training and test."""

train_bconclusions = bconclusions[0:8000]
test_bconclusions = bconclusions[8000:9000]
train_bproofs = bproofs[0:8000]
test_bproofs = bproofs[8000:9000]

# train_bconclusions = bconclusions[0:16000]
# test_bconclusions = bconclusions[16000:17000]
# train_bproofs = bproofs[0:16000]
# test_bproofs = bproofs[16000:17000]

"""Create the network, train it and test it."""

training_data = TensorDataset(torch.tensor(train_bconclusions,dtype=torch.float32), torch.tensor(train_bproofs,dtype=torch.long))
test_data = TensorDataset(torch.tensor(test_bconclusions,dtype=torch.float32), torch.tensor(test_bproofs,dtype=torch.long))

batch_size = 8

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

#for X, y in test_dataloader:
#    print(f"Shape of X [N, C, H, W]: {X.shape}")
#    print(f"Shape of y: {y.shape} {y.dtype}")
#    break

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

nnil = 4096

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(2*maxlen*nnodes, nnil),
            nn.ReLU(),
            nn.Linear(nnil, nnil),
            nn.ReLU(),
            # nn.Linear(nnil, nnil),
            # nn.ReLU(),
            nn.Linear(nnil, maxlen*nnodes)
        )

    def forward(self, x):
        # x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

def loss_fn(output, target):
    loss = torch.sum((output - target) ** 2)
    return loss

optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn, epochs, epoch):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            # print(f"X:{X.shape}")
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            #print(f"y:{y.shape}")
            #print(f"pred:{pred.shape}")
            # y1 = torch.matmul(y.float(), torch.FloatTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).to(device)).long()
            # correct += (pred.argmax(1) == y1).type(torch.float32).sum().item()
            for i in range(batch_size):
              p = pred[i]
              unflatp = torch.unflatten(p, 0, (maxlen, nnodes))
              decodedp = torch.argmax(unflatp, dim=1)
              lp = decodedp.tolist()
              dp = dict_of_list(lp)
              xp = proof_of_dict(dp)
              cp = sides(xp)
              t = y[i]
              unflatt = torch.unflatten(t, 0, (maxlen, nnodes))
              decodedt = torch.argmax(unflatt, dim=1)
              lt = decodedt.tolist()
              dt = dict_of_list(lt)
              xt = proof_of_dict(dt)
              ct = sides(xt)
              # if equal(xp, xt):
              if equal(cp[0], ct[0]) and equal(cp[1], ct[1]):
                correct += 1
              else:
                if epoch == epochs-1:
                  # print(f"Target: {string_of_proof(xt)} ; Predicted: {string_of_proof(xp)}")
                  print("Target:")
                  deduce(xt)
                  print("Predicted:")
                  deduce(xp)
    test_loss /= num_batches
    # correct /= size
    #print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    print(f"Test Error: \n {correct} correct results, Avg loss: {test_loss:>8f} \n")

epochs = 40
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn, epochs, t)
print("Done!")

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))

sa = ('SMB', 'a')
sb = ('SMB', 'b')
sc = ('SMB', 'c')

def fnc(x): return ('FNC', x)

"""Ask the network to search for a proof of a = b."""

def prove(a, b):

  da = dict_of_proof(a)
  la = list_of_dict(da, maxlen)

  db = dict_of_proof(b)
  lb = list_of_dict(db, maxlen)

  lc = la + lb
  tc = torch.LongTensor(lc)
  bc = torch.nn.functional.one_hot(tc, nnodes)
  fc = torch.flatten(bc)
  uc = fc.unsqueeze(0)

  input = torch.FloatTensor([])
  input = torch.cat((input, uc))
  print(input.shape)

  output = model(input)
  print(output.shape)
  # print(output[0])

  unflat = torch.unflatten(output[0], 0, (maxlen, nnodes))
  print(unflat.shape)
  # print(unflat)
  decoded = torch.argmax(unflat, dim=1)
  print(decoded)

  lx = decoded.tolist()
  print(lx)
  dx = dict_of_list(lx)
  print(dx)
  x = proof_of_dict(dx)
  print(x)
  print(string_of_proof(x))
  deduce(x)
  print("")

"""Search for some proofs."""

prove(sa, sa)
prove(sa, sb)
prove(sb, sa)
prove(sa, sc)
prove(ap(sa, sb), ap(sb, sc))