How GPU’s work, an explainer using the Mandelbrot set

mandelbrot sets and parallelism

Experiments
GPU
Published

February 9, 2025

Every day pretty much all of us either uses or hears about the mythical GPU, the Graphics Processing Unit. It’s the thing that makes your games, video renders, and your machine learning models train faster. But how does it achieve that? What makes it different from a CPU?

We will do a quick explainer which should give you a good intuition using the Mandelbrot set. The Mandelbrot set is a fractal, a set of complex numbers that when iterated through a function, either diverges to infinity or stays bounded, it is the boundary between these two regions. The Mandelbrot set is a great example to use because it’s a simple function that can be parallelized easily.

A quick diversion into parallelism

During my university days I had a quick course on parallelism, where we were asked how we could parallelize the computation of the Mandelbrot set. The answer is simple, we can calculate each pixel independently, since at its core, calculating the set involves applying a function to each complex number to determine if it belongs to the set or not.

Back then the approach I followed was to divide the image into an \(n x n\) grid and assign each grid to a separate networked computer, which would then calculate the pixels in that grid, and write results to a network shared file (this was back in the day where RPC was barely a thing). This is a simple form of parallelism, and it’s called embarrassingly parallel.

Most GPU computations are embarrassingly parallel, and this is why they are so good at simple parallelism workloads. They have thousands of cores, and each core can act independently to compute a fragment of the workload.

They are adept at machine learning and AI workloads equally because most of the computations in these fields are matrix multiplications, which can be parallelized easily.

The Mandelbrot set

The Mandelbrot set is defined by the following function:

\[ f(z) = z^2 + c \]

where \(z\) is a complex number, and \(c\) is a constant complex number. We can iterate this function, and if the magnitude of \(z\) is greater than 2, then we can say that the function diverges to infinity. If it doesn’t, then it stays bounded. The Mandelbrot set is the boundary between these two regions.

Further below we will show a rendered image of the Mandelbrot set, but first, let’s write a simple Python function to calculate it.

Calculating the Mandelbrot set with no parallelism

Let’s start by writing a simple, naive function to calculate the Mandelbrot set. This function uses no parallelism, and it’s a simple for loop that iterates over each pixel in the image and calculates each using the CPU only. It basically iterates over columns and rows and computes whether that particular point diverges or not as nested loops.

For a width of 500 and height of 500, this function will compute \(250000\) pixels, for a width of 1000 and height of 1000, it will compute \(1000000\) pixels, and so on. The time taken for this function to run increase as \(O(n^2)\) where \(n\) is the width or height of the image.

Show the code
def compute_mandelbrot_iterations(width, height, max_iter):
    real_min, real_max = -2, 1
    imag_min, imag_max = -1.5, 1.5
    real_step = (real_max - real_min) / (width - 1)
    imag_step = (imag_max - imag_min) / (height - 1)

    # Initialize a 2D list to hold iteration counts.
    iter_counts = [[0 for _ in range(width)] for _ in range(height)]

    for j in range(height):
        imag = imag_min + j * imag_step
        for i in range(width):
            real = real_min + i * real_step
            c = complex(real, imag)
            z = 0j
            count = 0
            while count < max_iter:
                z = z * z + c
                # Check divergence: if |z|^2 > 4 then break.
                if (z.real * z.real + z.imag * z.imag) > 4:
                    break
                count += 1
            iter_counts[j][i] = count

    return iter_counts

Calculating the Mandelbrot set with parallelism

While the above function is simple, it is not the most efficient. Because it is basically a big matrix operation (an image is a matrix), we can parallelize it easily using a number of frameworks which offer matrix operations. Let’s investigate how we would do this using a number of libraries.

For this example, we will show how to achieve it using numpy, pytorch and Apple’s mlx. They all offer a similar API, and can be used virtually interchangeably. They offer a set of functionality which allows you to perform matrix operations on either the CPU and GPU:

  • Vectorized Operations: They all let you perform elementwise operations on entire arrays/tensors without explicit loops, which boosts performance.
  • Broadcasting: NumPy, PyTorch, and MLX support broadcasting, allowing operations on arrays of different shapes — great for aligning matrices without manual reshaping.
  • Optimized Backends: Under the hood, they rely on highly optimized C/C++ libraries (like BLAS/LAPACK or Apple’s Accelerate framework for MLX) to perform computations quickly.
  • Multi-dimensional Data Handling: They all offer robust support for multi-dimensional arrays (or tensors), making them well-suited for tasks ranging from basic linear algebra to complex machine learning computations.

Numpy

