Pytorch-lightning

https://lightning.ai/docs/pytorch/stable/starter/introduction.html 

mamba install lightning torchvision

Define a model

import os

from torch import optim, nn, utils, Tensor

from torchvision.datasets import MNIST

from torchvision.transforms import ToTensor

import lightning.pytorch as pl


# define any number of nn.Modules (or use your current ones)

encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))



# define the LightningModule

class LitAutoEncoder(pl.LightningModule):

    def __init__(self, encoder, decoder):

        super().__init__()

        self.encoder = encoder

        self.decoder = decoder


    def training_step(self, batch, batch_idx):

        # training_step defines the train loop.

        # it is independent of forward

        x, y = batch

        x = x.view(x.size(0), -1)

        z = self.encoder(x)

        x_hat = self.decoder(z)

        loss = nn.functional.mse_loss(x_hat, x)

        # Logging to TensorBoard (if installed) by default

        self.log("train_loss", loss)

        return loss


    def configure_optimizers(self):

        optimizer = optim.Adam(self.parameters(), lr=1e-3)

        return optimizer



# init the autoencoder

autoencoder = LitAutoEncoder(encoder, decoder)

 Get a datast

dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())

train_loader = utils.data.DataLoader(dataset)

Train

trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)

trainer.fit(model=autoencoder, train_dataloaders=train_loader)

 Predict

checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"

autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

encoder = autoencoder.encoder

encoder.eval()

fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)

embeddings = encoder(fake_image_batch)

print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)