Mondrianiser: Image Inpainting with VAEs

an example of generative models using a variational autoencoder
Experiments
Machine Learning
Deep Learning
Published

March 17, 2025

Besides classification and regression tasks, deep learning models can also be used for generative tasks, where the goal is to generate new data samples that can mimic the distribution of the training data. Generative models have a wide range of applications, including image generation, style transfer, image inpainting, and more.

In this experiment, we’ll explore the use of a :link Variational Autoencoder (VAE) to inpaint missing regions in images. The task of inpainting involves filling in unknown regions, which can be useful for image restoration, editing, and other applications. For example, your phone camera likely has a feature that can remove unwanted objects from a photo by inpainting the missing regions - the way it achieves this is by using generative models like VAEs.

The architecture of a Variational Autoencoder

The VAE is a type of generative model which learns a latent representation of the input data. What latent in this context means is that the model transforms the original data into a compressed, hidden form that captures its most important features. This latent space is not directly observable but serves as an internal representation from which the model can reconstruct or generate new data.

Note

A similar concept exists with text embeddings. When processing text, models convert words or sentences into numerical vectors that capture the underlying meaning, relationships, and context. These embeddings are like a latent space for language, they’re not meant to be read directly but provide a simplified and efficient representation of the text’s core information. This parallel shows that whether dealing with images or text, many models rely on hidden representations to manage and manipulate complex data.

Once the model has learned this compressed representation, it can use it to create variations or entirely new examples that resemble the original inputs. By working in this latent space, the VAE is able to simplify complex data into a more manageable format while still retaining the core characteristics needed for effective reconstruction and generation.

VAE cluster_encoder Encoder cluster_latent Latent Space cluster_decoder Decoder input Original E1_1 input->E1_1 E1_2 input->E1_2 E1_3 input->E1_3 output Reconstructed E2_1 E1_1->E2_1 E2_2 E1_1->E2_2 E2_3 E1_1->E2_3 E1_2->E2_1 E1_2->E2_2 E1_2->E2_3 E1_3->E2_1 E1_3->E2_2 E1_3->E2_3 L1 E2_1->L1 L2 E2_1->L2 L3 E2_1->L3 E2_2->L1 E2_2->L2 E2_2->L3 E2_3->L1 E2_3->L2 E2_3->L3 D1_1 L1->D1_1 D1_2 L1->D1_2 D1_3 L1->D1_3 L2->D1_1 L2->D1_2 L2->D1_3 L3->D1_1 L3->D1_2 L3->D1_3 D2_1 D1_1->D2_1 D2_2 D1_1->D2_2 D2_3 D1_1->D2_3 D1_2->D2_1 D1_2->D2_2 D1_2->D2_3 D1_3->D2_1 D1_3->D2_2 D1_3->D2_3 D2_1->output D2_2->output D2_3->output

Mondrian

Training a VAE for an inpainting task requires a dataset of images for the model to learn from. For this experiment, we will use a simple Mondrian image generator to create a dataset of images with missing regions. The Mondrian images are inspired by the works of :link Piet Mondrian, a Dutch painter known for his abstract compositions of lines and colors.

These images are created by recursively splitting the canvas into smaller regions, each filled with a random color. The Mondrian images will serve as our training data, with a square mask applied to each image to simulate the missing regions that need to be inpainted. They are simple enough to generate programmatically, and for a relatively simple model to learn from without requiring a large dataset of images (like for example, the CelebA dataset).

Let us start by creating a Mondrian image generator, which we will then use to create a dataset for training our VAE model.

Show the code
import random
import numpy as np


