import os
from PIL import Image
from torch.utils.data import Dataset
import random
class CelebADataset(Dataset):
def __init__(self, root_dir, attr_file, transform=None, filters=None, samples=1000):
"""
Args:
root_dir (str): Directory with all the images.
attr_file (str): Path to the attribute file (list_attr_celeba.txt).
transform (callable, optional): Optional transform to be applied on an image.
filters (dict, optional): Dictionary where key is an attribute (e.g., 'Male')
and value is the desired label (1 or -1).
samples (int, optional): Number of images to use from the filtered dataset.
"""
self.root_dir = root_dir
self.transform = transform
self.filters = filters or {}
# Read the attribute file
with open(attr_file, "r") as f:
= f.readlines()
lines
# The second line contains the attribute names
= lines[1].strip().split()
attr_names
# Collect all matching samples first
= []
all_samples for line in lines[2:]:
= line.strip().split()
parts = parts[0]
filename = list(map(int, parts[1:]))
attributes = dict(zip(attr_names, attributes))
attr_dict
if all(attr_dict.get(attr) == val for attr, val in self.filters.items()):
all_samples.append((filename, attr_dict))
# Shuffle and select a random subset
random.shuffle(all_samples)self.samples = all_samples[:samples]
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
= self.samples[idx]
filename, attr_dict = os.path.join(self.root_dir, filename)
img_path = Image.open(img_path).convert("RGB")
image if self.transform:
= self.transform(image)
image return image, attr_dict
:link Blade Runner came out in 1982 and is a classic science fiction movie directed by Ridley Scott. One of the iconic scenes in the movie is when the protagonist, Deckard, uses a computer to “enhance” a photograph to reveal hidden details. This scene has become a meme and a reference in popular culture.
In this experiment, we will train a Variational Autoencoder (VAE) to deblur images as a tribute to the “enhance” effect from Blade Runner, where we take a blurry image and reconstruct it to reveal hidden details. We will use the CelebA dataset, which contains images of celebrities, and train the VAE to deblur these images.
This is a continuation of the Mondrian VAE experiment, where we trained a VAE to reconstruct masked images in the style of Piet Mondrian. The VAE architecture will be similar, but we will focus on deblurring instead of reconstructing portions of the image.
Loading the dataset
We will start by creating a class which can load and return samples from CelebA. The dataset is filtered based on attributes, and we can specify the number of samples to use. CelebA is composed of low resolution images (218x178), with varying degrees of quality, and in many different settings.
Let us show a few random images from the dataset to get an idea of the quality and diversity of the images. Note how we can pass a dictionary of filters to the dataset to select only images with specific attributes (for example, 'Male'=1, 'Goatee'=-1
to select images of only male celebrities without a goatee).
In this case, we are not using any filters, so we will get a random selection of images.
from torchvision import transforms
= transforms.Compose([transforms.ToTensor()])
transform
= {}
filters
# Instantiate the dataset
= CelebADataset(
dataset =f'{os.environ["DATASET"]}/img_align_celeba',
root_dir=f'{os.environ["DATASET"]}/list_attr_celeba.txt',
attr_file=transform,
transform=filters,
filters )
VAE model architecture
As explained in the Mondrian VAE experiment, the VAE architecture consists of an encoder and a decoder. The encoder downsamples the input image into a latent representation, and the decoder upsamples this latent representation to reconstruct the original image. Just as before, we will use skip connections between the encoder and decoder to improve the reconstruction quality (a U-NET style architecture). It still makes sense to use this approach in the case of deblurring since the input and output images are structurally identical - we’re not changing content, just removing degradation. The skip connections allow the network to bypass low-level features (like edges, contours, textures) from the encoder directly to the decoder, which helps reconstruct sharp details that might otherwise be lost in the bottleneck.
The decoder doesn’t need to learn how to recreate fine structure from scratch, it can just re-use it, correcting for the blur. This leads to faster convergence, better visual quality, and fewer artifacts.
The only necessary adaptation is to adapt the model for a different resolution (218x178 rather than 256x256) and to add a blur transformation to the training loop. We will apply a random Gaussian blur to each image in the batch before feeding it to the model. This simulates the effect of a blurry image that we want to deblur.
import torch
import torch.nn as nn
class Encoder(nn.Module):
"""Downsampling encoder that captures intermediate features for skip connections."""
def __init__(self, latent_dim=128):
super(Encoder, self).__init__()
self.enc1 = nn.Sequential(
3, 32, kernel_size=4, stride=2, padding=1), # 218x178 -> 109x89
nn.Conv2d(
nn.ReLU(),
)self.enc2 = nn.Sequential(
32, 64, kernel_size=4, stride=2, padding=1), # 109x89 -> 54x44
nn.Conv2d(
nn.ReLU(),
)self.enc3 = nn.Sequential(
64, 128, kernel_size=4, stride=2, padding=1), # 54x44 -> 27x22
nn.Conv2d(
nn.ReLU(),
)self.enc4 = nn.Sequential(
128, 256, kernel_size=4, stride=2, padding=1), # 27x22 -> 13x11
nn.Conv2d(
nn.ReLU(),
)# Flattened dimension: 256 * 13 * 11 = 36608
self.fc_mu = nn.Linear(256 * 13 * 11, latent_dim)
self.fc_logvar = nn.Linear(256 * 13 * 11, latent_dim)
def forward(self, x):
= self.enc1(x) # [B, 32, 109, 89]
f1 = self.enc2(f1) # [B, 64, 54, 44]
f2 = self.enc3(f2) # [B, 128, 27, 22]
f3 = self.enc4(f3) # [B, 256, 13, 11]
f4 = f4.view(f4.size(0), -1)
flat = self.fc_mu(flat)
mu = self.fc_logvar(flat)
logvar return f1, f2, f3, f4, mu, logvar
class Decoder(nn.Module):
"""Upsampling decoder that uses skip connections from the encoder."""
def __init__(self, latent_dim=128):
super(Decoder, self).__init__()
# Expand latent vector to match encoder's last feature map shape (256 x 13 x 11)
self.fc_dec = nn.Linear(latent_dim, 256 * 13 * 11)
# Up 1: f4 -> (B,256,13,11) -> upsample -> (B,256,27,22) to match f3 dimensions.
# Use output_padding=(1,0) so that:
# Height: (13-1)*2 -2 +4 +1 = 27 and Width: (11-1)*2 -2 +4 +0 = 22.
self.up4 = nn.ConvTranspose2d(
256, 256, kernel_size=4, stride=2, padding=1, output_padding=(1, 0)
)self.conv4 = nn.Sequential(
256 + 128, 128, kernel_size=3, padding=1), nn.ReLU()
nn.Conv2d(
)
# Up 2: (B,128,27,22) -> upsample -> (B,128,54,44) to match f2.
self.up3 = nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1)
self.conv3 = nn.Sequential(
128 + 64, 64, kernel_size=3, padding=1), nn.ReLU()
nn.Conv2d(
)
# Up 3: (B,64,54,44) -> upsample -> (B,64,109,89) to match f1.
# Set output_padding=(1,1) to get:
# Height: (54-1)*2 -2 +4 +1 = 109 and Width: (44-1)*2 -2 +4 +1 = 89.
self.up2 = nn.ConvTranspose2d(
64, 64, kernel_size=4, stride=2, padding=1, output_padding=(1, 1)
)self.conv2 = nn.Sequential(
64 + 32, 32, kernel_size=3, padding=1), nn.ReLU()
nn.Conv2d(
)
# Up 4: (B,32,109,89) -> upsample -> (B,32,218,178) -> final conv to 3 channels.
self.up1 = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1)
self.conv1 = nn.Sequential(
32, 3, kernel_size=3, padding=1), nn.Sigmoid()
nn.Conv2d(
)
def forward(self, z, f1, f2, f3, f4):
# Expand latent vector to spatial feature map: [B,256,13,11]
= self.fc_dec(z).view(-1, 256, 13, 11)
x
# Up 1 (with skip connection from f3)
= self.up4(x) # -> [B,256,27,22]
x = torch.cat([x, f3], dim=1) # Concatenate with f3: [B,256+128,27,22]
x = self.conv4(x) # -> [B,128,27,22]
x
# Up 2 (with skip connection from f2)
= self.up3(x) # -> [B,128,54,44]
x = torch.cat([x, f2], dim=1) # -> [B,128+64,54,44]
x = self.conv3(x) # -> [B,64,54,44]
x
# Up 3 (with skip connection from f1)
= self.up2(x) # -> [B,64,109,89]
x = torch.cat([x, f1], dim=1) # -> [B,64+32,109,89]
x = self.conv2(x) # -> [B,32,109,89]
x
# Up 4: final upsampling to original resolution
= self.up1(x) # -> [B,32,218,178]
x = self.conv1(x) # -> [B,3,218,178]
x return x
# The VAE model
class VAE_UNet(nn.Module):
"""U-Net style VAE that returns reconstruction, mu, logvar."""
def __init__(self, latent_dim=128):
super(VAE_UNet, self).__init__()
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim)
def reparameterize(self, mu, logvar):
= torch.exp(0.5 * logvar)
std = torch.randn_like(std)
eps return mu + eps * std
def forward(self, x):
= self.encoder(x)
f1, f2, f3, f4, mu, logvar = self.reparameterize(mu, logvar)
z = self.decoder(z, f1, f2, f3, f4)
recon return recon, mu, logvar
The annealing and loss functions are mostly the same as in the Mondrian VAE experiment.
# The KL annealing function
def kl_anneal_function(epoch, start_epoch=0, end_epoch=10):
"""
Linearly scales KL weight from 0.0 to 1.0 between start_epoch and end_epoch.
"""
if epoch < start_epoch:
return 0.0
elif epoch > end_epoch:
return 1.0
else:
return (epoch - start_epoch) / (end_epoch - start_epoch)
One difference for this use case is that previously we used the Mean Squared Error (MSE) loss for the reconstruction. However, for the deblurring task, we will use the L1 loss function instead. L1 is less sensitive to outliers and can produce sharper images, which is desirable for deblurring tasks. As before, the loss function includes a KL divergence term, which regularizes the latent space to follow a standard normal distribution.
As an exercise, you might want to try a perceptual loss function, such as VGG16 or LPIPS, to see if it improves the quality of the reconstructions. These loss functions are designed to capture perceptual similarity between images, which can be more effective than pixel-wise losses for tasks like deblurring.
import torch.nn.functional as F
def loss_function(recon_x, x, mu, logvar, kl_weight):
# Reconstruction loss using L1 instead of MSE
= F.l1_loss(recon_x, x, reduction="sum")
recon_loss # KL divergence
= -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
KL_loss return recon_loss + kl_weight * KL_loss
The training loop
The training loop is similar to our previous experiment, with the addition of a random blur applied to each image in the batch. We use a different level of blur for each sample in the batch to simulate varying degrees of blur by using a different kernel size and sigma randomly chosen from a range of values. Another difference from before is that in this case we will measure both training and validation losses in the loop, as we want to ensure that the model generalizes well to unseen data.
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
def train_vae_deblur(
=20, inferences=10
model, train_loader, val_loader, optimizer, device, epochs
):"""
Trains the model on the deblurring task with validation.
Applies a different random blur per sample (batch-wise) and uses inference_deblur for visualisation.
"""
# Create a SummaryWriter for TensorBoard logging
= SummaryWriter(log_dir="/tmp/runs/deblur_experiment")
writer
model.train()= max(1, epochs // inferences)
interval = []
train_losses = []
val_losses
for epoch in range(epochs):
= kl_anneal_function(epoch, 0, epochs // 2)
kl_weight = 0
total_train_loss = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]", leave=False)
progress
for img, _ in progress:
= img.to(device)
img
# Apply random blur to each sample in the batch
= []
blurred_batch for sample in img:
= random.choice(range(5, 16, 2)) # odd kernel size
k = random.uniform(1.5, 3.0) # sigma
s = transforms.GaussianBlur(kernel_size=k, sigma=s)
blur = blur(sample.unsqueeze(0))
blurred
blurred_batch.append(blurred)
= torch.cat(blurred_batch, dim=0).to(device)
blurred_img
optimizer.zero_grad()= model(blurred_img)
recon, mu, logvar = loss_function(recon, img, mu, logvar, kl_weight)
loss
loss.backward()
optimizer.step()
+= loss.item()
total_train_loss
progress.set_postfix(=f"{loss.item():.4f}", KL_Weight=f"{kl_weight:.2f}"
loss
)
= total_train_loss / len(train_loader.dataset)
avg_train_loss
train_losses.append(avg_train_loss)
# Validation
eval()
model.= 0
total_val_loss
with torch.no_grad():
for img, _ in val_loader:
= img.to(device)
img
= []
blurred_batch for sample in img:
= random.choice(range(5, 16, 2))
k = random.uniform(1.5, 3.0)
s = transforms.GaussianBlur(kernel_size=k, sigma=s)
blur = blur(sample.unsqueeze(0))
blurred
blurred_batch.append(blurred)
= torch.cat(blurred_batch, dim=0).to(device)
blurred_img = model(blurred_img)
recon, mu, logvar = loss_function(recon, img, mu, logvar, kl_weight)
loss += loss.item()
total_val_loss
= total_val_loss / len(val_loader.dataset)
avg_val_loss
val_losses.append(avg_val_loss)
# Log scalar values to TensorBoard
"Loss/Train", avg_train_loss, epoch)
writer.add_scalar("Loss/Val", avg_val_loss, epoch)
writer.add_scalar("KL Weight", kl_weight, epoch)
writer.add_scalar(
# Log images every 'interval' epochs (and at epoch 0)
if (epoch + 1) % interval == 0 or epoch == 0:
# Get one batch from the validation set for visual logging
with torch.no_grad():
for img, _ in val_loader:
= img.to(device)
img = []
blurred_batch for sample in img:
= random.choice(range(5, 16, 2))
k = random.uniform(1.5, 3.0)
s = transforms.GaussianBlur(kernel_size=k, sigma=s)
blur = blur(sample.unsqueeze(0))
blurred
blurred_batch.append(blurred)= torch.cat(blurred_batch, dim=0).to(device)
blurred_img = model(blurred_img)
recon, mu, logvar break # Use the first batch
# Create grids of images (normalize for visualization)
= make_grid(img, normalize=True, scale_each=True)
original_grid = make_grid(blurred_img, normalize=True, scale_each=True)
blurred_grid = make_grid(recon, normalize=True, scale_each=True)
recon_grid
"Validation/Original", original_grid, epoch)
writer.add_image("Validation/Blurred", blurred_grid, epoch)
writer.add_image("Validation/Reconstructed", recon_grid, epoch)
writer.add_image(
inference_deblur(model, device, val_loader, epoch)
model.train()
writer.close()return train_losses, val_losses
The training loop uses the following function to perform inference on a single image from the dataset. It takes an image from the dataloader, applies a blur, reconstructs it, and then plots the original, blurred, and reconstructed images side by side. This function is useful to visualize the deblurring effect of the model during training so we can see how well the model is performing as training progresses.
def inference_deblur(model, device, dataloader, epoch=0, blur_transform=None):
"""
Performs inference on a random image from a random batch in the dataloader.
It applies the blur, reconstructs it, computes the MSE, and then plots the original,
blurred, and reconstructed images with the MSE in the title.
"""
eval()
model.# Get the total number of batches and choose one at random
= len(dataloader)
num_batches = random.randint(0, num_batches - 1)
random_batch_index
# Iterate through the dataloader until the random batch is reached
for i, (img, _) in enumerate(dataloader):
if i == random_batch_index:
# Pick a random image from this batch
= random.randint(0, img.size(0) - 1)
random_image_index = img[random_image_index].unsqueeze(0).to(device)
original break
if blur_transform is None:
= transforms.GaussianBlur(kernel_size=9, sigma=2.0)
blur_transform = blur_transform(original)
blurred
with torch.no_grad():
= model(blurred)
recon, _, _ = torch.nn.functional.mse_loss(recon, original) # Compute MSE
mse
# Convert tensors to NumPy arrays for plotting
= original.squeeze(0).permute(1, 2, 0).cpu().numpy()
original_np = np.clip(blurred.squeeze(0).permute(1, 2, 0).cpu().numpy(), 0, 1)
blurred_np = np.clip(recon.squeeze(0).permute(1, 2, 0).cpu().numpy(), 0, 1)
recon_np
# Plot the original, blurred, and reconstructed images side by side
= plt.subplots(1, 3, figsize=(8, 4))
fig, axs f"Epoch: {epoch+1}, MSE: {mse.item():.4f}", fontsize=14)
fig.suptitle(
0].imshow(original_np)
axs[0].set_title("Original")
axs[0].axis("off")
axs[
1].imshow(blurred_np)
axs[1].set_title("Blurred")
axs[1].axis("off")
axs[
2].imshow(recon_np)
axs[2].set_title("Reconstructed")
axs[2].axis("off")
axs[
plt.show()
Finally let us put it all together, instantiate the model, optimizer, and dataloaders, and train the model. You might have noticed that we are using the same learning rate and batch size as in the Mondrian VAE experiment. To understand the interplay between these two hyperparameters, you could experiment with different values to see how they affect the training dynamics and final results. For example, try a smaller learning rate to see if the model is capable of learning more subtle details, accompanied by a smaller batch size to prevent the model from getting stuck in local minima.
The choice of learning rate and batch size plays a critical role in the performance, stability, and convergence speed of the model. While these hyperparameters are often tuned experimentally, understanding their individual and combined impact can guide decisions during model development.
The learning rate determines how big a step the optimizer takes in the direction of the gradient at each iteration. A learning rate that is too high can cause the model to overshoot minima in the loss landscape, leading to divergence or oscillating loss. On the other hand, a learning rate that is too low can result in painfully slow training and may cause the model to get stuck in suboptimal solutions. Common practice involves starting with values like \(10^{-3}\) or \(10^{-4}\), then adapting with schedulers or learning rate warm-up strategies depending on the model and task complexity.
Batch size, which defines how many samples are processed before the model updates its weights, also affects training dynamics. Smaller batch sizes introduce more noise into the gradient estimate, which can act as a regularizer and potentially help generalisation, but may also lead to instability if the learning rate isn’t adjusted accordingly. Larger batch sizes, on the other hand, provide smoother and more accurate gradient estimates, often leading to faster convergence, but can risk poorer generalisation.
There’s also a strong interplay between batch size and learning rate. As a general rule, larger batch sizes can support proportionally larger learning rates - this is one of the ideas behind the :link neural scaling law. Conversely, smaller batches usually require a smaller learning rate to remain stable. When tuned together, these parameters have a significant impact on model performance, generalisation, and training.
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from torch.utils.data import random_split
# Set a random seed for reproducibility
1)
random.seed(1)
np.random.seed(1)
torch.manual_seed(
= torch.device("mps" if torch.mps.is_available() else "cpu")
device
= VAE_UNet(latent_dim=128).to(device)
model
# Set the optimizer, using a small learning rate
= optim.Adam(model.parameters(), lr=1e-4)
optimizer
# Set the batch size, using a small value balanced by a smaller learning rate
= 32
batch_size
# Recreate the dataset
= CelebADataset(
dataset =f'{os.environ["DATASET"]}/img_align_celeba',
root_dir=f'{os.environ["DATASET"]}/list_attr_celeba.txt',
attr_file=transforms.Compose(
transform
[
transforms.ToTensor(),
]
),={},
filters=15000,
samples
)
= int(0.8 * len(dataset))
train_size = len(dataset) - train_size
val_size = random_split(dataset, [train_size, val_size])
train_dataset, val_dataset
= DataLoader(
train_dataloader =batch_size, shuffle=True, num_workers=0
train_dataset, batch_size
)= DataLoader(
val_dataloader =batch_size, shuffle=False, num_workers=0
val_dataset, batch_size
)
= 240
epochs
= train_vae_deblur(
train_losses, val_losses
model,
train_dataloader,
val_dataloader,
optimizer,
device,=epochs,
epochs=6,
inferences )
Notice how by epoch 40 the model is starting to reconstruct detail which is barely visible in the blurred image. By epoch 80, it reconstructs the wire fence behind the person, which is pretty much lost in the blurred image. At 120 you see that hair details starting to be reconstructed, and by epoch 160 hair and facial features are much clearer. The model continues to improve until the end of training, with the final images showing a significant improvement over the original blurred images, with the example at 200 showing a very clear reconstruction of the original image.
Keep in mind that the inference_deblur
function is showing the results of the model on images from the validation set, while the model is trained only on the training set. That is, the results above are on unseen data, with the model inferring details by “guessing” what the original should look like based on the training images alone!
Results
With training finished (note that it will likely take between a couple of hours, to a whole day, depending on your hardware), we can plot the training and validation losses to see how the model performed over time.
# Plot the training and validation losses, on a log scale
=(8, 4))
plt.figure(figsizerange(1, epochs + 1), train_losses, label="Train")
plt.plot(range(1, epochs + 1), val_losses, label="Validation")
plt.plot("log")
plt.yscale("Epoch")
plt.xlabel("Loss")
plt.ylabel(True)
plt.grid("Training and Validation Losses")
plt.title(
plt.legend() plt.show()
The training and validation losses follow closely, indicating that the model is learning effectively over time, with room for improvement. The validation loss is slightly higher than the training loss, which is expected as the model is optimized for the training set. The log scale helps to visualize the losses more clearly, as they can vary significantly over epochs.
Finally, we can perform inference on a few random images from the validation set to see how well the model performs generally.
# Perform inference on 4 random images from the validation set, showing the original, blurred, and reconstructed images
for _ in range(4):
= random.choice(range(5, 16, 2))
k = random.uniform(1.5, 3.0)
s = transforms.GaussianBlur(kernel_size=k, sigma=s)
blur_transform
inference_deblur(=epochs - 1, blur_transform=blur_transform
model, device, val_dataloader, epoch )
The second example above is particularly interesting. Notice how the model reconstructed the mouth, and the eyes. The original image is very blurry, and the model managed to infer the facial features quite well. However both the eye and lip shape isn’t quite right, as it didn’t have enough information to infer the exact shape or position of these features. This is common in generative models, where the model will “average out” the features it sees in the training set, and can’t always infer the exact details of the original image.
Final remarks
In this experiment, we trained a Variational Autoencoder to deblur images from the CelebA dataset. We used a similar architecture to the Mondrian VAE experiment, but with a target task which is completely different. It shows the flexibility of the variational autoencoder architecture, which can be adapted to many different problems requiring generative capabilities without needing extensive modifications.