Let’s start with an implementation of the compute_mandelbrot_iterations function using numpy, entirely with array (or matrix) operations.

Show the code
import numpy as np


def compute_mandelbrot_numpy(
    width: int = 500, height: int = 500, max_iter: int = 30
) -> np.ndarray:
    # Create linearly spaced real and imaginary parts and generate a complex grid.
    real = np.linspace(-2, 1, width)
    imag = np.linspace(-1.5, 1.5, height)
    X, Y = np.meshgrid(real, imag, indexing="xy")
    c = X + 1j * Y

    # Initialize z and an array to hold the iteration counts
    z = np.zeros_like(c)
    iter_counts = np.zeros(c.shape, dtype=np.int32)

    for _ in range(max_iter):
        # Create a mask for points that have not yet diverged
        mask = np.abs(z) < 4
        if not mask.any():
            break

        # Update z and iteration counts only where |z| < 4
        z = np.where(mask, z * z + c, z)
        iter_counts = np.where(mask, iter_counts + 1, iter_counts)

    return iter_counts

This function starts by generating a grid of complex numbers using matrix operations, which is key for parallel computation. It first creates two linearly spaced arrays for the real and imaginary parts using np.linspace, and then builds two 2D grids with np.meshgrid — one for the real values and one for the imaginary values. These grids are combined into a single complex grid c (where each element is of the form x + 1j * y), and this process happens all at once without the need for explicit loops, leveraging NumPy’s vectorized operations.