# Mondrian image generator
def generate_mondrian(
    width,
    height,
    border_thickness=3,
    split_prob=0.7,
    black_line_thickness=2,
    min_depth=2,
    max_depth=5,
    overall_border_thickness=6,
):
    img = np.ones((height, width, 3), dtype=np.uint8) * 255
    colors = [
        (255, 0, 0),  # red
        (0, 0, 255),  # blue
        (255, 255, 0),  # yellow
        (255, 255, 255),  # white
        (255, 165, 0),  # orange
    ]

    def fill_region(x, y, w, h):
        fill_color = random.choice(colors)
        img[y : y + h, x : x + w] = fill_color
        img[y : y + border_thickness, x : x + w] = 255
        img[y + h - border_thickness : y + h, x : x + w] = 255
        img[y : y + h, x : x + border_thickness] = 255
        img[y : y + h, x + w - border_thickness : x + w] = 255

    def split_region(x, y, w, h, depth=0):
        if w < 50 or h < 50:
            fill_region(x, y, w, h)
            return
        if depth >= max_depth:
            fill_region(x, y, w, h)
            return
        if depth >= min_depth and random.random() > split_prob:
            fill_region(x, y, w, h)
            return

        if w > h:
            split_x = random.randint(x + int(0.3 * w), x + int(0.7 * w))
            white_start = split_x - border_thickness // 2
            white_end = split_x + border_thickness // 2
            img[y : y + h, white_start:white_end] = 255
            bl_start = split_x - black_line_thickness // 2
            bl_end = split_x + black_line_thickness // 2 + (black_line_thickness % 2)
            img[y : y + h, bl_start:bl_end] = (75, 75, 75)
            left_width = white_start - x
            right_width = (x + w) - white_end
            split_region(x, y, left_width, h, depth + 1)
            split_region(white_end, y, right_width, h, depth + 1)
        else:
            split_y = random.randint(y + int(0.3 * h), y + int(0.7 * h))
            white_start = split_y - border_thickness // 2
            white_end = split_y + border_thickness // 2
            img[white_start:white_end, x : x + w] = 255
            bl_start = split_y - black_line_thickness // 2
            bl_end = split_y + black_line_thickness // 2 + (black_line_thickness % 2)
            img[bl_start:bl_end, x : x + w] = (75, 75, 75)
            top_height = white_start - y
            bottom_height = (y + h) - white_end
            split_region(x, y, w, top_height, depth + 1)
            split_region(x, white_end, w, bottom_height, depth + 1)

    split_region(0, 0, width, height)
    img[0:overall_border_thickness, :] = (75, 75, 75)
    img[-overall_border_thickness:, :] = (75, 75, 75)
    img[:, 0:overall_border_thickness] = (75, 75, 75)
    img[:, -overall_border_thickness:] = (75, 75, 75)

    return img

This function creates Mondrian like images, such as the following examples.

Show the code
import matplotlib.pyplot as plt

# Generate a 2x2 grid of Mondrian images
fig, axs = plt.subplots(2, 2, figsize=(8, 8))
for i in range(2):
    for j in range(2):
        img = generate_mondrian(256, 256)
        axs[i, j].imshow(img)
        axs[i, j].axis("off")
plt.show()

The dataset generator

To train our VAE model, we need a dataset of Mondrian images with masked regions. We will create a PyTorch dataset class that generates images and applies a random square mask to each one. The dataset will return the original image, the masked image, and the mask itself as the training samples. Note that we normalize the pixel values to the range \([0, 1]\) to facilitate training, as neural networks typically perform better with inputs in this range.

The model will use the Dataset class to load the training data in batches during training, with inputs being passed to the model as needed. Notice that the masked image is created by setting the pixel values in the masked region to zero, effectively removing that part (this is why the image generator sets a (75, 75, 75) grey border and lines, so the model doesn’t confuse the mask with black areas of the original).

Show the code
import torch
from torch.utils.data import Dataset, DataLoader


