Fine-tuning an LLM with Apple’s MLX Framework

fine-tuning pre-trained language models in apple silicon

HowTo
AI
Language Models
Published

December 11, 2024

Modern GPU’s come with inbuilt memory, which is separate from the CPU’s memory. This means that when training large models, the data has to be copied from the CPU’s memory to the GPU’s memory, which can be slow and inefficient. This is particularly problematic when training large language models (LLM’s), as the data can be too large to fit into the GPU’s memory.

With Apple Silicon, the emergence of shared memory between the CPU and GPU has opened up a lot of possibilities for machine learning, as the GPU can now access the CPU’s memory directly. This is a huge advantage for training large models, as it removes the GPU RAM limitation, even if the GPU itself is not as powerful as a dedicated GPU.

flowchart TD
    classDef cpu fill:#b3d9ff,stroke:#333
    classDef gpu fill:#ffb3b3,stroke:#333
    classDef ne fill:#b3ffb3,stroke:#333
    classDef other fill:#ffffb3,stroke:#333
    classDef uma fill:#e6f2ff,stroke:#333
    classDef features fill:#f0f0f0,stroke:#333

    CPU("CPU Cores"):::cpu <--> UMA
    GPU("GPU Cores"):::gpu <--> UMA
    NE("Neural Engine"):::ne <--> UMA
    UMA(["Unified Memory Pool<br>(VRAM)"]):::uma

Apple also released the MLX framework, which is Apple’s take on PyTorch and NumPy, but taking full advantage of the Unified Memory Architecture (UMA) of Apple Silicon.

Here we will see how we can fine-tune a pre-trained LLM using the MLX framework, using the LoRA approach.

About LoRA

LoRA (Low-Rank Adaptation) is a technique for fine-tuning large machine learning models, like language models or image generators, without retraining the entire model from scratch. Instead of updating all the model’s parameters—which can be slow, expensive, and require massive computational resources—LoRA freezes the original model and adds small, trainable “adapters” to specific parts of the network (like the attention layers in a transformer model). These adapters are designed using low-rank matrices, which are pairs of smaller, simpler matrices that approximate how the original model’s weights would need to change for a new task.

The core idea is to avoid retraining a massive neural network with billions of parameters for every new task, such as adapting to a specialized domain or style. LoRA modifies only a tiny fraction of the model by training two smaller matrices for each targeted layer. These matrices work together to capture the most important adjustments needed for the task. The size of these matrices is controlled by a “rank” hyperparameter, which balances efficiency and accuracy. This approach reduces the number of trainable parameters by thousands of times, making fine-tuning feasible on hardware with limited resources.

Once trained, the adapter matrices can be merged back into the original model during inference, adding almost no computational overhead. This makes the adapted model as fast as the original during deployment. The benefits include significant memory and computational savings, flexibility in training multiple lightweight adapters for different tasks (e.g., coding, translation, or art styles), and performance that often matches full fine-tuning. By focusing on low-rank updates, LoRA efficiently captures critical task-specific adjustments without altering the bulk of the pre-trained model’s knowledge.

graph LR
    subgraph Input Layer
        A1((Input))
    end

    subgraph Hidden Layer 1
        B1(("Layer Parameters"))
    end

    subgraph Hidden Layer 2
        C1(("Layer Parameters"))
    end

    subgraph Output Layer
        D1((Output))
    end

    %% LoRA Additions (colored differently)
    L1(("LoRA Adapter")):::loraStyle
    L2(("LoRA Adapter")):::loraStyle

    %% Connections in Pre-trained Model
    A1 --> B1
    B1 --> C1
    C1 --> D1

    %% LoRA Connections (colored differently)
    L1 --> B1:::loraConnection
    L2 --> C1:::loraConnection

    %% Style Definitions
    classDef loraStyle fill:#f9d5e5,stroke:#c81d7a,stroke-width:2px,color:#000;
    classDef loraConnection stroke:#c81d7a,stroke-width:2px,stroke-dasharray:5 5;

A brief overview of fine-tuning

Fine-tuning a pre-trained language model is common practice. The idea is to take a pre-trained model, like Llama or Qwen, and train on a specific dataset to adapt it to a specific task. This is typically done by freezing the weights of the pre-trained model and adding a small number of trainable parameters to the model, which are trained on the new dataset.

Overall, there are three main ways to fine-tune a pre-trained model:

  • Full fine-tuning: In this approach, all the weights of the pre-trained model are unfrozen, and the entire model is trained on the new dataset. This is the most computationally expensive approach, as it requires training the entire model from scratch.
  • Layer-wise fine-tuning: Only a subset of the layers in the pre-trained model are unfrozen and trained on the new dataset. This is less computationally expensive than full fine-tuning, as only a portion of the model is trained.
  • Adapter-based fine-tuning: Small trainable “adapters” are added to specific parts of the pre-trained model, and only these adapters are trained on the new dataset. This is the least computationally expensive approach, as only a small number of parameters are trained (this is the LoRA approach).

