A key component of training deep learning models is the choice of optimisation algorithm. There are several approaches, ranging from :link simple stochastic gradient descent (SGD) to more advanced methods like Adam. In this experiment, we’ll try to give an intuitive understanding of what optimisation means in the context of machine learning, briefly discussing the Adam algorithm.
What is optimisation?
Optimisation, in its broadest sense, is the process of finding the best solution among many possibilities, adjusting variables to maximize or minimize an objective function. Think of it like tuning a car: you adjust various settings to achieve the best performance, whether the objective is faster acceleration or higher fuel efficiency. This concept applies across fields, from engineering to economics, where you often balance trade-offs to reach an optimal outcome.
In machine learning, optimisation takes on a more specific role. The objective function is typically the :link loss (or cost), which quantifies how far off a model’s predictions are from the actual data. The goal is to adjust the model’s parameters (like weights and biases) to minimize this loss. Because the loss landscapes in machine learning can be highly complex and non-linear, algorithms like gradient descent, and variants such as Adam, are employed. These algorithms iteratively tweak model parameters, gradually moving it toward better performance.
Note
Machine learning involves data with many parameters and high-dimensional spaces, therefore the optimisation algorithm has to navigate many local minima and :link saddle points. The choice of algorithm is crucial, as it determines how efficiently the model converges to the optimal solution.
Visualising Adam in action
To illustrate the optimisation process, let us take a classical function used to test optimisation algorithms: the :link Rosenbrock function. This function is known for its narrow, curved valley, making it challenging for optimisation algorithms to converge to the global minimum. The function is typically depicted in 2D, with the \(x\) and \(y\) axes representing the parameters to be optimized. We will instead visualise the optimisation process in 3D, with the \(x\) and \(y\) axes representing the spatial coordinates and the \(z\)-axis representing the function value.
In the code below we define the rosenbrock_2d function, set up the optimisation process using PyTorch and the Adam optimizer (torch.optim.Adam), and track the path taken by the optimizer. We then create a 3D surface plot of the function and animate the optimisation process to see how the optimiser navigates the landscape.
Show the code
import torchimport matplotlib.pyplot as pltimport numpy as npfrom matplotlib.animation import FuncAnimationfrom mpl_toolkits.mplot3d import Axes3D # Ensures 3D projection is recognizeddef rosenbrock_2d(x, y, a=1.0, b=100.0):return (a - x) **2+ b * (y - x**2) **2# PyTorch setup: we'll optimize x, y to find the minimum of the Rosenbrock functionparams = torch.tensor([-0.8, 2.0], requires_grad=True)optimizer = torch.optim.Adam([params], lr=0.01)# Track the path: (x, y, f(x,y)) each iterationpath = []tolerance =1e-4max_iterations =6000for i inrange(max_iterations): optimizer.zero_grad() loss = rosenbrock_2d(params[0], params[1]) loss.backward() optimizer.step() x_val = params[0].item() y_val = params[1].item() z_val = loss.item() path.append([x_val, y_val, z_val])# Stop if loss is below toleranceif z_val < tolerance:print("Converged at iteration", i)breakpath = np.array(path)num_frames =len(path)# Create a 3D surface for the functionX = np.linspace(-2, 2, 200)Y = np.linspace(-1, 3, 200)X_mesh, Y_mesh = np.meshgrid(X, Y)Z_mesh = rosenbrock_2d(X_mesh, Y_mesh)fig = plt.figure()ax = fig.add_subplot(111, projection="3d")ax.set_title("Adam Optimizer on 2D Rosenbrock (3D Surface)")ax.set_xlabel("x")ax.set_ylabel("y")ax.set_zlabel("f(x,y)")# Initial axis limits (from our grid)init_xlim = (-2, 2)init_ylim = (-1, 3)init_zlim = (np.min(Z_mesh), np.max(Z_mesh))center_x, center_y, center_z =1, 1, 0# Set initial limitsax.set_xlim(init_xlim)ax.set_ylim(init_ylim)ax.set_zlim(init_zlim)ax.plot_surface(X_mesh, Y_mesh, Z_mesh, alpha=0.6)ax.plot([1], [1], [0], marker="o", markersize=5) # Global minimum reference# Animation: plot the path and adjust axis limits to zoom(point,) = ax.plot([], [], [], "ro") # Current position marker(line,) = ax.plot([], [], [], "r-") # Path linedef init(): point.set_data([], []) point.set_3d_properties([]) line.set_data([], []) line.set_3d_properties([])return point, linedef update(frame):# Update point and path x_val = path[frame, 0] y_val = path[frame, 1] z_val = path[frame, 2] point.set_data([x_val], [y_val]) point.set_3d_properties([z_val]) line.set_data(path[: frame +1, 0], path[: frame +1, 1]) line.set_3d_properties(path[: frame +1, 2])return point, lineani = FuncAnimation( fig, update, frames=num_frames, init_func=init, interval=100, blit=True)ani.save("adam_rosenbrock.mp4", writer="ffmpeg", fps=48)plt.close(fig)
Adam uses adaptive learning rates for each parameter, which can help it converge faster than traditional gradient descent methods. This is why in the animation you see the optimizer move at different speeds in different directions. The slower the convergence, the more the optimizer is “exploring” the landscape to find the optimal path to the global minimum. This adaptability is one of the key strengths of Adam, as it can handle different learning rates for each parameter, making it more robust to various optimisation problems.
The mathematics of Adam
Adam (Adaptive Moment Estimation) combines ideas from momentum and :link RMSProp to adaptively adjust the learning rates of model parameters. At its core, Adam computes two moving averages: one for the gradients (the first moment) and one for the squared gradients (the second moment). Given the gradient \(g_t\) at iteration \(t\), these are updated as:
\[
m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t
\]
\[
v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2
\]
Here, \(\beta_1\) and \(\beta_2\) are decay rates (typically around 0.9 and 0.999, respectively) that determine how much of the past gradients and squared gradients are retained.
Since the moving averages \(m_t\) and \(v_t\) are initialized at zero, they are biased toward zero in the initial steps. To correct this bias, Adam computes bias-corrected estimates:
\[
\hat{m}_t = \frac{m_t}{1 - \beta_1^t}
\]
\[
\hat{v}_t = \frac{v_t}{1 - \beta_2^t}
\]
Finally, the parameters ( ) are updated using these bias-corrected estimates according to the rule:
In this formula, \(\alpha\) is the learning rate and \(\epsilon\) is a small constant (such as \(10^{-8}\)) to avoid division by zero. This update rule allows Adam to automatically adjust the step size for each parameter, effectively handling sparse gradients and noisy objectives, which often results in faster convergence and improved performance over traditional stochastic gradient descent methods.
Teaching a neural network to paint with Adam
Another great way to show Adam in action is by training a neural network to paint an image. We’ll use a simple Multi-Layer Perceptron (MLP) and a more advanced architecture called :link Sinusoidal Representation Networks (SIREN) to illustrate this. The goal is to predict the RGB values of each pixel in an image based on its spatial coordinates. We’ll my favourite painting, “The Arnolfini Portrait” by Jan van Eyck as our target image.
First we need to setup a few hyperparameters and load the image. We are setting up a network with 4 hidden layers, each with 512 hidden units. We’ll train the model, saving display frames every 100 epochs and animation frames every 10 epochs. We’ll use the Adam optimizer with a learning rate of \(10^{-4}\) and early stopping patience of 500 epochs.
Let us load the image and display it to see what the model is working with.
Show the code
import numpy as npimport matplotlib.pyplot as pltfrom PIL import Imagedef load_and_preprocess_image(image_path): img = Image.open(image_path).convert("RGB") img = np.array(img) /255.0 H, W, _ = img.shapereturn img, H, W# Load and display imageimg, H, W = load_and_preprocess_image(image_path)print(f"Image shape: {img.shape}")plt.figure(figsize=(8, 8))plt.imshow(img)plt.axis("off")plt.show()
Image shape: (800, 585, 3)
In my case, I am using Apple Silicon to run this code, so we check for the presence of MPS so PyTorch uses the GPU. If you are running this on a different platform, you may need to adjust the device accordingly.
We also need to create a coordinate grid that represents the spatial coordinates of each pixel in the image. This grid will be the input to our neural network, and the target will be the RGB values of the corresponding pixels in the image. We’ll use the coordinate grid to train the model to predict the RGB values based on spatial location.
This grid looks as the following, notice that the image is inverted in the y-axis compared to the usual image representation. This is because the origin \((0,0)\) is at the top-left corner in the image, while in the Cartesian coordinate system it is at the bottom-left corner.
We also need to create directories to store the display and animation frames, this way we don’t have to store all the frames in memory. We’ll use these to save the model’s predictions at different epochs during training, which we will later use to create an animation of the training process.
As mentioned before, we will use a Multi-Layer Perceptron (MLP) model. It features an input layer that accepts \((x,y)\) coordinates, three hidden layers with :link ReLU activation functions, and an output layer that produces the predicted RGB values. While an MLP is a basic neural network that may not capture complex spatial patterns as well as more advanced architectures, this very limitation helps visually highlight the optimizer’s struggle to learn the image, and how Adam adapts as it traverses the loss landscape.
Note the model doesn’t have enough parameters to fully memorize the image and will struggle to capture the details of the painting pixel by pixel. This limitation will be evident in the animation, where the model’s predictions will be a blurry approximation of the original image. You can think of it as the model having to compress information into a lower-dimensional space and then reconstruct it, losing detail in the process. To produce an image that closely resembles the original, we would need a more complex architecture, a different approach, or lots of epochs to capture enough detail.
Show the code
model_mlp = MLP( in_features=2, hidden_features=hidden_features, hidden_layers=hidden_layers, out_features=3,).to(device)print(model_mlp)print("Number of parameters:",sum(p.numel() for p in model_mlp.parameters() if p.requires_grad),)
Finally we define the Adam optimiser and the Mean Squared Error (MSE) loss function. Remember the optimiser is responsible for updating the model’s parameters towards minimizing the loss function, while the MSE loss measures the difference between the model’s predictions and the target values (the original pixels), which we aim to minimize during training.
Show the code
import torch.optim as optimoptimizer = optim.Adam(model_mlp.parameters(), lr=learning_rate)criterion = nn.MSELoss()
With this out of the way, let us train the model and save the display and animation frames. We’ll also implement early stopping based on the patience hyper-parameter, which stops training if the loss does not improve for a certain number of epochs. If you decide to try this yourself, keep in mind that depending on your hardware, training may take a while (hours) due to the necessary large number of epochs and the complexity of the model.
Show the code
from tqdm.notebook import tqdmdef save_frame(frame, folder, prefix, epoch):"""Save a frame (as an image) to the given folder.""" frame_path = os.path.join(folder, f"{prefix}_{epoch:04d}.png")# If frame is grayscale, use cmap; otherwise, display as colorif frame.ndim ==2or frame.shape[-1] ==1: plt.imsave(frame_path, frame.astype(np.float32), cmap="gray")else: plt.imsave(frame_path, frame.astype(np.float32))def train_model( model, coords, target, H, W, num_epochs, display_interval, animation_interval, patience, optimizer, criterion, display_dir, anim_dir, create_animation,): best_loss =float("inf") patience_counter =0 display_epochs = [] display_losses = []for epoch in tqdm(range(num_epochs), desc="Training"): optimizer.zero_grad() pred = model(coords) loss = criterion(pred, target) loss.backward() optimizer.step()if loss.item() < best_loss: best_loss = loss.item() patience_counter =0else: patience_counter +=1if patience_counter >= patience:print(f"Early stopping at epoch {epoch}, best loss: {best_loss:.6f}")breakwith torch.no_grad():# Reshape prediction to (H, W, 3) for a color image pred_img = pred.detach().cpu().numpy().astype(np.float16).reshape(H, W, 3) frame = np.clip(pred_img, 0, 1)if create_animation and epoch % animation_interval ==0: save_frame(frame, anim_dir, "frame", epoch)if epoch % display_interval ==0: save_frame(frame, display_dir, "display", epoch) display_epochs.append(epoch) display_losses.append(loss.item())del predreturn best_loss, display_epochs, display_lossesbest_loss_mlp, display_epochs_mlp, display_losses_mlp = train_model( model_mlp, coords, target, H, W, num_epochs, display_interval, animation_interval, patience, optimizer, criterion, display_dir, anim_dir, create_animation,)
With the training complete, we can display the saved frames to get a sense of how the model’s predictions evolved over time. They show the model’s output at different epochs, with the epoch number and loss value displayed with each image. This visualisation helps us understand how the model learns to approximate the original image pixel by pixel.
Show the code
import globimport mathimport redef extract_number(f): s = os.path.basename(f) match = re.search(r"(\d+)", s)returnint(match.group(1)) if match else-1def grid_display(display_dir, display_epochs, display_losses, num_cols=5):# Use the custom key for natural sorting of filenames display_files =sorted( glob.glob(os.path.join(display_dir, "*.png")), key=extract_number ) num_images =len(display_files) num_rows = math.ceil(num_images / num_cols) fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols *3, num_rows *3)) axes = axes.flatten() if num_images >1else [axes]for i, ax inenumerate(axes):if i < num_images: img_disp = plt.imread(display_files[i]) ax.imshow( img_disp if img_disp.ndim ==3else img_disp, cmap=Noneif img_disp.ndim ==3else"gray", ) ax.set_title(f"Epoch {display_epochs[i]}\nLoss: {display_losses[i]:.6f}") ax.axis("off")else: ax.axis("off") plt.tight_layout() plt.show()grid_display(display_dir, display_epochs_mlp, display_losses_mlp)
To get an even better intuition, let us create an animation which shows predictions at different epochs. This animation will give us a dynamic view of the training process, illustrating how the model’s output evolves over time. We’ll use the imageio library to create an MP4 video from the saved frames.
Show the code
from PIL import Image, ImageDraw, ImageFontimport imageio.v2 as imageioimport globimport osimport numpy as npdef create_mp4_from_frames(anim_dir, mp4_filename, fps=10):# Use the custom sort key to ensure natural sorting of filenames anim_files =sorted(glob.glob(os.path.join(anim_dir, "*.png")), key=extract_number) frames = [] font_size =32try: font = ImageFont.truetype(r"OpenSans-Bold.ttf", font_size)exceptIOError: font = ImageFont.load_default()forfilein anim_files: base = os.path.basename(file)try: parts = base.split("_") iteration = parts[-1].split(".")[0]exceptException: iteration ="N/A" frame_array = imageio.imread(file) image = Image.fromarray(frame_array)# Ensure image is in RGB mode for drawing colored textif image.mode !="RGB": image = image.convert("RGB") draw = ImageDraw.Draw(image) text =str(iteration) textwidth = draw.textlength(text, font) textheight = font_size width, height = image.size x = width - textwidth -10 y = height - textheight -10# For RGB images, white is (255, 255, 255) draw.text((x, y), text, font=font, fill=(255, 255, 255)) frames.append(np.array(image))# Write frames to an MP4 video file with the ffmpeg writer writer = imageio.get_writer(mp4_filename, fps=fps, codec="libx264", format="ffmpeg")for frame in frames: writer.append_data(frame) writer.close()
Show the code
if create_animation: create_mp4_from_frames(anim_dir, "The_Arnolfini_portrait_RGB_MLP.mp4", fps=24) cleanup_frames(anim_dir)
We can clearly see the model slowly learn the details of the painting over time, starting from a verry blurry approximation and gradually refining its predictions. The role of the optimiser, is to guide the model towards “guessing” the details of the painting, such as textures, colours, and shapes. The “wiggles” in the animation represent the model’s attempt to find the optimal parameters that minimize the loss function, which in turn helps it produce more accurate predictions, just like when a person tries to find the optimal path around a complex maze by trial and error.
The SIREN model
MLPs, when used with standard activation functions like ReLU, tend to create piecewise linear approximations of the target function. This works well for many problems, but it can lead to over-smoothing when modeling complex spatial patterns, especially in images. Essentially, an MLP struggles to capture high-frequency details or subtle variations in an image because its architecture is inherently limited by its smooth, global parameterization.
On the other hand, a SIREN model, short for Sinusoidal Representation Networks, employs periodic activation functions (typically sine functions) instead of ReLU. The sinusoidal activations allow the network to naturally capture high-frequency details, as they can represent oscillatory patterns much more effectively. This means that it will be better suited for representing complex, detailed signals with fine variations, making it a strong candidate for tasks such as image reconstruction or any problem where precise spatial detail is critical. It will also help the optimizer converge much faster and more accurately to the target image.
Here’s how SIREN is defined in PyTorch. The key difference is the use of the SineLayer class, which replaces the standard linear layers in the MLP. The SineLayer applies a sine function to the output of a linear layer, with a frequency controlled by the omega_0 parameter. The SIREN class then stacks multiple SineLayer instances to create a deep network with sinusoidal activations. The choice of omega_0 determines the frequency of the sine functions and can be tuned to capture different spatial frequencies in the data.
Finally, we run training and save the display and animation frames. This time we should see convergence being achieved faster, and more detailed predictions compared to the MLP, thanks to SIREN’s ability to capture high-frequency spatial patterns more effectively.
And finally stich the animation frames together to create a video that shows the training process of the SIREN model.
Show the code
if create_animation: create_mp4_from_frames(anim_dir, "The_Arnolfini_portrait_RGB_SIREN.mp4", fps=12) cleanup_frames(anim_dir)
Notice how this time SIREN captures fine details much faster and accurately than the MLP. It has almost memorized the training data, showing the effectiveness of SIREN in high-frequency function representation. The Adam optimizer, in this case, has an easier time navigating the loss landscape, converging to the target image much faster and with more precision.
Final remarks
Hopefully this experiment has given you a better understanding of the role of optimisation in machine learning, it is a crucial aspect that affects how well models perform and converge during training, and despite the somewhat complex nature of the algorithms employed, it is possible for anyone to get a rough intuition of how they work.