# Dataset generator
class MondrianDataset(Dataset):
    def __init__(self, num_samples=1000, width=256, height=256, mask_size=64):
        self.num_samples = num_samples
        self.width = width
        self.height = height
        self.mask_size = mask_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate a Mondrian image and normalize to [0,1]
        img = generate_mondrian(self.width, self.height)
        img = img.astype(np.float32) / 255.0

        # Create a random square mask
        x = random.randint(0, self.width - self.mask_size)
        y = random.randint(0, self.height - self.mask_size)
        mask = np.zeros((self.height, self.width, 1), dtype=np.float32)
        mask[y : y + self.mask_size, x : x + self.mask_size] = 1.0

        # Create masked image: set the masked region to 0
        masked_img = np.copy(img)
        masked_img[y : y + self.mask_size, x : x + self.mask_size, :] = 0.0

        # Convert HWC to CHW tensors
        masked_img = torch.from_numpy(masked_img).permute(2, 0, 1)
        img = torch.from_numpy(img).permute(2, 0, 1)
        mask = torch.from_numpy(mask).permute(2, 0, 1)
        return masked_img, mask, img

The VAE model

With the dataset and generator in place, let us move to looking into the architecture of the Variational Autoencoder model. Previously we mentioned the encoder and decoder - we will organise our model by separating both functions clearly, and then bringing it together into a single model class.

We will not go too deep into an explanation of how a VAE works, but will provide a high-level overview of the model’s components so that readers can understand the overall architecture without getting bogged down in the details or the math.

Encoding

The encoder is responsible for transforming the input image into a latent representation. In our VAE model, the encoder consists of several :link convolutional layers that downsample the input image, capturing its features at different levels of abstraction. The encoder outputs the mean (mu) and log variance (logvar) of the latent distribution, which are used to sample the latent vector during training.

Note

For the mathematically inclined, mu (\(\mu\)) and logvar are the parameters of a Gaussian distribution that approximates the true posterior distribution of the latent space. The encoder learns to map the input image to the parameters of this distribution, which allows the model to sample latent vectors during training. The Gaussian distribution is chosen for its simplicity and differentiability, which makes it easier to train the model using backpropagation.

The encoder outputs two vectors—one for the mean \(\mu\) and one for the log-variance \(\log \sigma^2\), typically computed as:

\[ \mu = W_{\mu} \cdot x + b_{\mu} \]

\[ \log \sigma^2 = W_{\log \sigma^2} \cdot x + b_{\log \sigma^2} \]

Here, \(x\) represents the features extracted from the input by earlier layers, and \(W\) and \(b\) are learned parameters.

Once you have \(\mu\) and \(\log \sigma^2\), you can obtain the standard deviation by taking:

\[ \sigma = \exp\left(\frac{1}{2} \log \sigma^2\right) \]

This formulation is critical for the reparameterization trick, which allows for backpropagation through the sampling process. Specifically, a latent vector \(z\) is sampled as:

\[ z = \mu + \sigma \odot \epsilon \]

with \(\epsilon \sim \mathcal{N}(0, I)\).

This approach ensures that the sampling is differentiable, making the VAE training stable.

Encoding in practice captures parts of the image that are important for reconstruction, such as edges, textures and shapes. The encoder’s output is a compressed representation of the input image that can be used to reconstruct the original image or generate new samples.

Show the code
from torch import nn
import torch.nn.functional as F


# U-NET style VAE encoder
class Encoder(nn.Module):
    """Downsampling encoder that captures intermediate features for skip connections."""

    def __init__(self, latent_dim=128):
        super(Encoder, self).__init__()
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # 256 -> 128
            nn.ReLU(),
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 128 -> 64
            nn.ReLU(),
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 64 -> 32
            nn.ReLU(),
        )
        self.enc4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # 32 -> 16
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(256 * 16 * 16, latent_dim)
        self.fc_logvar = nn.Linear(256 * 16 * 16, latent_dim)

    def forward(self, x):
        f1 = self.enc1(x)  # [B, 32, 128, 128]
        f2 = self.enc2(f1)  # [B, 64, 64, 64]
        f3 = self.enc3(f2)  # [B, 128, 32, 32]
        f4 = self.enc4(f3)  # [B, 256, 16, 16]
        flat = f4.view(f4.size(0), -1)
        mu = self.fc_mu(flat)
        logvar = self.fc_logvar(flat)
        return f1, f2, f3, f4, mu, logvar