Additionally, there are two main types of fine-tuning based on supervision:

  • Unsupervised fine-tuning: In this approach, the pre-trained model is fine-tuned on a new dataset without any labels (which is to say, we give the model a large amount of content). In other words, we offer the model a new corpus of text, and the model learns to generate text in the style of the new corpus.
  • Supervised fine-tuning: The pre-trained model is fine-tuned on a new dataset with labels. That is, we offer the model a new corpus of text (“prompts”) with labels (the “output”), and the model learns to generate text that matches the intended labels.

MLX can handle any combination of the above.

Starting with the MLX framework

To begin, we need to install the MLX framework on your Apple Silicon Mac. MLX is a Python library, so we can install it in a variety of ways depending on your Python environment, for example, for Conda:

conda install -c conda-forge mlx mlx-lm

Or with pip:

pip install mlx mlx-lm

Once installed, you will have available the basic set of MLX tools, including the mlx command-line tool, which can be used to create new projects, run experiments, and manage datasets.

MLX can directly download models from the Hugging Face model hub - just keep in mind that not all models are optimized for the MLX framework. You can find many MLX optimized models, and there is an active community working on adding more to the list.

As an example, let’s generate some text using a very small Qwen model with just \(1/2\) billion parameters and 8 bit quantization:

mlx_lm.generate \
    --model lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit \
    --prompt 'When did Michael Jackson die?'

In my case, I use LMStudio to manage models, so I point at the model in a specific location rather than downloading it from the Hugging Face model hub via the mlx command.

Show the code
!mlx_lm.generate \
    --model $HOME/.lmstudio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit \
    --prompt 'When did Michael Jackson die? Stick to facts.' \
    --max-tokens 256
==========
Michael Jackson died on August 13, 2016, at the age of 50. He was diagnosed with multiple health issues, including kidney failure, in 2009, and passed away due to complications from his treatment.
==========
Prompt: 39 tokens, 104.885 tokens-per-sec
Generation: 53 tokens, 221.708 tokens-per-sec
Peak memory: 0.572 GB

Fine-tuning with MLX and LoRA

MLX removes the need to write custom Python code to fine-tune, as it provides a set of commands which implement the fine-tuning pipeline without the need for any additional code. The toolset can also use datasets from the Hugging Face model hub - this is exactly what we will do, as we are only illustrating the fine-tuning process with MLX. In most cases you will want to use your own dataset.

Note

In other articles we will cover how to perform fine-tuning with your own dataset and using the hugging face transformers library PEFT, rather than a prescribed tool such as MLX, Axolotl or Unsloth.

Supervised fine-tuning

Let’s start with supervised fine-tuning. We will use HuggingFaceH4/no_robots, a high-quality dataset designed to fine tune LLMs so they follow instructions more preciselly. It contains a set of prompts and the corresponding output text - it is split into train and test sets, but MLX requires a validation set as well, so we will first split the train set into train and validation sets.

Note

For the purposes of this exercise, we don’t need to worry about the specifics of the dataset, or whether the model improves or not - we are only interested in the process of fine-tuning.

Show the code
from datasets import load_dataset
import tqdm as notebook_tqdm

dataset = load_dataset("HuggingFaceH4/no_robots")

# Split train into train and validation
train = dataset["train"].train_test_split(test_size=0.15, seed=42)
dataset["train"] = train["train"]
dataset["validation"] = train["test"]

print(dataset)
DatasetDict({
    train: Dataset({
        features: ['prompt', 'prompt_id', 'messages', 'category'],
        num_rows: 8075
    })
    test: Dataset({
        features: ['prompt', 'prompt_id', 'messages', 'category'],
        num_rows: 500
    })
    validation: Dataset({
        features: ['prompt', 'prompt_id', 'messages', 'category'],
        num_rows: 1425
    })
})
Show the code
print(dataset["train"][0])
{'prompt': 'Pretend you are a dog. Send out a text to all your dog friends inviting them to the dog park. Specify that everyone should meet at 2pm today.', 'prompt_id': '4b474f9f59c64e8e32ad346051bb4f8d9b864110c2dda0d481e8f13898dc4511', 'messages': [{'content': 'Pretend you are a dog. Send out a text to all your dog friends inviting them to the dog park. Specify that everyone should meet at 2pm today.', 'role': 'user'}, {'content': "Hello, my dog friends!\n\nIt is such a beautiful day today! Does anyone want to go to the dog park to play catch and chase each other's tails with me? I will be there at 2 pm today. \n\nLet me know if you will be there! I'm looking forward to playing with you all!", 'role': 'assistant'}], 'category': 'Generation'}

Now let’s save the split dataset into a file.

Show the code
import json
import os

output_dir = "no_robots"
os.makedirs(output_dir, exist_ok=True)

# Rename 'validation' to 'valid'
dataset["valid"] = dataset.pop("validation")

for split in ["train", "test", "valid"]:
    dataset[split].to_json(f"{output_dir}/{split}.jsonl", lines=True)

