Deblurring, a Classic Machine Learning Problem

train a variational autoencoder to deblur images
Experiments
Machine Learning
Published

March 20, 2025

:link Blade Runner came out in 1982 and is a classic science fiction movie directed by Ridley Scott. One of the iconic scenes in the movie is when the protagonist, Deckard, uses a computer to “enhance” a photograph to reveal hidden details. This scene has become a meme and a reference in popular culture.

In this experiment, we will train a Variational Autoencoder (VAE) to deblur images as a tribute to the “enhance” effect from Blade Runner, where we take a blurry image and reconstruct it to reveal hidden details. We will use the CelebA dataset, which contains images of celebrities, and train the VAE to deblur these images.

This is a continuation of the Mondrian VAE experiment, where we trained a VAE to reconstruct masked images in the style of Piet Mondrian. The VAE architecture will be similar, but we will focus on deblurring instead of reconstructing portions of the image.

Loading the dataset

We will start by creating a class which can load and return samples from CelebA. The dataset is filtered based on attributes, and we can specify the number of samples to use. CelebA is composed of low resolution images (218x178), with varying degrees of quality, and in many different settings.

import os
from PIL import Image
from torch.utils.data import Dataset
import random


class CelebADataset(Dataset):
    def __init__(self, root_dir, attr_file, transform=None, filters=None, samples=1000):
        """
        Args:
            root_dir (str): Directory with all the images.
            attr_file (str): Path to the attribute file (list_attr_celeba.txt).
            transform (callable, optional): Optional transform to be applied on an image.
            filters (dict, optional): Dictionary where key is an attribute (e.g., 'Male')
                                      and value is the desired label (1 or -1).
            samples (int, optional): Number of images to use from the filtered dataset.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.filters = filters or {}

        # Read the attribute file
        with open(attr_file, "r") as f:
            lines = f.readlines()

        # The second line contains the attribute names
        attr_names = lines[1].strip().split()

        # Collect all matching samples first
        all_samples = []
        for line in lines[2:]:
            parts = line.strip().split()
            filename = parts[0]
            attributes = list(map(int, parts[1:]))
            attr_dict = dict(zip(attr_names, attributes))

            if all(attr_dict.get(attr) == val for attr, val in self.filters.items()):
                all_samples.append((filename, attr_dict))

        # Shuffle and select a random subset
        random.shuffle(all_samples)
        self.samples = all_samples[:samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        filename, attr_dict = self.samples[idx]
        img_path = os.path.join(self.root_dir, filename)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, attr_dict

Let us show a few random images from the dataset to get an idea of the quality and diversity of the images. Note how we can pass a dictionary of filters to the dataset to select only images with specific attributes (for example, 'Male'=1, 'Goatee'=-1 to select images of only male celebrities without a goatee).

In this case, we are not using any filters, so we will get a random selection of images.

from torchvision import transforms

transform = transforms.Compose([transforms.ToTensor()])

filters = {}

# Instantiate the dataset
dataset = CelebADataset(
    root_dir=f'{os.environ["DATASET"]}/img_align_celeba',
    attr_file=f'{os.environ["DATASET"]}/list_attr_celeba.txt',
    transform=transform,
    filters=filters,
)
import random
import matplotlib.pyplot as plt

# Display a few random images
fig, axs = plt.subplots(2, 2, figsize=(6, 6))
for i, ax in enumerate(axs.flat):
    idx = random.randint(0, len(dataset) - 1)
    image, attributes = dataset[idx]
    ax.imshow(image.permute(1, 2, 0))
    ax.axis("off")
plt.show()

VAE model architecture

As explained in the Mondrian VAE experiment, the VAE architecture consists of an encoder and a decoder. The encoder downsamples the input image into a latent representation, and the decoder upsamples this latent representation to reconstruct the original image. Just as before, we will use skip connections between the encoder and decoder to improve the reconstruction quality (a U-NET style architecture). It still makes sense to use this approach in the case of deblurring since the input and output images are structurally identical - we’re not changing content, just removing degradation. The skip connections allow the network to bypass low-level features (like edges, contours, textures) from the encoder directly to the decoder, which helps reconstruct sharp details that might otherwise be lost in the bottleneck.

The decoder doesn’t need to learn how to recreate fine structure from scratch, it can just re-use it, correcting for the blur. This leads to faster convergence, better visual quality, and fewer artifacts.

The only necessary adaptation is to adapt the model for a different resolution (218x178 rather than 256x256) and to add a blur transformation to the training loop. We will apply a random Gaussian blur to each image in the batch before feeding it to the model. This simulates the effect of a blurry image that we want to deblur.

import torch
import torch.nn as nn


class Encoder(nn.Module):
    """Downsampling encoder that captures intermediate features for skip connections."""

    def __init__(self, latent_dim=128):
        super(Encoder, self).__init__()
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # 218x178 -> 109x89
            nn.ReLU(),
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 109x89 -> 54x44
            nn.ReLU(),
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 54x44 -> 27x22
            nn.ReLU(),
        )
        self.enc4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # 27x22 -> 13x11
            nn.ReLU(),
        )
        # Flattened dimension: 256 * 13 * 11 = 36608
        self.fc_mu = nn.Linear(256 * 13 * 11, latent_dim)
        self.fc_logvar = nn.Linear(256 * 13 * 11, latent_dim)

    def forward(self, x):
        f1 = self.enc1(x)  # [B, 32, 109, 89]
        f2 = self.enc2(f1)  # [B, 64, 54, 44]
        f3 = self.enc3(f2)  # [B, 128, 27, 22]
        f4 = self.enc4(f3)  # [B, 256, 13, 11]
        flat = f4.view(f4.size(0), -1)
        mu = self.fc_mu(flat)
        logvar = self.fc_logvar(flat)
        return f1, f2, f3, f4, mu, logvar
class Decoder(nn.Module):
    """Upsampling decoder that uses skip connections from the encoder."""

    def __init__(self, latent_dim=128):
        super(Decoder, self).__init__()
        # Expand latent vector to match encoder's last feature map shape (256 x 13 x 11)
        self.fc_dec = nn.Linear(latent_dim, 256 * 13 * 11)

        # Up 1: f4 -> (B,256,13,11) -> upsample -> (B,256,27,22) to match f3 dimensions.
        # Use output_padding=(1,0) so that:
        # Height: (13-1)*2 -2 +4 +1 = 27 and Width: (11-1)*2 -2 +4 +0 = 22.
        self.up4 = nn.ConvTranspose2d(
            256, 256, kernel_size=4, stride=2, padding=1, output_padding=(1, 0)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(256 + 128, 128, kernel_size=3, padding=1), nn.ReLU()
        )

        # Up 2: (B,128,27,22) -> upsample -> (B,128,54,44) to match f2.
        self.up3 = nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Sequential(
            nn.Conv2d(128 + 64, 64, kernel_size=3, padding=1), nn.ReLU()
        )

        # Up 3: (B,64,54,44) -> upsample -> (B,64,109,89) to match f1.
        # Set output_padding=(1,1) to get:
        # Height: (54-1)*2 -2 +4 +1 = 109 and Width: (44-1)*2 -2 +4 +1 = 89.
        self.up2 = nn.ConvTranspose2d(
            64, 64, kernel_size=4, stride=2, padding=1, output_padding=(1, 1)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64 + 32, 32, kernel_size=3, padding=1), nn.ReLU()
        )

        # Up 4: (B,32,109,89) -> upsample -> (B,32,218,178) -> final conv to 3 channels.
        self.up1 = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.Sequential(
            nn.Conv2d(32, 3, kernel_size=3, padding=1), nn.Sigmoid()
        )

    def forward(self, z, f1, f2, f3, f4):
        # Expand latent vector to spatial feature map: [B,256,13,11]
        x = self.fc_dec(z).view(-1, 256, 13, 11)

        # Up 1 (with skip connection from f3)
        x = self.up4(x)  # -> [B,256,27,22]
        x = torch.cat([x, f3], dim=1)  # Concatenate with f3: [B,256+128,27,22]
        x = self.conv4(x)  # -> [B,128,27,22]

        # Up 2 (with skip connection from f2)
        x = self.up3(x)  # -> [B,128,54,44]
        x = torch.cat([x, f2], dim=1)  # -> [B,128+64,54,44]
        x = self.conv3(x)  # -> [B,64,54,44]

        # Up 3 (with skip connection from f1)
        x = self.up2(x)  # -> [B,64,109,89]
        x = torch.cat([x, f1], dim=1)  # -> [B,64+32,109,89]
        x = self.conv2(x)  # -> [B,32,109,89]

        # Up 4: final upsampling to original resolution
        x = self.up1(x)  # -> [B,32,218,178]
        x = self.conv1(x)  # -> [B,3,218,178]
        return x
# The VAE model
class VAE_UNet(nn.Module):
    """U-Net style VAE that returns reconstruction, mu, logvar."""

    def __init__(self, latent_dim=128):
        super(VAE_UNet, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        f1, f2, f3, f4, mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decoder(z, f1, f2, f3, f4)
        return recon, mu, logvar

The annealing and loss functions are mostly the same as in the Mondrian VAE experiment.

# The KL annealing function
def kl_anneal_function(epoch, start_epoch=0, end_epoch=10):
    """
    Linearly scales KL weight from 0.0 to 1.0 between start_epoch and end_epoch.
    """
    if epoch < start_epoch:
        return 0.0
    elif epoch > end_epoch:
        return 1.0
    else:
        return (epoch - start_epoch) / (end_epoch - start_epoch)

One difference for this use case is that previously we used the Mean Squared Error (MSE) loss for the reconstruction. However, for the deblurring task, we will use the L1 loss function instead. L1 is less sensitive to outliers and can produce sharper images, which is desirable for deblurring tasks. As before, the loss function includes a KL divergence term, which regularizes the latent space to follow a standard normal distribution.

Note

As an exercise, you might want to try a perceptual loss function, such as VGG16 or LPIPS, to see if it improves the quality of the reconstructions. These loss functions are designed to capture perceptual similarity between images, which can be more effective than pixel-wise losses for tasks like deblurring.

import torch.nn.functional as F


def loss_function(recon_x, x, mu, logvar, kl_weight):
    # Reconstruction loss using L1 instead of MSE
    recon_loss = F.l1_loss(recon_x, x, reduction="sum")
    # KL divergence
    KL_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_weight * KL_loss

The training loop

The training loop is similar to our previous experiment, with the addition of a random blur applied to each image in the batch. We use a different level of blur for each sample in the batch to simulate varying degrees of blur by using a different kernel size and sigma randomly chosen from a range of values. Another difference from before is that in this case we will measure both training and validation losses in the loop, as we want to ensure that the model generalizes well to unseen data.

from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid


def train_vae_deblur(
    model, train_loader, val_loader, optimizer, device, epochs=20, inferences=10
):
    """
    Trains the model on the deblurring task with validation.
    Applies a different random blur per sample (batch-wise) and uses inference_deblur for visualisation.
    """
    # Create a SummaryWriter for TensorBoard logging
    writer = SummaryWriter(log_dir="/tmp/runs/deblur_experiment")

    model.train()
    interval = max(1, epochs // inferences)
    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        kl_weight = kl_anneal_function(epoch, 0, epochs // 2)
        total_train_loss = 0
        progress = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]", leave=False)

        for img, _ in progress:
            img = img.to(device)

            # Apply random blur to each sample in the batch
            blurred_batch = []
            for sample in img:
                k = random.choice(range(5, 16, 2))  # odd kernel size
                s = random.uniform(1.5, 3.0)  # sigma
                blur = transforms.GaussianBlur(kernel_size=k, sigma=s)
                blurred = blur(sample.unsqueeze(0))
                blurred_batch.append(blurred)

            blurred_img = torch.cat(blurred_batch, dim=0).to(device)

            optimizer.zero_grad()
            recon, mu, logvar = model(blurred_img)
            loss = loss_function(recon, img, mu, logvar, kl_weight)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            progress.set_postfix(
                loss=f"{loss.item():.4f}", KL_Weight=f"{kl_weight:.2f}"
            )

        avg_train_loss = total_train_loss / len(train_loader.dataset)
        train_losses.append(avg_train_loss)

        # Validation
        model.eval()
        total_val_loss = 0

        with torch.no_grad():
            for img, _ in val_loader:
                img = img.to(device)

                blurred_batch = []
                for sample in img:
                    k = random.choice(range(5, 16, 2))
                    s = random.uniform(1.5, 3.0)
                    blur = transforms.GaussianBlur(kernel_size=k, sigma=s)
                    blurred = blur(sample.unsqueeze(0))
                    blurred_batch.append(blurred)

                blurred_img = torch.cat(blurred_batch, dim=0).to(device)
                recon, mu, logvar = model(blurred_img)
                loss = loss_function(recon, img, mu, logvar, kl_weight)
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader.dataset)
        val_losses.append(avg_val_loss)

        # Log scalar values to TensorBoard
        writer.add_scalar("Loss/Train", avg_train_loss, epoch)
        writer.add_scalar("Loss/Val", avg_val_loss, epoch)
        writer.add_scalar("KL Weight", kl_weight, epoch)

        # Log images every 'interval' epochs (and at epoch 0)
        if (epoch + 1) % interval == 0 or epoch == 0:
            # Get one batch from the validation set for visual logging
            with torch.no_grad():
                for img, _ in val_loader:
                    img = img.to(device)
                    blurred_batch = []
                    for sample in img:
                        k = random.choice(range(5, 16, 2))
                        s = random.uniform(1.5, 3.0)
                        blur = transforms.GaussianBlur(kernel_size=k, sigma=s)
                        blurred = blur(sample.unsqueeze(0))
                        blurred_batch.append(blurred)
                    blurred_img = torch.cat(blurred_batch, dim=0).to(device)
                    recon, mu, logvar = model(blurred_img)
                    break  # Use the first batch

            # Create grids of images (normalize for visualization)
            original_grid = make_grid(img, normalize=True, scale_each=True)
            blurred_grid = make_grid(blurred_img, normalize=True, scale_each=True)
            recon_grid = make_grid(recon, normalize=True, scale_each=True)

            writer.add_image("Validation/Original", original_grid, epoch)
            writer.add_image("Validation/Blurred", blurred_grid, epoch)
            writer.add_image("Validation/Reconstructed", recon_grid, epoch)

            inference_deblur(model, device, val_loader, epoch)
            model.train()

    writer.close()
    return train_losses, val_losses

The training loop uses the following function to perform inference on a single image from the dataset. It takes an image from the dataloader, applies a blur, reconstructs it, and then plots the original, blurred, and reconstructed images side by side. This function is useful to visualize the deblurring effect of the model during training so we can see how well the model is performing as training progresses.

def inference_deblur(model, device, dataloader, epoch=0, blur_transform=None):
    """
    Performs inference on a random image from a random batch in the dataloader.
    It applies the blur, reconstructs it, computes the MSE, and then plots the original,
    blurred, and reconstructed images with the MSE in the title.
    """
    model.eval()
    # Get the total number of batches and choose one at random
    num_batches = len(dataloader)
    random_batch_index = random.randint(0, num_batches - 1)

    # Iterate through the dataloader until the random batch is reached
    for i, (img, _) in enumerate(dataloader):
        if i == random_batch_index:
            # Pick a random image from this batch
            random_image_index = random.randint(0, img.size(0) - 1)
            original = img[random_image_index].unsqueeze(0).to(device)
            break

    if blur_transform is None:
        blur_transform = transforms.GaussianBlur(kernel_size=9, sigma=2.0)
    blurred = blur_transform(original)

    with torch.no_grad():
        recon, _, _ = model(blurred)
        mse = torch.nn.functional.mse_loss(recon, original)  # Compute MSE

    # Convert tensors to NumPy arrays for plotting
    original_np = original.squeeze(0).permute(1, 2, 0).cpu().numpy()
    blurred_np = np.clip(blurred.squeeze(0).permute(1, 2, 0).cpu().numpy(), 0, 1)
    recon_np = np.clip(recon.squeeze(0).permute(1, 2, 0).cpu().numpy(), 0, 1)

    # Plot the original, blurred, and reconstructed images side by side
    fig, axs = plt.subplots(1, 3, figsize=(8, 4))
    fig.suptitle(f"Epoch: {epoch+1}, MSE: {mse.item():.4f}", fontsize=14)

    axs[0].imshow(original_np)
    axs[0].set_title("Original")
    axs[0].axis("off")

    axs[1].imshow(blurred_np)
    axs[1].set_title("Blurred")
    axs[1].axis("off")

    axs[2].imshow(recon_np)
    axs[2].set_title("Reconstructed")
    axs[2].axis("off")

    plt.show()

Finally let us put it all together, instantiate the model, optimizer, and dataloaders, and train the model. You might have noticed that we are using the same learning rate and batch size as in the Mondrian VAE experiment. To understand the interplay between these two hyperparameters, you could experiment with different values to see how they affect the training dynamics and final results. For example, try a smaller learning rate to see if the model is capable of learning more subtle details, accompanied by a smaller batch size to prevent the model from getting stuck in local minima.

Note

The choice of learning rate and batch size plays a critical role in the performance, stability, and convergence speed of the model. While these hyperparameters are often tuned experimentally, understanding their individual and combined impact can guide decisions during model development.

The learning rate determines how big a step the optimizer takes in the direction of the gradient at each iteration. A learning rate that is too high can cause the model to overshoot minima in the loss landscape, leading to divergence or oscillating loss. On the other hand, a learning rate that is too low can result in painfully slow training and may cause the model to get stuck in suboptimal solutions. Common practice involves starting with values like \(10^{-3}\) or \(10^{-4}\), then adapting with schedulers or learning rate warm-up strategies depending on the model and task complexity.

Batch size, which defines how many samples are processed before the model updates its weights, also affects training dynamics. Smaller batch sizes introduce more noise into the gradient estimate, which can act as a regularizer and potentially help generalisation, but may also lead to instability if the learning rate isn’t adjusted accordingly. Larger batch sizes, on the other hand, provide smoother and more accurate gradient estimates, often leading to faster convergence, but can risk poorer generalisation.

There’s also a strong interplay between batch size and learning rate. As a general rule, larger batch sizes can support proportionally larger learning rates - this is one of the ideas behind the :link neural scaling law. Conversely, smaller batches usually require a smaller learning rate to remain stable. When tuned together, these parameters have a significant impact on model performance, generalisation, and training.

import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from torch.utils.data import random_split

# Set a random seed for reproducibility
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)

device = torch.device("mps" if torch.mps.is_available() else "cpu")

model = VAE_UNet(latent_dim=128).to(device)

# Set the optimizer, using a small learning rate
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Set the batch size, using a small value balanced by a smaller learning rate
batch_size = 32

# Recreate the dataset
dataset = CelebADataset(
    root_dir=f'{os.environ["DATASET"]}/img_align_celeba',
    attr_file=f'{os.environ["DATASET"]}/list_attr_celeba.txt',
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    ),
    filters={},
    samples=15000,
)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
)
val_dataloader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, num_workers=0
)

epochs = 240

train_losses, val_losses = train_vae_deblur(
    model,
    train_dataloader,
    val_dataloader,
    optimizer,
    device,
    epochs=epochs,
    inferences=6,
)

Notice how by epoch 40 the model is starting to reconstruct detail which is barely visible in the blurred image. By epoch 80, it reconstructs the wire fence behind the person, which is pretty much lost in the blurred image. At 120 you see that hair details starting to be reconstructed, and by epoch 160 hair and facial features are much clearer. The model continues to improve until the end of training, with the final images showing a significant improvement over the original blurred images, with the example at 200 showing a very clear reconstruction of the original image.

Keep in mind that the inference_deblur function is showing the results of the model on images from the validation set, while the model is trained only on the training set. That is, the results above are on unseen data, with the model inferring details by “guessing” what the original should look like based on the training images alone!

Results

With training finished (note that it will likely take between a couple of hours, to a whole day, depending on your hardware), we can plot the training and validation losses to see how the model performed over time.

# Plot the training and validation losses, on a log scale
plt.figure(figsize=(8, 4))
plt.plot(range(1, epochs + 1), train_losses, label="Train")
plt.plot(range(1, epochs + 1), val_losses, label="Validation")
plt.yscale("log")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.title("Training and Validation Losses")
plt.legend()
plt.show()

The training and validation losses follow closely, indicating that the model is learning effectively over time, with room for improvement. The validation loss is slightly higher than the training loss, which is expected as the model is optimized for the training set. The log scale helps to visualize the losses more clearly, as they can vary significantly over epochs.

Finally, we can perform inference on a few random images from the validation set to see how well the model performs generally.

# Perform inference on 4 random images from the validation set, showing the original, blurred, and reconstructed images
for _ in range(4):
    k = random.choice(range(5, 16, 2))
    s = random.uniform(1.5, 3.0)
    blur_transform = transforms.GaussianBlur(kernel_size=k, sigma=s)
    inference_deblur(
        model, device, val_dataloader, epoch=epochs - 1, blur_transform=blur_transform
    )

The second example above is particularly interesting. Notice how the model reconstructed the mouth, and the eyes. The original image is very blurry, and the model managed to infer the facial features quite well. However both the eye and lip shape isn’t quite right, as it didn’t have enough information to infer the exact shape or position of these features. This is common in generative models, where the model will “average out” the features it sees in the training set, and can’t always infer the exact details of the original image.

Final remarks

In this experiment, we trained a Variational Autoencoder to deblur images from the CelebA dataset. We used a similar architecture to the Mondrian VAE experiment, but with a target task which is completely different. It shows the flexibility of the variational autoencoder architecture, which can be adapted to many different problems requiring generative capabilities without needing extensive modifications.

Reuse

This work is licensed under CC BY (View License)