You might be tempted to think of the VAE encoder as similar to a hash function (like SHA256 for example), but there are key differences. A hash function is a one-way transformation that maps any input to a fixed-size output. It’s designed for tasks like data integrity checks, where even a tiny change in the input leads to a completely different hash. In contrast, the encoder in a VAE is a learnable function that compresses input data into a latent space. This latent representation retains the core features of the original data so that it can later be used to reconstruct or, even generate new, similar data.

The VAE encoder doesn’t produce a single deterministic output like a hash function does. Instead, it outputs parameters, the mean and variance of a probability distribution in the latent space. This allows for controlled randomness, enabling smooth transitions and meaningful variations when sampling from the latent space. While both methods reduce the dimensionality of data, the VAE encoder is built to preserve the underlying structure and semantics necessary for generating or reconstructing data, rather than just providing a unique fingerprint of the input, like a hash function does.

Decoding

The VAE decoder essentially reverses the output of the encoder. It takes the latent vector, sampled from a distribution defined by the encoder’s mean and variance, and maps it back to the original data space. In the case of images, the decoder is usually composed of a series of transposed convolutional (or deconvolutional) layers that gradually upsample the latent representation. This process reconstructs the image by piecing together the key features, such as edges, textures, and colors, that the encoder originally captured in its compressed form.

Show the code
# VAE decoder
class Decoder(nn.Module):
    """Upsampling decoder that uses skip connections from the encoder."""

    def __init__(self, latent_dim=128):
        super(Decoder, self).__init__()
        self.fc_dec = nn.Linear(latent_dim, 256 * 16 * 16)

        # Up 1: f4 -> (B,256,16,16) -> upsample -> (B,256,32,32) + skip f3 -> conv -> (B,128,32,32)
        self.up4 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Sequential(
            nn.Conv2d(256 + 128, 128, kernel_size=3, padding=1), nn.ReLU()
        )

        # Up 2: (B,128,32,32) -> upsample -> (B,128,64,64) + skip f2 -> conv -> (B,64,64,64)
        self.up3 = nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Sequential(
            nn.Conv2d(128 + 64, 64, kernel_size=3, padding=1), nn.ReLU()
        )

        # Up 3: (B,64,64,64) -> upsample -> (B,64,128,128) + skip f1 -> conv -> (B,32,128,128)
        self.up2 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Sequential(
            nn.Conv2d(64 + 32, 32, kernel_size=3, padding=1), nn.ReLU()
        )

        # Up 4: (B,32,128,128) -> upsample -> (B,32,256,256) -> final -> (B,3,256,256)
        self.up1 = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.Sequential(
            nn.Conv2d(32, 3, kernel_size=3, padding=1), nn.Sigmoid()
        )

    def forward(self, z, f1, f2, f3, f4):
        # Expand latent to spatial
        x = self.fc_dec(z).view(-1, 256, 16, 16)

        # Up 1 (skip f3)
        x = self.up4(x)  # -> [B,256,32,32]
        x = torch.cat([x, f3], dim=1)  # -> [B,256+128=384,32,32]
        x = self.conv4(x)  # -> [B,128,32,32]

        # Up 2 (skip f2)
        x = self.up3(x)  # -> [B,128,64,64]
        x = torch.cat([x, f2], dim=1)  # -> [B,128+64=192,64,64]
        x = self.conv3(x)  # -> [B,64,64,64]

        # Up 3 (skip f1)
        x = self.up2(x)  # -> [B,64,128,128]
        x = torch.cat([x, f1], dim=1)  # -> [B,64+32=96,128,128]
        x = self.conv2(x)  # -> [B,32,128,128]

        # Up 4 (no skip)
        x = self.up1(x)  # -> [B,32,256,256]
        x = self.conv1(x)  # -> [B,3,256,256]
        return x

During training, the decoder learns to generate images that are as close as possible to the original inputs, guided by a reconstruction loss. This means that the decoder doesn’t simply replicate the input image but instead creates a version that preserves the most important details. The smooth, continuous nature of the latent space ensures that small changes in the latent vector result in gradual, meaningful variations in the output. As a result, once trained, the decoder is not only capable of accurately reconstructing inputs but also of generating entirely new samples that share similar characteristics with the training data.

