Getting started with proving math theorems through reinforcement learning

An experiment at MIT's Brains, Minds, and Machines Lab

September 2019

In February 2019, I was looking for undergraduate research opportunities in mathematics at MIT and soon came across one that caught my eye.

Project Title: Creating an environment to challenge AI systems in mathematical reasoning

Project Description: Artificial neural networks are showing great promise at processing sensory information, finding patterns in data or playing games, so it seems natural to ask whether they can be applied to the uniquely human discipline of asking and answering mathematical questions. Overcoming the challenges of this problem will provide a clearer path towards other high cognitive tasks. The main objective of this project will be to create a mathematical environment to test AI agents at proving mathematical theorems, as well as designing a simple agent showcasing the effectiveness of the environment.

I went to talk to the MIT postdoctoral associate in charge of the project -- Dr. Andrzej Banburski. He explained that the Brains, Minds, and Machines lab, which focuses on studying the brain's intelligent behavior and replicating that behavior in machines, had yet to study how humans do mathematics and possibly replicating it.

Since the lab had a limited experience in the fields of theorem proving and reinforcement learning, we were starting from scratch, which means we had to answer a lot of fundamental questions.

Stage 1: Philosophical debates

I, along with the other undergraduate researchers on the project, spent hours on the blackboard each week debating fundamental questions of creating an reinforcement learning theorem prover for mathematics:

  • What areas of math might be most suited to being solved through reinforcement learning?

  • Do we need to prove that the learning environment is well defined and consistent? For an agent playing chess, it is obvious that when we take each action we stay within a set of defined states. In math, should we prove that we can stay within a set of finite states given certain actions? If not, is it ok if the set of states could become infinite?

  • What does a "successful agent" look like? Is it one that we only give a few proof techniques to, and it knows which to apply in which order? One that can come up with its own proof technique to solve a theorem? One that can propose its own theorems? One that can successfully apply complex proof techniques (e.g. the probabilistic method)? One that can solve a previously unsolved theorem in math?

  • How do we model the state of the learning environment?

  • What should the set of actions the agent takes be?

  • When should we reward the agent?

  • Do we want to train the agent on human theorem-proving data, or do we want it to learn by self-play?

We spent much of the first few weeks of the project proposing and tackling many more fun philosophical questions. But we decided perhaps before finalizing our answers to these questions, we should start implementing the framework so it is clear what is feasible and not.

Stage 2: Working on the Prover

Due to my longstanding infatuation with Mathematica, I persuaded the group to try out Mathematica's theorem proving library to take care of the formal proving aspects of the project (we already knew we wanted to use Python for the machine learning side).

Unfortunately, building up group theory infrastructure as well as prover infrastructure in Mathematica proved tedious. After weeks, we still weren't making progress, and were considering switching to a language such as Coq that is more equipped for theorem proving.

Our supervisor decided to hold a competition: one of us needed to prove the cancellation property in Mathematica, and the other needed to prove it in Coq. Whoever came up with a solution first, won the competition, and determined the fate of the project.

Coq won.

Not all was lost with our attempt to use Mathematica for theorem proving: we did make some interesting discoveries along the way. One such discovery was that Mathematica comes equipped with a FindEquationalProof tactic that does automatically prove theorems. For example, here, we asked Mathematica to prove that for a and f in a group, a*f=a implies that f is the identity element.
In[1]: groupTheory = {ForAll[{a, b, c}, g[a, g[b, c]] == g[g[a, b], c]], ForAll[a, g[a, e] == a], ForAll[a, g[a, inv[a]] == e]}In[2]: proof = FindEquationalProof[e == f, Append[groupTheory, ForAll[a, g[a, f] == a]]]In[3]: proof["ProofDataset"]
Unfortunately, the proofs produced by this tactic weren't quite as readable as we would have liked (see table on right). We decided to make it a priority to make the proofs our agent produced as readable as possible.

We used GamePad's Coq-to-Python theorem proving interface to send Coq commands to Python. Since none of us had previous experience with Coq, we spent a lot of late nights in the lab, eating pizza, and having fun with type errors. After many such nights, one of us finally came up with an implementation of the problem in group theory that we wanted our agent to be able to solve: cancellation law.

Require Export Ensembles.Section group.Variable U:Type.Record group := { S : Ensemble U; id : U; op : U -> U -> U; inv : U -> U; ... }.End group.
Section theorem. ...(* Let a, b, c be elements of a group G whose law of composition is written multiplicatively. If ab = ac or if ba = ca, then b = c. *)Let H1 := In U (S G) x0. Let H2 := In U (S G) x1. Let H3 := In U (S G) x2. Let H4a := op G x0 x1 = op G x0 x2. Let H4b := op G x1 x0 = op G x2 x0. Let H4 := H4a \/ H4b.Theorem cancellation_law: H1 -> H2 -> H3 -> H4 -> x1 = x2. Proof. intros D1. intros D2. intros D3. intros D4. ...Qed.
End theorem.