Next, the code initializes two arrays of the same shape as c: one for the iterative values z (starting at zero) and one to keep track of the iteration counts. The main computation occurs in a loop where, in each iteration, the code computes the absolute value of every element in z simultaneously using np.abs(z) and creates a boolean mask that identifies the elements where |z| < 4. This mask is then used to update z and iter_countsin one go vianp.where`, ensuring that only the elements that haven’t diverged (i.e., where the condition holds) are updated.

Because these operations — creating the grid, computing absolute values, applying the mask, and updating arrays — are all performed on entire arrays at once, they are handled in parallel by optimized C code under the hood. This eliminates the need for slow, explicit Python loops, which is why such an approach is highly efficient for intensive computations like generating the Mandelbrot set. The combination of vectorized operations and conditional updates not only makes the code concise but also allows the underlying hardware to execute many operations concurrently, resulting in much faster computation.

PyTorch

Now let us do the same but with the pytorch library, except for the line c = X.t() + 1j * Y.t(), the code is identical to the numpy implementation. The t() function is used to transpose the matrix, and the + and * operators are overloaded to perform elementwise addition and multiplication, respectively. This allows us to create the complex grid c in a single line, just like in the numpy version.

Tip

The reason you see a transpose in PyTorch is because its grid creation defaults to a different dimension ordering than numpy’s. In numpy, when you use np.meshgrid with indexing='xy', the resulting arrays have the first dimension corresponding to the y-axis and the second to the x-axis, matching common image coordinate conventions. pytorch’s torch.meshgrid, on the other hand, typically returns tensors where the dimensions are swapped relative to that layout. By transposing (.t()) the pytorch tensors, you align the dimensions so that the complex grid c ends up with the same arrangement as in numpy. This ensures that each element in c correctly corresponds to the intended coordinate in the complex plane.

Show the code
import torch


def compute_mandelbrot_torch(
    width: int = 500, height: int = 500, max_iter: int = 30, device: str = "cpu"
) -> torch.Tensor:
    real = torch.linspace(-2, 1, steps=width, device=device)
    imag = torch.linspace(-1.5, 1.5, steps=height, device=device)
    X, Y = torch.meshgrid(real, imag, indexing="xy")
    c = X.t() + 1j * Y.t()

    z = torch.zeros_like(c)
    iter_counts = torch.zeros(c.shape, device=device, dtype=torch.int32)

    for _ in range(max_iter):
        mask = torch.abs(z) < 4
        if not mask.any():
            break

        z = torch.where(mask, z * z + c, z)
        iter_counts = torch.where(mask, iter_counts + 1, iter_counts)

    return iter_counts

Because we are using pytorch tensors, we can offload the workload onto a GPU by setting the device parameter to 'cuda' or 'mps'. This tells pytorch to use the GPU for all subsequent operations, which will significantly speed up the computation. The rest of the code remains the same, with the same vectorized operations and conditional updates as in the NumPy version.

The difference being that when using the GPU, the operations will run concurrently on the GPU cores, which are optimized for parallel computation. For example, when running z = torch.where(mask, z * z + c, z) on a GPU, each element in z, mask, and c can be processed simultaneously by different cores, allowing for massive speedups compared to sequential execution on a CPU. Effectivelly we will be “painting” the Mandelbrot set in one single operation rather than pixel by pixel.

Apple’s MLX

Apple’s MLX offers an API which is virtually the same as PyTorch and NumPy, and it can be used interchangeably with them. The only difference is that it is optimized for Apple hardware, and it can be used on Apple Silicon.

Show the code
try:
    import mlx.core as mx
except ImportError:
    mx = None

if mx is not None:

    def compute_mandelbrot_mlx(
        width: int = 500, height: int = 500, max_iter: int = 30
    ) -> mx.array:
        real = mx.linspace(-2, 1, width).astype(mx.float32)
        imag = mx.linspace(-1.5, 1.5, height).astype(mx.float32)
        X, Y = mx.meshgrid(real, imag, indexing="xy")
        c_real = mx.transpose(X)
        c_imag = mx.transpose(Y)

        z_real = mx.zeros_like(c_real)
        z_imag = mx.zeros_like(c_imag)
        iter_counts = mx.zeros(c_real.shape, dtype=mx.int32)

        for _ in range(max_iter):
            mag_sq = mx.multiply(z_real, z_real) + mx.multiply(z_imag, z_imag)
            mask = mag_sq < 4.0
            if not mx.any(mask):
                break

            new_z_real = mx.where(
                mask, z_real * z_real - z_imag * z_imag + c_real, z_real
            )
            new_z_imag = mx.where(mask, 2 * z_real * z_imag + c_imag, z_imag)
            z_real = new_z_real
            z_imag = new_z_imag

            iter_counts = mx.where(mask, iter_counts + 1, iter_counts)

        return iter_counts

else:

    def compute_mandelbrot_mlx(*args, **kwargs):
        raise ImportError("MLX is not available on this system.")

Putting it all together

Let’s put all the above together and render the Mandelbrot set with each different method. Each of compute_mandelbrot_* returns a 2D array of integers, where each integer represents the number of iterations it took for that pixel to diverge. We will then use matplotlib to render the image.

Show the code
mps_available = torch.backends.mps.is_available()
cuda_available = torch.cuda.is_available()
mlx_available = mx is not None

width, height = 500, 500
max_iter = 30

iter_counts = {}

iter_counts["iterations"] = compute_mandelbrot_iterations(width, height, max_iter)
iter_counts["numpy"] = compute_mandelbrot_numpy(width, height, max_iter)
iter_counts["torch_cpu"] = compute_mandelbrot_torch(
    width, height, max_iter, "cpu"
).T.cpu()

if mps_available:
    iter_counts["mps"] = compute_mandelbrot_torch(
        width, height, max_iter, "mps"
    ).T.cpu()
if cuda_available:
    iter_counts["cuda"] = compute_mandelbrot_torch(
        width, height, max_iter, "cuda"
    ).T.cpu()
if mlx_available:
    iter_counts["mlx"] = compute_mandelbrot_mlx(width, height, max_iter).T

Plotting the set

Now let’s create a function to plot the above iter_counts list of Mandelbrot images so we can compare each visually, they should all look the same.

Show the code
import math
import matplotlib.pyplot as plt


def plot_mandelbrot_grid(iter_counts_dict, titles: list = None):
    n = len(iter_counts_dict)
    iter_counts_list = list(iter_counts_dict.values())
    if titles is not None and len(titles) != n:
        raise ValueError("Number of titles must match the number of images.")
    if n == 0:
        print("No Mandelbrot sets to plot.")
        return
    n_cols = math.ceil(math.sqrt(n))
    n_rows = math.ceil(n / n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 10))
    axes = axes.flatten() if n_rows > 1 else [axes]
    for i, (key, iter_counts) in enumerate(iter_counts_dict.items()):
        ax = axes[i]
        ax.imshow(iter_counts, cmap="hot", extent=(-2, 1, -1.5, 1.5))
        ax.set_title(titles[i] if titles else key)
        ax.axis("off")
    for j in range(i + 1, len(axes)):
        axes[j].axis("off")
    plt.tight_layout()
    plt.show()
Show the code
titles = ["Python (iterations)", "NumPy", "PyTorch (CPU)"]
if mps_available:
    titles.append("PyTorch (MPS)")
if cuda_available:
    titles.append("PyTorch (CUDA)")
if mlx_available:
    titles.append("Apple MLX")

plot_mandelbrot_grid(iter_counts, titles=titles)

That confirms that each of the methods above is correct, and that they all yield the same result. The only difference is the speed at which they compute the Mandelbrot set!

Timing the functions

So we can easily contrast and compare the speed of each of the above functions, let’s time them using the time module in Python at different resolutions (width and height). We will see the difference in parallelism between methods which rely entirely on the CPU and those which offload the computation to the GPU. All timings are in seconds.

Show the code
import time

resolutions = [1000, 2000, 3000, 4000]
max_iter = 1000

header = [
    "Resolution",
    "Iterations",
    "NumPy",
    "PyTorch/cpu",
    "PyTorch/mps" if mps_available else "MPS - N/A",
    "PyTorch/cuda" if cuda_available else "CUDA - N/A",
    "MLX" if mlx_available else "MLX - N/A",
]

table_data = []

for n in resolutions:
    width = height = n
    timings = {}

    start_time = time.time()
    compute_mandelbrot_iterations(width, height, max_iter)
    timings["Iterations"] = time.time() - start_time

    start_time = time.time()
    compute_mandelbrot_numpy(width, height, max_iter)
    timings["NumPy"] = time.time() - start_time

    start_time = time.time()
    compute_mandelbrot_torch(width, height, max_iter, "cpu")
    timings["PyTorch/cpu"] = time.time() - start_time

    if mps_available:
        start_time = time.time()
        compute_mandelbrot_torch(width, height, max_iter, "mps")
        timings["PyTorch/mps"] = time.time() - start_time
    else:
        timings["PyTorch/mps"] = None

    if cuda_available:
        start_time = time.time()
        compute_mandelbrot_torch(width, height, max_iter, "cuda")
        timings["PyTorch/cuda"] = time.time() - start_time
    else:
        timings["PyTorch/cuda"] = None

    if mlx_available:
        start_time = time.time()
        compute_mandelbrot_mlx(width, height, max_iter)
        timings["MLX"] = time.time() - start_time
    else:
        timings["MLX"] = None

    row = [
        f"{n}x{n}",
        f"{timings['Iterations']:.3f}",
        f"{timings['NumPy']:.3f}",
        f"{timings['PyTorch/cpu']:.3f}",
        (
            f"{timings['PyTorch/mps']:.3f}"
            if timings["PyTorch/mps"] is not None
            else "N/A"
        ),
        (
            f"{timings['PyTorch/cuda']:.3f}"
            if timings["PyTorch/cuda"] is not None
            else "N/A"
        ),
        f"{timings['MLX']:.3f}" if timings["MLX"] is not None else "N/A",
    ]
    table_data.append(row)

print("{:<12} {:<12} {:<12} {:<15} {:<15} {:<12} {:<12}".format(*header))
for row in table_data:
    print("{:<12} {:<12} {:<12} {:<15} {:<15} {:<12} {:<12}".format(*row))
Resolution   Iterations   NumPy        PyTorch/cpu     MPS - N/A       PyTorch/cuda MLX - N/A   
1000x1000    18.568       6.508        2.299           N/A             0.205        N/A         
2000x2000    74.228       27.851       15.054          N/A             0.652        N/A         
3000x3000    167.244      66.770       65.809          N/A             1.349        N/A         
4000x4000    297.467      114.647      109.472         N/A             2.338        N/A         

And finally let us put the above results in an intuitive visual representation, so we can see the difference in speed between the different methods.

Show the code
import matplotlib.ticker as ticker
import numpy as np  # Make sure NumPy is imported

resolutions_numeric = [int(row[0].split("x")[0]) for row in table_data]
methods = header[1:]

num_res = len(resolutions_numeric)
num_methods = len(methods)
x = np.arange(num_res)
bar_width = 0.15

fig, ax = plt.subplots(figsize=(8, 6))

for i, method in enumerate(methods):
    times = []
    for row in table_data:
        value = row[i + 1]  # skip resolution column
        times.append(float(value) if value != "N/A" else 0)
    offset = (i - num_methods / 2) * bar_width + bar_width / 2
    ax.bar(x + offset, times, bar_width, label=method)

ax.set_xlabel("Resolution (pixels)")
ax.set_ylabel("Time (s)")
ax.set_xticks(x)
ax.set_xticklabels([f"{res}x{res}" for res in resolutions_numeric])

formatter = ticker.ScalarFormatter()
formatter.set_scientific(False)
ax.yaxis.set_major_formatter(formatter)

ax.legend()
plt.show()

You can see that GPU methods are significantly faster than the CPU because of its inherent parallelism. The more cores you have, the faster the computation will be. This is why GPUs are so good at parallel workloads, and why they are so adept to machine learning and AI workloads - deep down, they are just matrix operations using embarrasingly parallel workloads.

Reuse

This work is licensed under CC BY (View License)