In this case, we have used a technique called a :link U-NET architecture, which enhances the basic encoder-decoder design with skip connections. These connections allow the model to carry over fine-grained spatial details from the encoder directly to the corresponding layers in the decoder. This means that while the encoder compresses the image into a latent space, U-NET helps preserve important features, ensuring that the reconstructed image retains higher fidelity to the original.

By incorporating it, we improve the quality of our reconstructions significantly. The architecture not only captures the global structure through the bottleneck (the latent space) but also reintroduces local details via the skip connections.

UNet_VAE cluster_encoder Encoder cluster_latent Latent Space cluster_decoder Decoder input Original E1_1 input->E1_1 E1_2 input->E1_2 E1_3 input->E1_3 output Reconstructed E2_1 E1_1->E2_1 E2_2 E1_1->E2_2 E2_3 E1_1->E2_3 E1_2->E2_1 E1_2->E2_2 E1_2->E2_3 E1_3->E2_1 E1_3->E2_2 E1_3->E2_3 D2_3 E1_3->D2_3 skip L1 E2_1->L1 L2 E2_1->L2 L3 E2_1->L3 E2_2->L1 E2_2->L2 E2_2->L3 E2_3->L1 E2_3->L2 E2_3->L3 D1_3 E2_3->D1_3 skip D1_1 L1->D1_1 D1_2 L1->D1_2 L1->D1_3 L2->D1_1 L2->D1_2 L2->D1_3 L3->D1_1 L3->D1_2 L3->D1_3 D2_1 D1_1->D2_1 D2_2 D1_1->D2_2 D1_1->D2_3 D1_2->D2_1 D1_2->D2_2 D1_2->D2_3 D1_3->D2_1 D1_3->D2_2 D1_3->D2_3 D2_1->output D2_2->output D2_3->output

The model

With the encoder and decoder in place, we can now define the full VAE model. The model combines the encoder and decoder components, along with a reparametrization function that samples from the latent distribution defined by the encoder’s output. The model’s forward pass takes the input image, encodes it into the latent space, samples a latent vector, and then decodes it back into the image space.

Note the forward pass returns the reconstructed image (recon), the mean (mu), and the log variance (logvar) of the latent distribution. The mean and log variance are used to compute the :link Kullback-Leibler (KL) divergence loss, which helps regularize the latent space during training. The KL divergence ensures that the latent distribution remains close to a standard normal distribution, which aids in generating realistic samples and controlling the model’s capacity.

