flowchart TD classDef cpu fill:#b3d9ff,stroke:#333 classDef gpu fill:#ffb3b3,stroke:#333 classDef ne fill:#b3ffb3,stroke:#333 classDef other fill:#ffffb3,stroke:#333 classDef uma fill:#e6f2ff,stroke:#333 classDef features fill:#f0f0f0,stroke:#333 CPU("CPU Cores"):::cpu <--> UMA GPU("GPU Cores"):::gpu <--> UMA NE("Neural Engine"):::ne <--> UMA UMA(["Unified Memory Pool<br>(VRAM)"]):::uma
Modern GPU’s come with inbuilt memory, which is separate from the CPU’s memory. This means that when training large models, the data has to be copied from the CPU’s memory to the GPU’s memory, which can be slow and inefficient. This is particularly problematic when training large language models (LLM’s), as the data can be too large to fit into the GPU’s memory.
With Apple Silicon, the emergence of shared memory between the CPU and GPU has opened up a lot of possibilities for machine learning, as the GPU can now access the CPU’s memory directly. This is a huge advantage for training large models, as it removes the GPU RAM limitation, even if the GPU itself is not as powerful as a dedicated GPU.
Apple also released the MLX framework, which is Apple’s take on PyTorch and NumPy, but taking full advantage of the Unified Memory Architecture (UMA) of Apple Silicon.
Here we will see how we can fine-tune a pre-trained LLM using the MLX framework, using the LoRA approach.
LoRA (Low-Rank Adaptation) is a technique for fine-tuning large machine learning models, like language models or image generators, without retraining the entire model from scratch. Instead of updating all the model’s parameters—which can be slow, expensive, and require massive computational resources—LoRA freezes the original model and adds small, trainable “adapters” to specific parts of the network (like the attention layers in a transformer model). These adapters are designed using low-rank matrices, which are pairs of smaller, simpler matrices that approximate how the original model’s weights would need to change for a new task.
The core idea is to avoid retraining a massive neural network with billions of parameters for every new task, such as adapting to a specialized domain or style. LoRA modifies only a tiny fraction of the model by training two smaller matrices for each targeted layer. These matrices work together to capture the most important adjustments needed for the task. The size of these matrices is controlled by a “rank” hyperparameter, which balances efficiency and accuracy. This approach reduces the number of trainable parameters by thousands of times, making fine-tuning feasible on hardware with limited resources.
Once trained, the adapter matrices can be merged back into the original model during inference, adding almost no computational overhead. This makes the adapted model as fast as the original during deployment. The benefits include significant memory and computational savings, flexibility in training multiple lightweight adapters for different tasks (e.g., coding, translation, or art styles), and performance that often matches full fine-tuning. By focusing on low-rank updates, LoRA efficiently captures critical task-specific adjustments without altering the bulk of the pre-trained model’s knowledge.
graph LR subgraph Input Layer A1((Input)) end subgraph Hidden Layer 1 B1(("Layer Parameters")) end subgraph Hidden Layer 2 C1(("Layer Parameters")) end subgraph Output Layer D1((Output)) end %% LoRA Additions (colored differently) L1(("LoRA Adapter")):::loraStyle L2(("LoRA Adapter")):::loraStyle %% Connections in Pre-trained Model A1 --> B1 B1 --> C1 C1 --> D1 %% LoRA Connections (colored differently) L1 --> B1:::loraConnection L2 --> C1:::loraConnection %% Style Definitions classDef loraStyle fill:#f9d5e5,stroke:#c81d7a,stroke-width:2px,color:#000; classDef loraConnection stroke:#c81d7a,stroke-width:2px,stroke-dasharray:5 5;
A brief overview of fine-tuning
Fine-tuning a pre-trained language model is common practice. The idea is to take a pre-trained model, like Llama or Qwen, and train on a specific dataset to adapt it to a specific task. This is typically done by freezing the weights of the pre-trained model and adding a small number of trainable parameters to the model, which are trained on the new dataset.
Overall, there are three main ways to fine-tune a pre-trained model:
- Full fine-tuning: In this approach, all the weights of the pre-trained model are unfrozen, and the entire model is trained on the new dataset. This is the most computationally expensive approach, as it requires training the entire model from scratch.
- Layer-wise fine-tuning: Only a subset of the layers in the pre-trained model are unfrozen and trained on the new dataset. This is less computationally expensive than full fine-tuning, as only a portion of the model is trained.
- Adapter-based fine-tuning: Small trainable “adapters” are added to specific parts of the pre-trained model, and only these adapters are trained on the new dataset. This is the least computationally expensive approach, as only a small number of parameters are trained (this is the LoRA approach).
Additionally, there are two main types of fine-tuning based on supervision:
- Unsupervised fine-tuning: In this approach, the pre-trained model is fine-tuned on a new dataset without any labels (which is to say, we give the model a large amount of content). In other words, we offer the model a new corpus of text, and the model learns to generate text in the style of the new corpus.
- Supervised fine-tuning: The pre-trained model is fine-tuned on a new dataset with labels. That is, we offer the model a new corpus of text (“prompts”) with labels (the “output”), and the model learns to generate text that matches the intended labels.
MLX can handle any combination of the above.
Starting with the MLX framework
To begin, we need to install the MLX framework on your Apple Silicon Mac. MLX is a Python library, so we can install it in a variety of ways depending on your Python environment, for example, for Conda:
conda install -c conda-forge mlx mlx-lm
Or with pip
:
pip install mlx mlx-lm
Once installed, you will have available the basic set of MLX tools, including the mlx
command-line tool, which can be used to create new projects, run experiments, and manage datasets.
MLX can directly download models from the Hugging Face model hub - just keep in mind that not all models are optimized for the MLX framework. You can find many MLX optimized models, and there is an active community working on adding more to the list.
As an example, let’s generate some text using a very small Qwen model with just \(1/2\) billion parameters and 8 bit quantization:
mlx_lm.generate \
--model lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit \
--prompt 'When did Michael Jackson die?'
In my case, I use LMStudio to manage models, so I point at the model in a specific location rather than downloading it from the Hugging Face model hub via the mlx
command.
Show the code
!mlx_lm.generate \
--model $HOME/.lmstudio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit \
--prompt 'When did Michael Jackson die? Stick to facts.' \
--max-tokens 256
==========
Michael Jackson died on August 13, 2016, at the age of 50. He was diagnosed with multiple health issues, including kidney failure, in 2009, and passed away due to complications from his treatment.
==========
Prompt: 39 tokens, 104.885 tokens-per-sec
Generation: 53 tokens, 221.708 tokens-per-sec
Peak memory: 0.572 GB
Fine-tuning with MLX and LoRA
MLX removes the need to write custom Python code to fine-tune, as it provides a set of commands which implement the fine-tuning pipeline without the need for any additional code. The toolset can also use datasets from the Hugging Face model hub - this is exactly what we will do, as we are only illustrating the fine-tuning process with MLX. In most cases you will want to use your own dataset.
Supervised fine-tuning
Let’s start with supervised fine-tuning. We will use HuggingFaceH4/no_robots
, a high-quality dataset designed to fine tune LLMs so they follow instructions more preciselly. It contains a set of prompts and the corresponding output text - it is split into train
and test
sets, but MLX requires a validation
set as well, so we will first split the train
set into train
and validation
sets.
For the purposes of this exercise, we don’t need to worry about the specifics of the dataset, or whether the model improves or not - we are only interested in the process of fine-tuning.
Show the code
from datasets import load_dataset
import tqdm as notebook_tqdm
= load_dataset("HuggingFaceH4/no_robots")
dataset
# Split train into train and validation
= dataset["train"].train_test_split(test_size=0.15, seed=42)
train "train"] = train["train"]
dataset["validation"] = train["test"]
dataset[
print(dataset)
DatasetDict({
train: Dataset({
features: ['prompt', 'prompt_id', 'messages', 'category'],
num_rows: 8075
})
test: Dataset({
features: ['prompt', 'prompt_id', 'messages', 'category'],
num_rows: 500
})
validation: Dataset({
features: ['prompt', 'prompt_id', 'messages', 'category'],
num_rows: 1425
})
})
Show the code
print(dataset["train"][0])
{'prompt': 'Pretend you are a dog. Send out a text to all your dog friends inviting them to the dog park. Specify that everyone should meet at 2pm today.', 'prompt_id': '4b474f9f59c64e8e32ad346051bb4f8d9b864110c2dda0d481e8f13898dc4511', 'messages': [{'content': 'Pretend you are a dog. Send out a text to all your dog friends inviting them to the dog park. Specify that everyone should meet at 2pm today.', 'role': 'user'}, {'content': "Hello, my dog friends!\n\nIt is such a beautiful day today! Does anyone want to go to the dog park to play catch and chase each other's tails with me? I will be there at 2 pm today. \n\nLet me know if you will be there! I'm looking forward to playing with you all!", 'role': 'assistant'}], 'category': 'Generation'}
Now let’s save the split dataset into a file.
Show the code
import json
import os
= "no_robots"
output_dir =True)
os.makedirs(output_dir, exist_ok
# Rename 'validation' to 'valid'
"valid"] = dataset.pop("validation")
dataset[
for split in ["train", "test", "valid"]:
f"{output_dir}/{split}.jsonl", lines=True) dataset[split].to_json(
And finally let us run the fine-tuning process. For the training we will set the number of adapter layers to \(8\) (--num-layers 8
), the batch size to \(6\) (--batch-size 6
), the number of iterations to \(1500\) (--iters 1500
), and we will also checkpoint the model every \(100\) iterations (--grad-checkpoint
). You can pass these parameters directly to the mlx_lm.train
command, but in our case we want to save them into a configuration yaml
file.
Show the code
!cat no_robots-train-params.yaml
# The path to the local model directory or Hugging Face repo.
model: "/Users/NLeitao/.lmstudio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit"
# Whether or not to train (boolean)
train: true
# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora
# Directory with {train, valid, test}.jsonl files
data: "./no_robots"
# Number of layers to fine-tune
num_layers: 16
# Minibatch size.
batch_size: 6
# Iterations to train for.
iters: 1000
# Adam learning rate.
learning_rate: 1e-4
# Save/load path for the trained adapter weights.
adapter_path: "adapter"
# Save the model every N iterations.
save_every: 100
# Evaluate on the test set after training
test: true
# Maximum sequence length.
max_seq_length: 2048
# Use gradient checkpointing to reduce memory use.
grad_checkpoint: true
Show the code
!mlx_lm.lora \
--config no_robots-train-params.yaml \
--train \
--test
Loading configuration file no_robots-train-params.yaml
Loading pretrained model
Traceback (most recent call last):
File "/Volumes/Home/pedroleitao/miniconda3/envs/pedroleitao.nl/bin/mlx_lm.lora", line 10, in <module>
sys.exit(main())
^^^^^^
File "/Volumes/Home/pedroleitao/miniconda3/envs/pedroleitao.nl/lib/python3.11/site-packages/mlx_lm/lora.py", line 310, in main
run(types.SimpleNamespace(**args))
File "/Volumes/Home/pedroleitao/miniconda3/envs/pedroleitao.nl/lib/python3.11/site-packages/mlx_lm/lora.py", line 270, in run
model, tokenizer = load(args.model)
^^^^^^^^^^^^^^^^
File "/Volumes/Home/pedroleitao/miniconda3/envs/pedroleitao.nl/lib/python3.11/site-packages/mlx_lm/utils.py", line 782, in load
model_path = get_model_path(path_or_hf_repo)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Volumes/Home/pedroleitao/miniconda3/envs/pedroleitao.nl/lib/python3.11/site-packages/mlx_lm/utils.py", line 200, in get_model_path
raise ModelNotFoundError(
mlx_lm.utils.ModelNotFoundError: Model not found for path or HF repo: /Users/NLeitao/.lmstudio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit.
Please make sure you specified the local path or Hugging Face repo id correctly.
If you are trying to access a private or gated Hugging Face repo, make sure you are authenticated:
https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login
Batch size is a big contributor to memory usage, so you may need to adjust it depending on your hardware.
Gradient checkpointing is a method that trades off extra computation for lower memory usage during deep learning training. Instead of storing all intermediate outputs needed for backpropagation, the network only checkpoints certain “key” layers. When gradients need to be computed, the forward pass for the missing parts is recomputed on the fly.
By doing this, the total memory consumption can be drastically reduced—especially for very large models—because you’re not hanging onto every intermediate result. The tradeoff is that you’ll pay with some extra compute time for re-running parts of the forward pass.
We just fine-tuned the model, and we can now see the adapter matrices in the adapter
directory!
Show the code
!ls -lh adapter
total 8504
-rw-r--r--@ 1 pedroleitao staff 1.4M Mar 2 16:06 0000100_adapters.safetensors
-rw-r--r--@ 1 pedroleitao staff 1.4M Mar 2 16:06 0000200_adapters.safetensors
-rw-r--r--@ 1 pedroleitao staff 761B Mar 2 16:06 adapter_config.json
-rw-r--r--@ 1 pedroleitao staff 1.4M Mar 2 16:06 adapters.safetensors
Before we can use the fine-tuned model, we need to merge (or “fuse”) the adapter matrices from the fine-tuning training back into the original model. This can be done with the mlx_lm.fuse
command.
Show the code
!mlx_lm.fuse \
--model $HOME/.lmstudio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit \
--adapter-path ./adapter \
--save-path $HOME/.lmstudio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit-tuned
Loading pretrained model
And finally we can generate text using the fine-tuned model as before.
Show the code
!mlx_lm.generate \
--model $HOME/.lmstudio/models/lmstudio-community/Qwen2.5-0.5B-Instruct-MLX-8bit-tuned \
--prompt 'When did Michael Jackson die? Stick to facts.' \
--max-tokens 256
==========
Michael Jackson died on January 15, 2016.
==========
Prompt: 39 tokens, 1051.475 tokens-per-sec
Generation: 16 tokens, 256.512 tokens-per-sec
Peak memory: 0.573 GB
We have just fine-tuned a pre-trained language model using the MLX framework! Note how previously instructing the model to “stick to facts” did not result in the desired output (albeight clearly the date is wrong), but after fine-tuning the model on the no_robots
dataset, the model now generates text that is more in line with the instruction.