And finally let us run the fine-tuning process. For the training we will set the number of adapter layers to \(8\) (--num-layers 8), the batch size to \(6\) (--batch-size 6), the number of iterations to \(1500\) (--iters 1500), and we will also checkpoint the model every \(100\) iterations (--grad-checkpoint). You can pass these parameters directly to the mlx_lm.train command, but in our case we want to save them into a configuration yaml file.

Show the code
!cat no_robots-train-params.yaml
# The path to the local model directory or Hugging Face repo.
model: "/Users/NLeitao/.lmstudio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit"

# Whether or not to train (boolean)
train: true

# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora

# Directory with {train, valid, test}.jsonl files
data: "./no_robots"

# Number of layers to fine-tune
num_layers: 16

# Minibatch size.
batch_size: 6

# Iterations to train for.
iters: 1000

# Adam learning rate.
learning_rate: 1e-4

# Save/load path for the trained adapter weights.
adapter_path: "adapter"

# Save the model every N iterations.
save_every: 100

# Evaluate on the test set after training
test: true

# Maximum sequence length.
max_seq_length: 2048

# Use gradient checkpointing to reduce memory use.
grad_checkpoint: true
Show the code
!mlx_lm.lora \
    --config no_robots-train-params.yaml \
    --train \
    --test
Loading configuration file no_robots-train-params.yaml
Loading pretrained model
Traceback (most recent call last):
  File "/Volumes/Home/pedroleitao/miniconda3/envs/pedroleitao.nl/bin/mlx_lm.lora", line 10, in <module>
    sys.exit(main())
             ^^^^^^
  File "/Volumes/Home/pedroleitao/miniconda3/envs/pedroleitao.nl/lib/python3.11/site-packages/mlx_lm/lora.py", line 310, in main
    run(types.SimpleNamespace(**args))
  File "/Volumes/Home/pedroleitao/miniconda3/envs/pedroleitao.nl/lib/python3.11/site-packages/mlx_lm/lora.py", line 270, in run
    model, tokenizer = load(args.model)
                       ^^^^^^^^^^^^^^^^
  File "/Volumes/Home/pedroleitao/miniconda3/envs/pedroleitao.nl/lib/python3.11/site-packages/mlx_lm/utils.py", line 782, in load
    model_path = get_model_path(path_or_hf_repo)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/Home/pedroleitao/miniconda3/envs/pedroleitao.nl/lib/python3.11/site-packages/mlx_lm/utils.py", line 200, in get_model_path
    raise ModelNotFoundError(
mlx_lm.utils.ModelNotFoundError: Model not found for path or HF repo: /Users/NLeitao/.lmstudio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit.
Please make sure you specified the local path or Hugging Face repo id correctly.
If you are trying to access a private or gated Hugging Face repo, make sure you are authenticated:
https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login

Batch size is a big contributor to memory usage, so you may need to adjust it depending on your hardware.

About Gradient Checkpointing

Gradient checkpointing is a method that trades off extra computation for lower memory usage during deep learning training. Instead of storing all intermediate outputs needed for backpropagation, the network only checkpoints certain “key” layers. When gradients need to be computed, the forward pass for the missing parts is recomputed on the fly.

By doing this, the total memory consumption can be drastically reduced—especially for very large models—because you’re not hanging onto every intermediate result. The tradeoff is that you’ll pay with some extra compute time for re-running parts of the forward pass.

We just fine-tuned the model, and we can now see the adapter matrices in the adapter directory!

Show the code
!ls -lh adapter
total 8504
-rw-r--r--@ 1 pedroleitao  staff   1.4M Mar  2 16:06 0000100_adapters.safetensors
-rw-r--r--@ 1 pedroleitao  staff   1.4M Mar  2 16:06 0000200_adapters.safetensors
-rw-r--r--@ 1 pedroleitao  staff   761B Mar  2 16:06 adapter_config.json
-rw-r--r--@ 1 pedroleitao  staff   1.4M Mar  2 16:06 adapters.safetensors

Before we can use the fine-tuned model, we need to merge (or “fuse”) the adapter matrices from the fine-tuning training back into the original model. This can be done with the mlx_lm.fuse command.

Show the code
!mlx_lm.fuse \
    --model $HOME/.lmstudio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit \
    --adapter-path ./adapter \
    --save-path $HOME/.lmstudio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit-tuned
Loading pretrained model

And finally we can generate text using the fine-tuned model as before.

Show the code
!mlx_lm.generate \
    --model $HOME/.lmstudio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit-tuned \
    --prompt 'When did Michael Jackson die? Stick to facts.' \
    --max-tokens 256
==========
Michael Jackson died on January 15, 2016.
==========
Prompt: 39 tokens, 1051.475 tokens-per-sec
Generation: 16 tokens, 256.512 tokens-per-sec
Peak memory: 0.573 GB

We have just fine-tuned a pre-trained language model using the MLX framework! Note how previously instructing the model to “stick to facts” did not result in the desired output (albeight clearly the date is wrong), but after fine-tuning the model on the no_robots dataset, the model now generates text that is more in line with the instruction.

Reuse

This work is licensed under CC BY (View License)