Show the code
# The VAE model
class VAE_UNet(nn.Module):
    """U-Net style VAE that returns reconstruction, mu, logvar."""

    def __init__(self, latent_dim=128):
        super(VAE_UNet, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        f1, f2, f3, f4, mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decoder(z, f1, f2, f3, f4)
        return recon, mu, logvar

Try not to get too lost in the mathematics of the model. It is not an easy topic to grapple with, for now what matters is that you understand the high-level architecture of the model, the main parameters involved, and how the encoder and decoder work together to learn a compressed representation of the input data. If you are curious, you can read the original paper by Kingma and Welling that introduced the VAE concept, or a slightly gentler introduction.

Annealing the KL divergence

During training, the VAE model minimizes a loss function that consists of two components: a reconstruction loss and a Kullback-Leibler (KL) divergence loss. The reconstruction loss measures the difference between the input and the reconstructed output, while the KL divergence loss ensures that the latent distribution remains close to a standard normal distribution.

The KL divergence loss is weighted by a parameter kl_weight, which controls the importance of the KL divergence term during training. An annealing schedule is often used to gradually increase the KL weight over the course of training. This helps the model first focus on learning a good reconstruction, before enforcing a more structured latent space.

Here we define a simple linear annealing function that scales the KL weight from \(0\) to \(1\) between a start and end epoch. This function will be used during training to adjust the KL weight over time.

Show the code
# The KL annealing function
def kl_anneal_function(epoch, start_epoch=0, end_epoch=10):
    """
    Linearly scales KL weight from 0.0 to 1.0 between start_epoch and end_epoch.
    """
    if epoch < start_epoch:
        return 0.0
    elif epoch > end_epoch:
        return 1.0
    else:
        return (epoch - start_epoch) / (end_epoch - start_epoch)

Here’s what the KL annealing schedule looks like over the course of training for 100 epochs. The KL weight starts at \(0.0\) and gradually increases to \(1.0\) between epochs 0 and 50.

Show the code
# Plot the KL annealing schedule

epochs = 100
kl_weights = [kl_anneal_function(epoch, 0, epochs // 2) for epoch in range(epochs)]
plt.figure(figsize=(8, 4))
plt.plot(range(epochs), kl_weights, marker="o")
plt.xlabel("Epoch")
plt.ylabel("KL Weight")
plt.title("KL Annealing Schedule")
plt.grid(True)
plt.show()

The loss function

The loss function of the VAE, as we have touched upon before, is composed of two main components: the reconstruction loss (recon_loss) and the KL divergence term (KL_loss). The reconstruction loss measures how close the reconstructed output is to the original input, but in our function is computed only over the masked region using Mean Squared Error (MSE). This ensures that the model focuses on accurately recreating the parts of the image that matter most.

By balancing these two components, the overall loss ensures that the model not only produces high-quality reconstructions but also learns a well-structured latent space. The kl_weight parameter lets you adjust the emphasis on the KL divergence term relative to the reconstruction loss. A higher kl_weight will force the latent space to be more closely aligned with a normal distribution, potentially at the expense of reconstruction accuracy, whereas a lower weight will prioritize accurate reconstructions. The annealing scheduler we defined earlier helps the model go from focusing on the broader structure of the latent space to the finer details as training progresses.

In this function, recon_x is the reconstructed output, x is the original input, mu is the mean of the latent distribution, logvar is the log variance (all of which are computed in the forward pass), and mask is the binary mask indicating the missing region.

Show the code
def loss_function(recon_x, x, mu, logvar, mask, kl_weight):
    # MSE only over the masked region
    recon_loss = nn.functional.mse_loss(recon_x * mask, x * mask, reduction="sum")
    # KL divergence
    KL_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_weight * KL_loss

The training loop

We now have all the components needed to train our VAE model for the inpainting task. The training loop consists of the following steps:

  1. Iterate over the dataset in batches.
  2. Compute the forward pass through the model to get the reconstructed output, mean, and log variance.
  3. Compute the loss function using the reconstructed output, original input, mean, log variance, and mask.
  4. Backpropagate the gradients and update the model parameters.
  5. Periodically run an inference step to visualize the inpainting results.
Show the code
from tqdm import tqdm


# A training loop with periodic inference
def train_vae_unet(model, dataloader, optimizer, device, epochs=20, inferences=10):
    model.train()
    interval = max(1, epochs // inferences)
    losses = []
    for epoch in range(epochs):
        kl_weight = kl_anneal_function(epoch, 0, epochs // 2)
        total_loss = 0
        progress = tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=False)
        for masked_img, mask, img in progress:
            masked_img, mask, img = (
                masked_img.to(device),
                mask.to(device),
                img.to(device),
            )
            optimizer.zero_grad()
            recon, mu, logvar = model(masked_img)
            loss = loss_function(recon, img, mu, logvar, mask, kl_weight)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            progress.set_postfix(
                loss=f"{loss.item():.4f}", KL_Weight=f"{kl_weight:.2f}"
            )
        avg_loss = total_loss / len(dataloader.dataset)
        losses.append(avg_loss)

        if (epoch + 1) % interval == 0 or epoch == 0:
            model.eval()  # Switch to evaluation mode
            inference(model, device, epoch)
            model.train()  # Switch back to training mode

    return losses

Inference

The training loop above calls an inference function at regular intervals to visualize the inpainting results so we can regularly look at how well the model is performing. The inference function generates a new Mondrian image, applies a random square mask to it, and inpaints the missing region using the model. It then plots the original image, the masked image, the inpainted image, and the reconstructed patch. Note that it is given a whole new image to inpaint, not one from the training set.

Show the code
# Run an inference loop
def inference(model, device, epoch=0):
    model.eval()
    width, height, mask_size = 256, 256, 64
    # Generate a new Mondrian image and mask it
    img = generate_mondrian(width, height)
    img = img.astype(np.float32) / 255.0
    x = random.randint(0, width - mask_size)
    y = random.randint(0, height - mask_size)
    mask = np.zeros((height, width, 1), dtype=np.float32)
    mask[y : y + mask_size, x : x + mask_size] = 1.0
    masked_img = np.copy(img)
    masked_img[y : y + mask_size, x : x + mask_size, :] = 0.0

    # Convert to tensor
    masked_tensor = (
        torch.from_numpy(masked_img).permute(2, 0, 1).unsqueeze(0).to(device)
    )
    with torch.no_grad():
        recon, _, _ = model(masked_tensor)
    recon = recon.squeeze(0).permute(1, 2, 0).cpu().numpy()

    # Extract the reconstructed patch
    patch_recon = recon[y : y + mask_size, x : x + mask_size, :]

    # Combine the reconstructed patch into the masked image
    inpainted = np.copy(masked_img)
    inpainted[y : y + mask_size, x : x + mask_size, :] = patch_recon

    # Compute the MSE loss between the original and inpainted regions
    mse_loss = np.mean(
        (
            img[y : y + mask_size, x : x + mask_size]
            - inpainted[y : y + mask_size, x : x + mask_size]
        )
        ** 2
    )

    # Plot
    fig, axs = plt.subplots(1, 4, figsize=(8, 3))
    fig.suptitle(
        f"Epoch: {epoch}, MSE Loss: {mse_loss}", x=0.0, ha="left", fontsize=14
    )  # Left-aligned title
    axs[0].imshow(img)
    axs[0].set_title("Original Image")
    axs[0].axis("off")

    axs[1].imshow(masked_img)
    axs[1].set_title("Masked Image")
    axs[1].axis("off")

    axs[2].imshow(inpainted)
    axs[2].set_title("Inpainted Image")
    axs[2].axis("off")

    axs[3].imshow(patch_recon)
    axs[3].set_title("Reconstructed Patch")
    axs[3].axis("off")

    plt.show()

Putting it all together

We now have all the necessary pieces and can put them together to train our VAE model on the Mondrian dataset. We will use the Adam optimiser to update the model parameters during training. Training will run for 150 epochs, with periodic inference steps to visualize the inpainting results. We will also plot the training loss over time to monitor the model’s progress.

Show the code
from torch import optim

# Initialize random seed for reproducibility
random.seed(123)

device = torch.device("mps" if torch.mps.is_available() else "cpu")

# Create dataset & dataloader
dataset = MondrianDataset(num_samples=10000)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)

# Create U-Net style VAE model
model = VAE_UNet(latent_dim=128).to(device)

# The Adam optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Train (with periodic inference)
max_epoch = 150
losses = train_vae_unet(
    model, dataloader, optimizer, device, epochs=max_epoch, inferences=5
)

Notice how the model is learning and generalizing across epochs. By the end of epoch 0, it can very roughly interpret colours and colour boundaries, but not very accurately. At epoch 59 it has improved significantly, and it can now draw boundary lines, and is very confident with boundaries and colours. At epoch 89 it is nearly getting perfect boundary lines. At epoch 149 it pretty much mastered it.

At epoch 119, it threw a curveball with an edge case. The mask fell on a corner, and the model clearly wasn’t very confident with the inpainting. This is a good example of how it is learning to generalize, but is not perfect at this stage.

Finally let us run one last inference at the last training epoch.

Show the code
inference(model, device, epoch=max_epoch)

By now we can see that the VAE model has learned to inpaint the missing regions quite well. The reconstructed patch is very close to the original, including edge cases such as where border lines meet. The model has learned to capture the structure and colors of the Mondrian images.

Finally let us plot the training loss over time to see how the model’s performance improved during training.

Show the code
# Plot losses
plt.figure(figsize=(8, 4))
plt.plot(losses)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.show()

Producing an entirely new image

We discussed before that the VAE produces a compressed representation of the input data in the latent space. This latent representation can be sampled to generate new images that are similar to the training data. By sampling from the latent space and passing the resulting vector through the decoder, we can create new Mondrian-style images that were not part of the training set.

Let us write a function which does precisely this. It will sample a latent vector from a standard normal distribution, pass it through the decoder, and return the generated image.

Show the code
# Generate a new Mondrian-style image
def generate_synthetic_mondrian(model, device):
    model.eval()
    with torch.no_grad():
        latent_dim = 128  # Adjust if your latent dimension differs
        # Sample a latent vector from a standard normal distribution
        z = torch.randn(1, latent_dim, device=device)
        # Create a dummy input (e.g., a tensor of zeros) with the same shape as a real input image
        dummy_input = torch.zeros(1, 3, 256, 256, device=device)
        # Obtain skip connections from the encoder using the dummy input
        f1, f2, f3, f4, _, _ = model.encoder(dummy_input)
        # Generate the image using the decoder with the sampled latent vector and the dummy skip connections
        img = model.decoder(z, f1, f2, f3, f4)
    # Rearrange from [C, H, W] to [H, W, C] for visualization and convert to NumPy
    img = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
    return img

The function generates a random latent vector \(z\) from a standard normal distribution and passes it through the decoder. The output should be a new, synthetic Mondrian style image that the model has learned to generate based on the training data.

Show the code
# Plot a 2x2 grid of synthetic Mondrian images
fig, axs = plt.subplots(2, 2, figsize=(8, 8))
for i in range(2):
    for j in range(2):
        img = generate_synthetic_mondrian(model, device)
        axs[i, j].imshow(img)
        axs[i, j].axis("off")
plt.show()

If the model had been trained with a more complex dataset, such as photos of birds, trains or cars, the generated images would reflect the characteristics of that dataset. The VAE model learns to capture the underlying structure and patterns of the training data, allowing it to generate new samples that share similar features.

Notice how the synthetic images generated by the model lack detail. This is because we used a dummy input, which leads to a set of skip connections that are not based on any real input data. In practice, you would use a real input image to extract the skip connections, which would provide more meaningful features for generating new images.

Final remarks

We have explored the concept of Variational Autoencoders (VAEs) and how they can be used for image inpainting. A VAE is a type of generative model that learns a compressed representation of the input data, which can be used to generate new samples or reconstruct the original data. There are other generative models, such as Generative Adversarial Networks (GANs), which adopt a fundamentally different training paradigm. Unlike VAEs, which explicitly model the latent space by learning a probabilistic distribution, GANs consist of two networks, a generator and a discriminator, that are trained in an adversarial framework. The generator’s goal is to produce realistic images, while the discriminator’s task is to distinguish between real and generated images.

One major advantage of GANs is their ability to generate sharp, high-quality images. However, they do not naturally provide an interpretable latent space, which can be a limitation for tasks like image inpainting where controlling specific aspects of the generated content is beneficial. GANs can be more challenging to train due to issues such as mode collapse and unstable training dynamics. To address these challenges, researchers have explored hybrid approaches like VAE-GANs. These models combine the structured latent space of VAEs with the adversarial loss of GANs, aiming to achieve both meaningful representations and high-quality image generation.

We have implemented a VAE model with a U-NET architecture and trained it on a dataset of Mondrian images with masked regions. The model learned to inpaint the missing regions by reconstructing the original image from the compressed latent space representation.

As an exercise, we can further train the model on a more complex dataset, for example the CelebA. We will leave it to another time.

Reuse

This work is licensed under CC BY (View License)