Requires MNIST dataset, Numpy, and PANDAS. After less than 500 iterations, the accuracy is above 91% for handwriting recognition.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Load data
data = pd.read_csv('train.csv')
data = np.array(data)
m, n = data.shape
np.random.shuffle(data)
# Split data
data_dev = data[0:1000].T
Y_dev = data_dev[0]
X_dev = data_dev[1:n] / 255.0 # Normalize the values in the grayscale from 0 to 1 instead of 0 to 255
data_train = data[1000:m].T
Y_train = data_train[0]
X_train = data_train[1:n] / 255.0 # Normalize again
# Initialize parameters
def init_params():
W1 = np.random.randn(10, 784) * np.sqrt(2.0 / 784) # He initialization
b1 = np.zeros((10, 1))
W2 = np.random.randn(10, 10) * np.sqrt(2.0 / 10)
b2 = np.zeros((10, 1))
return W1, b1, W2, b2
# Activation functions
def ReLU(Z):
return np.maximum(0, Z)
def softmax(Z):
Z_max = np.max(Z, axis=0, keepdims=True)
e_Z = np.exp(Z - Z_max)
return e_Z / np.sum(e_Z, axis=0, keepdims=True)
# Forward propagation, here we apply a linear transformation to the matrix X of m samples, the b's are the biases, w's are the weights
def forward_prop(W1, b1, W2, b2, X):
Z1 = W1.dot(X) + b1
A1 = ReLU(Z1)
Z2 = W2.dot(A1) + b2
A2 = softmax(Z2)
return Z1, A1, Z2, A2
# One-hot encoding
def one_hot(Y):
num_classes = 10
one_hot_Y = np.zeros((num_classes, Y.size))
one_hot_Y[Y, np.arange(Y.size)] = 1
return one_hot_Y.astype(np.float32)
# Derivative of ReLU
def deriv_ReLU(Z):
return Z > 0
# Backpropagation
def back_prop(Z1, A1, Z2, A2, W2, X, Y):
m = Y.size
one_hot_Y = one_hot(Y)
dZ2 = A2 - one_hot_Y #This is the error at the output layer one-hot is the true encoded
dW2 = 1 / m * dZ2.dot(A1.T) #computing the gradient of the weigths in the second layer
db2 = 1 / m * np.sum(dZ2, axis=1, keepdims=True) #computing the gradient of the biases in the second layer
# Computing the error at the hidden layer, the derivative of the activation function ensures that the only
#neurons that were activated during forward propagation can contribute to the gradient dZ1
dZ1 = W2.T.dot(dZ2) * deriv_ReLU(Z1)
# Compute the gradient of the weights in the first layer, X is the input data, and this dot product computes the gradient of the loss
# with respect to W1, then 1/m averages it
dW1 = 1 / m * dZ1.dot(X.T)
#Again, computing the gradient of the biases for the first layer
db1 = 1 / m * np.sum(dZ1, axis=1, keepdims=True)
return dW1, db1, dW2, db2
# Update parameters
def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):
W1 = W1 - alpha * dW1
b1 = b1 - alpha * db1
W2 = W2 - alpha * dW2
b2 = b2 - alpha * db2
return W1, b1, W2, b2
# Get predictions
def get_predictions(A2): #We select the index which has highest value that is positive,
#this is what the NN thinks is the likeliest digit it sees
return np.argmax(A2, 0)
# Get accuracy
def get_accuracy(predictions, Y):# It goes through the array checking equality and marking T/F,
#from there it divides the number of True's by the total number of samples
return np.sum(predictions == Y) / Y.size
# Compute loss
def compute_loss(A2, Y):
m = Y.size
one_hot_Y = one_hot(Y)#converts it into a matrix where each row contains zeroes except where the class is set, there it is set to 1
loss = -np.sum(one_hot_Y * np.log(A2)) / m #first compute the elementwise product of the probabilities in A2 with the true values in one_hot Y,
#the zeroes will throw away the less likely values termed by the NN and then we get the probability of what it thinks the digit is.
#taking the sum and negative log comes as the definition of cross entropy loss (penalizes incorrect predictions more than just a standard sum)
return loss
# Gradient descent
def gradient_descent(X, Y, iterations, alpha):
W1, b1, W2, b2 = init_params()
for i in range(iterations):
Z1, A1, Z2, A2 = forward_prop(W1, b1, W2, b2, X)
dW1, db1, dW2, db2 = back_prop(Z1, A1, Z2, A2, W2, X, Y)
W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha)
if i % 50 == 0:
loss = compute_loss(A2, Y)
accuracy = get_accuracy(get_predictions(A2), Y)
print(f"Iteration {i}: Loss = {loss}, Accuracy = {accuracy}")
return W1, b1, W2, b2
# Train the model
W1, b1, W2, b2 = gradient_descent(X_train, Y_train, 600, 0.1)
# Evaluate on development set
Z1_dev, A1_dev, Z2_dev, A2_dev = forward_prop(W1, b1, W2, b2, X_dev)
dev_accuracy = get_accuracy(get_predictions(A2_dev), Y_dev)
print("Development Accuracy:", dev_accuracy)
# Function to plot a sample and its prediction
def plot_sample_and_prediction(X, Y, W1, b1, W2, b2):
# Select a random index
index = np.random.randint(0, X.shape[1])
# Get the sample and its label
sample = X[:, index].reshape(28, 28) # Reshape to 28x28 for MNIST
label = Y[index]
# Make a prediction
_, _, _, A2 = forward_prop(W1, b1, W2, b2, X[:, index].reshape(-1, 1))
prediction = get_predictions(A2)
# Plot the sample
plt.imshow(sample, cmap='gray')
plt.title(f"True Label: {label}, Predicted: {prediction[0]}")
plt.axis('off')
plt.show()
# Call the function
plot_sample_and_prediction(X_dev, Y_dev, W1, b1, W2, b2)
This code implements WGAN to solve a classic optimal transport problem between two distributions, one Gaussian, another being sampled from MNIST, the WGAN is capable of generating realistic handwritten images.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import os
import matplotlib.pyplot as plt
# Hyperparameters
latent_dim = 100
channels = 1
img_size = 28
batch_size = 64
n_critic = 5
lr = 0.0002
lambda_gp = 10
n_epochs = 100
os.makedirs("images", exist_ok=True) # Creates directory if it doesn't exist
# Configure device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Generator
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, channels * img_size * img_size),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), channels, img_size, img_size)
return img
# Critic (Discriminator)
class Critic(nn.Module):
def __init__(self):
super(Critic, self).__init__()
self.model = nn.Sequential(
nn.Linear(channels * img_size * img_size, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1)
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
# Initialize networks
generator = Generator().to(device)
critic = Critic().to(device)
# Configure data loader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataset = torchvision.datasets.MNIST(
root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_C = optim.Adam(critic.parameters(), lr=lr, betas=(0.5, 0.999))
# Gradient penalty calculation
def compute_gradient_penalty(critic, real_samples, fake_samples):
alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(device)
interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
d_interpolates = critic(interpolates)
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=torch.ones(d_interpolates.size()).to(device),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
#Initialize lists to keep track of average losses per epoch
c_losses = []
g_losses = []
# Training loop
for epoch in range(n_epochs):
total_c_loss = 0.0 #start with nothing for each neural network
total_g_loss = 0.0
num_g_steps = 0 #just count the number of generator steps
for i, (imgs, _) in enumerate(dataloader):
real_imgs = imgs.to(device)
# Train Critic
optimizer_C.zero_grad()
# Generate fake images
z = torch.randn(imgs.size(0), latent_dim).to(device)
fake_imgs = generator(z)
# Real and fake scores
real_validity = critic(real_imgs)
fake_validity = critic(fake_imgs.detach())
# Gradient penalty
gradient_penalty = compute_gradient_penalty(critic, real_imgs.data, fake_imgs.data)
# Critic loss
loss_C = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
loss_C.backward()
optimizer_C.step()
total_c_loss += loss_C.item() # add how much critic loss to the total!
# Train Generator every n_critic steps
if i % n_critic == 0:
optimizer_G.zero_grad()
# Generate images and calculate loss
gen_imgs = generator(z)
gen_validity = critic(gen_imgs)
loss_G = -torch.mean(gen_validity)
loss_G.backward()
optimizer_G.step()
total_g_loss += loss_G.item() #add how much generat lost in this round
num_g_steps += 1 #count the number of gen steps
# Print progress
if i % 100 == 0:
print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "
f"[D loss: {loss_C.item()}] [G loss: {loss_G.item()}]")
# Calculate average losses for the epoch
avg_c_loss = total_c_loss / len(dataloader)
avg_g_loss = total_g_loss / num_g_steps if num_g_steps > 0 else 0
c_losses.append(avg_c_loss)
g_losses.append(avg_g_loss)
print(f"[Epoch {epoch}/{n_epochs}] Avg Critic loss: {avg_c_loss:.4f}, Avg Generator loss: {avg_g_loss:.4f}")
# Save generated images
torchvision.utils.save_image(gen_imgs.data[:25], f"images/{epoch}.png", nrow=5, normalize=True)
# Plot the training curves
plt.figure(figsize=(10, 5))
plt.plot(c_losses, label="Critic Loss")
plt.plot(g_losses, label="Generator Loss")
plt.title("WGAN Training Convergence")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.savefig("wgan_convergence.png")
plt.show()
This code solves the dynamic optimal transport problem via a Neural ODE. The initial distribution is Gaussian, the final is an MNIST digit.
import matplotlib
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchdiffeq import odeint
from torchvision.datasets import MNIST
from torchvision import transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# Ambient dimension for 2D coordinate flow
Z_DIM = 2
MAX_VAL = 4.0 # coordinate span for histogram grid
class ImageDataset():
"""Sample from a distribution defined by an image histogram."""
def __init__(self, img):
# img: 2D numpy array (H, W)
h, w = img.shape
# grid of coordinates in [-MAX_VAL, +MAX_VAL]
xx = np.linspace(-MAX_VAL, MAX_VAL, w)
yy = np.linspace(-MAX_VAL, MAX_VAL, h)
xx, yy = np.meshgrid(xx, yy)
self.means = np.stack([xx.ravel(), yy.ravel()], axis=1) # (H*W, 2)
# flatten histogram to probability distribution
probs = img.ravel().astype(np.float64)
probs = probs / (probs.sum() + 1e-10)
self.probs = probs
# jitter scale
self.noise_std = np.array([MAX_VAL / w, MAX_VAL / h])
def sample(self, batch_size=512):
inds = np.random.choice(len(self.probs), size=batch_size, p=self.probs)
means = self.means[inds]
noise = np.random.randn(batch_size, 2) * self.noise_std
samples = means + noise
return torch.from_numpy(samples).float()
def compute_kinetic_energy(model, z_samples, ntimes=10):
times = torch.linspace(0, 1, ntimes).to(z_samples)
z_samples = z_samples.detach().clone().requires_grad_(True)
traj, _ = model(z_samples, integration_times=times, reverse=True) # (T, B, D)
traj = traj.requires_grad_(True)
kinetic_energy = 0.0
for i in range(ntimes):
t = times[i].expand(z_samples.size(0), 1)
z = traj[i]
z_dot = model.odefunc.get_z_dot(t, z)
kinetic_energy += (z_dot ** 2).sum(dim=1).mean()
kinetic_energy /= ntimes
return kinetic_energy
class ODEfunc(nn.Module):
def __init__(self, hidden_dims=(128,128)):
super(ODEfunc, self).__init__()
dims = [Z_DIM] + list(hidden_dims) + [Z_DIM]
self.layers = nn.ModuleList([
nn.Linear(dims[i] + 1, dims[i+1]) for i in range(len(dims)-1)
])
def get_z_dot(self, t, z):
z_dot = z
for i, layer in enumerate(self.layers):
tz = torch.cat([t.expand(z.size(0), 1), z_dot], dim=1)
z_dot = layer(tz)
if i < len(self.layers)-1:
z_dot = F.softplus(z_dot)
return z_dot
def forward(self, t, state):
z, delta_logpz = state
batch = z.size(0)
# z tracks gradients
if not z.requires_grad:
z.requires_grad_(True)
z_dot = self.get_z_dot(t, z)
# Compute divergence using gradients
div = 0.0
for i in range(z.size(1)):
div += torch.autograd.grad(z_dot[:, i].sum(), z, create_graph=True)[0][:, i]
return z_dot, -div.view(batch, 1)
class FfjordModel(nn.Module):
def __init__(self):
super(FfjordModel, self).__init__()
self.odefunc = ODEfunc()
def forward(self, z, delta_logpz=None, integration_times=None, reverse=False):
if delta_logpz is None:
# delta_logpz is created with gradients
delta_logpz = torch.zeros(z.size(0), 1, device=z.device, requires_grad=True)
if integration_times is None:
integration_times = torch.tensor([0.0, 1.0]).to(z)
if reverse:
integration_times = integration_times.flip(0)
# State includes z and delta_logpz
state = odeint(
self.odefunc,
(z, delta_logpz),
integration_times,
method='dopri5',
atol=1e-6,
rtol=1e-6
)
z_t, logpz_t = state
if len(integration_times) == 2:
z_t, logpz_t = z_t[1], logpz_t[1]
return z_t, logpz_t
def standard_normal_logprob(z):
# 2D standard normal log-density
return (-0.5 * (2 * np.log(2*np.pi) + (z**2).sum(1, keepdim=True)))
def save_trajectory(model, savedir='imgs', ntimes=300, memory=0.01, n=4000):
model.eval()
# Enable gradients for trajectory computation
z_samples = torch.randn(n, Z_DIM, requires_grad=True).to(device)
npts = 500
side = np.linspace(-MAX_VAL, MAX_VAL, npts)
xx, yy = np.meshgrid(side, side)
grid = torch.from_numpy(np.stack([xx.ravel(), yy.ravel()], 1)).float().to(device)
# Temporarily enable gradients for ODE solve
with torch.set_grad_enabled(True):
logp_s = standard_normal_logprob(z_samples)
times = torch.linspace(0, 1, ntimes).to(device)
z_traj, _ = model(z_samples, logp_s, integration_times=times, reverse=True)
# Detach for visualization
z_traj = z_traj.detach().cpu()
# Visualize the final frame to check if transport completed
plt.clf()
plt.hist2d(z_traj[-1,:,0], z_traj[-1,:,1], bins=200,
range=[[-MAX_VAL, MAX_VAL], [-MAX_VAL, MAX_VAL]])
plt.title("Final Frame Check")
#plt.gca().invert_yaxis()
plt.gca().set_aspect('equal', adjustable='box')
# Detach for visualization
z_traj = z_traj.detach().cpu()
# Plot final transported samples
plt.clf()
plt.figure(figsize=(6, 6))
plt.hist2d(z_traj[-1,:,0], z_traj[-1,:,1], bins=200,
range=[[-MAX_VAL, MAX_VAL], [-MAX_VAL, MAX_VAL]])
# Add target samples as orange dots (for visual comparison)
x_target = dset.sample(n).to(device)
plt.scatter(x_target[:,0].cpu(), x_target[:,1].cpu(), color='orange', s=0.1, alpha=0.5)
plt.gca().set_aspect('equal')
plt.title("Final Frame vs Target Samples")
plt.gca().invert_yaxis()
plt.savefig("final_frame_check.png")
plt.savefig("final_frame_check.png")
os.makedirs(savedir, exist_ok=True)
for t in range(z_traj.size(0)):
plt.clf()
plt.subplot(1,2,1)
plt.hist2d(z_traj[t,:,0], z_traj[t,:,1], bins=200,
range=[[-MAX_VAL, MAX_VAL],[-MAX_VAL, MAX_VAL]])
plt.title('Transport')
plt.gca().set_aspect('equal', adjustable='box')
plt.gca().invert_yaxis()
plt.savefig(os.path.join(savedir, f"viz-{t:05d}.jpg"))
def trajectory_to_video(savedir='imgs', mp4_fn='transform.mp4'):
import subprocess
img_pattern = os.path.join(savedir, 'viz-%05d.jpg')
out = os.path.join(savedir, mp4_fn)
cmd = f'ffmpeg -y -i {img_pattern} {out}'
subprocess.call(cmd, shell=True)
if __name__ == '__main__':
# Load a single MNIST digit as a 2D histogram
transform = transforms.Compose([transforms.ToTensor()])
mnist = MNIST(root='data', train=True, download=True, transform=transform)
img_tensor, label = mnist[13] # pick digit index 0
img = img_tensor.squeeze().numpy() # shape (28,28)
# Save target image PDF
plt.imshow(img, cmap='gray')
plt.savefig('target_image.pdf')
plt.close('all')
# Build 2D point-cloud dataset
dset = ImageDataset(img)
# Define model and optimizer
model = FfjordModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
model.train()
for batch in range(2000):
optimizer.zero_grad()
# Enable gradients for z_t1
z_t1 = dset.sample(1024).to(device).requires_grad_(True)
z_t0, delta_logpz = model(z_t1)
logpz_t0 = standard_normal_logprob(z_t0)
logpz_t1 = logpz_t0 - delta_logpz
loss = -torch.mean(logpz_t1)
nll_loss = -torch.mean(logpz_t1)
if batch % 100 == 0:
print(f"Batch {batch:04d}: loss = {loss.item():.5f}")
# kinetic energy term
kinetic_energy = compute_kinetic_energy(model, z_t1, ntimes=10)
loss = nll_loss + 2 * kinetic_energy # 0.01 is a tunable weight
if batch % 100 == 0:
print(f"Batch {batch:04d}: NLL = {nll_loss.item():.5f}, KE = {kinetic_energy.item():.5f}")
loss.backward()
optimizer.step()
# Save model
model_path = 'ffjord_mnist_2d_state.tar'
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")
# Visualize flow
save_trajectory(model, ntimes=100)
trajectory_to_video()
z_t1 = dset.sample(1000).to(device)
z_t0, _ = model(z_t1)
z_t0 = z_t0.detach().cpu().numpy()
z_t1 = z_t1.detach().cpu().numpy()
plt.scatter(z_t1[:, 0], z_t1[:, 1], label='Target (t=1)', alpha=0.5)
plt.scatter(z_t0[:, 0], z_t0[:, 1], label='Mapped to prior (t=0)', alpha=0.5)
plt.legend()
plt.title("Transport check")
plt.savefig("transport_debug.png")
Here are some visualizations of the transport occurring from both a Gaussian mixture, and a Gaussian distribution.