Stage 3: Working on the Learner (States, Actions, and Rewards)

Before implementing a reinforcement learning algorithm, we had to decide on the states, actions, and rewards of the algorithm.


We decided that the state of the theorem proving environment most logically includes the theorem's assumptions, theorem's variables, and steps taken in the proof so far.

The more complicated question was -- how do we encapsulate all of those things in one object by which to store the state in a Q-table? We decided to just store the state object as a hash string that depends only on the steps taken in the proof so far, and use that hash string for Q-table lookup, and hope for the best.

class State: def __init__(self, variables={}, assumptions={}, goals={}): self.done = False self.variables = variables self.assumptions = assumptions self.goals = goals self.past_actions = [] def __hash__(self): return hash(tuple(self.past_actions))

Since we didn't have a pre-determined set of states (and in fact, the set of states is potentially infinite, since it includes all possible permutations of the proofs the agent could generate), we implemented a Q-table that expands as the agent takes on new states.

class QTable():...def update(self, state, next_state, action, reward, alpha, gamma): """ Update qtable based on reward from last action""" self.ensure_state_in_qtable(state) # add state to qtable if not already there self.ensure_state_in_qtable(next_state) # add state to qtable if not already there old_value = self.q_table[state][action] next_max = max(self.q_table[next_state].values()) #get the maximum q_table reward from next state new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max) self.q_table[state][action] = new_value


The only actions required to prove our test theorem, cancellation law, were multiplying by inverses, applying associativity, applying the left inverse property of groups, and applying the identity property of groups. We also decided to add an action that will allow the agent to undo the previous action.

Each "action" in our code often consisted of multiple Coq tactics. For example, the left-multiply action was implemented as follows.

class Action: def __init__(self): pass def to_coq(self): pass
class LeftMultiply(Action): def __init__(self, in_eqn, out_eqn, expr, state): self.in_eqn = state.assumptions[in_eqn] self.out_eqn = out_eqn self.expr = state.variables[expr] def to_coq(self): return "assert (op G ({}) ({}) = op G ({}) ({})) as {}. f_equal. assumption.".format(self.expr, self.in_eqn.left.to_coq(), self.expr, self.in_eqn.right.to_coq(), self.out_eqn)

Reward Shaping

Keeping with the style of setting up a minimalistic experiment, we decided to keep the agent's rewards as simple as possible as well:

  • +1 for finishing the theorem

  • -10 for attempting to apply a tactic that yields a Coq error (e.g. applying associativity when it is not applicable to the goal).

We were shocked when we started to train the agent and found that it decided the best course of action was to repeatedly apply the tactic "Undo."

Then we realized why. The action least likely to throw a Coq error was Undo -- it depended little on the current state of the goal, unlike actions such as ApplyAssociativity or ApplyIdentity.

Our agent was behaving exactly as we had incentivized it to -- to maximize rewards by choosing the least-risky actions. And so we amended the rewards system so the agent only undoes its progress when absolutely necessary:

  • +1 for finishing the theorem

  • -10 for attempting to apply a tactic that yields a Coq error

  • -1000 for applying Undo.

Results & Takeaways

We decided to train the agent in a manner following the OpenAI gym implementation.

def train(self, episodes=5, alpha=0.1, gamma=0.6, epsilon=0.1): env = self.env qtable = self.qtable # Run episodes for _ in range(episodes): state = env.reset() epochs, penalties, reward, = 0, 0, 0 while not env.state.done: # ------------------------------------ # Choose to explore or exploit # ------------------------------------ if np.random.uniform(0, 1) < epsilon: # Explore action space action = qtable.get_random_action() else: # Exploit the action space action = qtable.get_recommended_action(state) next_state, reward, done = env.step(action) # will return error and undo, if unsuccessful # ------------------------------------ # See if we're done with the proof # ------------------------------------ env.step(VerifyGoalReached()) # will return error and undo, if unsuccessful # ------------------------------------ # Update the qtable # ------------------------------------ qtable.update(state, next_state, action, reward, alpha, gamma) if reward < 0: penalties += 1 state = next_state epochs += 1 print("Proof generated:", env.state.past_actions)

Indeed, through training, our agent eventually did converge at correct proofs.

Proof generated: [LeftMultiplyByInverse, ApplyAssociativity, ApplyAssociativity, ApplyLeftInverseProperty, ApplyLeftInverseProperty, ApplyIdentityProperty, ApplyIdentityProperty, VerifyGoalReached]

We decided to define a "successful agent" the simplest way possible: given only one theorem to prove, the agent should prove that theorem with less attempts after training than with before training.

# ----------------------------------------------------------------------------------# Evaluate agent when q-table is empty (should be pretty bad)# ----------------------------------------------------------------------------------a = Agent(env)episodes, total_epochs, total_penalties = a.evaluate(episodes=5)print(f"BEFORE TRAINING:")print(f"\tAverage timesteps per episode: {total_epochs / episodes}")print(f"\tAverage penalties per episode: {total_penalties / episodes}")
# # ----------------------------------------------------------------------------------# # Train agent by filling out the q-table# # ----------------------------------------------------------------------------------
print("\nTRAINING ON 5 EPISODES ...")a.train(episodes=5)print("Training finished.\n")# print(a.qtable)
# # ----------------------------------------------------------------------------------# # Evaluate how well agent was trained by evaluating how well it performs with new qtable# # ----------------------------------------------------------------------------------episodes, total_epochs, total_penalties = a.evaluate(episodes=5)print(f"AFTER TRAINING:")print(f"\tAverage timesteps per episode: {total_epochs / episodes}")print(f"\tAverage penalties per episode: {total_penalties / episodes}")

The results came back as expected! After training on only 5 episodes, the agent became significantly faster, and made significantly fewer attempts to apply error-yielding Coq tactics.

BEFORE TRAINING: Average timesteps per episode: 27.2 Average penalties per episode: 20.2
TRAINING ON 5 EPISODES...Training finished.
AFTER TRAINING: Average timesteps per episode: 17.2 Average penalties per episode: 10.2

These results seem fairly obvious -- since the agent was only given one theorem to prove, and a list of necessary tactics, the agent could always use brute force to find the right sequence of steps. And so combined with reinforcement learning, the problem of speeding up the solving became one of memorization of a particular sequence, and thus this experiment should have yielded such results. Still, to a team of us who started the semester completely unsure about the capacity of reinforcement learning algorithms to prove theorems, the obvious results were actually quite a relief.

And we were, in fact, surprised by a few of our results, including that:

  • The agent learned without using neural networks.

  • The Q-table functioned correctly despite the fact that the states consisted of hash strings (with no direct meaning in themselves, besides the fact that they essentially indexed a certain list of previous actions).

  • Most interesting of all, after just a few seconds with Coq, the agent seemed to generate cleaner Coq proofs than we had over the past few months.

Our human-generated proof:

Theorem cancellation_law: H1 -> H2 -> H3 -> H4 -> x1 = x2. Proof. intros D1. intros D2. intros D3. intros D4.destruct D4 as [D4a|D4b].assert (op G (inv G x0) (op G x0 x1) = op G (inv G x0) (op G x0 x2)) as H5. f_equal. assumption.rewrite (op_assoc G) in H5. rewrite (op_inv_l G) in H5. 2:assumption.rewrite (op_id_l G) in H5. 2:assumption.rewrite (op_assoc G) in H5. rewrite (op_inv_l G) in H5. 2:assumption.rewrite (op_id_l G) in H5. 2:assumption.assumption. assert (op G (op G x1 x0) (inv G x0) = op G (op G x2 x0) (inv G x0)) as H6. f_equal. assumption.rewrite <- (op_assoc G) in H6.rewrite (op_inv_r G) in H6. 2:assumption.rewrite (op_id_r G) in H6. 2:assumption.rewrite <- (op_assoc G) in H6. rewrite (op_inv_r G) in H6. 2:assumption.rewrite (op_id_r G) in H6. 2:assumption.assumption. Qed.

Our agent-generated proof:

Theorem cancellation_law: H1 -> H2 -> H3 -> H4 -> x1 = x2. Proof. intros D1. intros D2. intros D3. intros D4. assert (op G (inv G x0) (op G x0 x1) = op G (inv G x0) (op G x0 x2)) as H5. f_equal. assumption.rewrite (op_assoc G) in H5 at 1. rewrite (op_inv_l G) in H5 at 1. 2:assumption.rewrite (op_id_l G) in H5 at 1. 2:assumption.rewrite (op_assoc G) in H5 at 1. rewrite (op_inv_l G) in H5 at 1. 2:assumption.rewrite (op_id_l G) in H5 at 1. 2:assumption.assumption. Qed.

These results lend credence to the idea that reinforcement learning through self-play has the potential to supersede human performance. That is, by learning from self-play rather than human training data, computers may become more "creative" by utilizing their own strengths, rather than human strengths. The success of self-play in videogames has been demonstrated before by AlphaGo and the OpenAI 5, both of which defeated human competitors after training using self-play rather than human data. It seems like the success of self-playing machines may also extend to mathematical theorem proving.

Code Repository: The code for this project is available online at:
Acknowledgements: All of the following discoveries and code was a collaborative effort. The undergraduate research team this past semester consisted of Christian Omar Altamirano Modesto, Jessica Shi, Laura Koemmpel, Kevin Shen, and myself, supervised by Dr. Andrzej Banburski.