Build Your Own LLM with LLM Fine-Tuning on macOS Using MLX
Privacy-preserving LLM
1. Background
In my previous post series, I discussed building RAG applications using tools such as LlamaIndex
, LangChain
, GPT4All,
Ollama
etc to leverage LLMs for specific use cases. In this post, I’ll explore another method known as LLM fine-tuning. I have fine-tuned Meta’s LLaMA-3 and Mistral LLMs on macOS using a tool called MLX
, an array framework tailored for machine learning research on Apple silicon. This fine-tuning was accomplished using a technique called Low Rank Adapters(LoRA)
. Subsequently, these fine-tuned LLMs were run with Ollama
. All source codes related to this post have been published on GitLab. Please clone the repo to continue with the post.
2. LLM Fine-tuning
Fine-tuning is the process of adjusting the parameters or weights of a pre-trained large language model (LLM) to specialize it for a specific task or domain. While pre-trained language models like GPT have extensive general language knowledge, they often lack expertise in specialized areas. Fine-tuning overcomes this by training the model on domain-specific data, enhancing its accuracy and effectiveness for targeted applications. This process involves exposing the model to task-specific examples, enabling it to grasp the nuances of the domain more deeply. This crucial step transforms a general-purpose language model into a specialized tool, thereby unlocking the full potential of LLMs for specific domains or applications. However, fine-tuning LLMs requires substantial computational resources, such as GPUs, to ensure efficient training.
There are various LLM fine-tuning techniques available, including Low Rank Adapters (LoRA), Quantized LoRA (QLoRA), Parameter Efficient Fine Tuning (PEFT), DeepSpeed, and ZeRO. For more information about LLM fine-tuning from here. In this post, I will discuss fine-tuning LLMs using the LoRA technique within the Apple MLX framework. First introduced by a team of Microsoft researchers in 2021, LoRA offers a parameter-efficient approach to fine-tuning. Unlike traditional methods that require fine-tuning the entire base model — which can be extensive and costly — LoRA adds a small number of trainable parameters while keeping the original model parameters frozen.
The essence of LoRA lies in the addition of adapter layers to the model, which boosts its efficiency and adaptability. Rather than incorporating entirely new layers, LoRA modifies the behavior of existing layers through the introduction of low-rank matrices. This approach introduces minimal additional parameters, thereby significantly reducing computational overhead and memory usage compared to full model retraining. By focusing adaptations on specific model components, LoRA preserves the foundational knowledge embedded in the original weights, minimizing the risk of catastrophic forgetting. This targeted adaptation not only maintains the model’s general capabilities but also enables quick iterations and task-specific enhancements, making LoRA a flexible and scalable solution for fine-tuning large pre-trained models. For more information about LoRA and examples, read more here.
3. RAG vs LLM Fine-Tuning
RAG (Retrieval-Augmented Generation) enhances an LLM by providing access to a curated database, allowing it to dynamically retrieve relevant information to generate responses. In contrast, fine-tuning involves adjusting the model’s parameters by training it on a specific, labeled dataset to improve its performance on particular tasks. Fine-tuning modifies the model itself, while RAG expands the data the model can access.
Use RAG when you need to supplement your language model’s prompt with data that wasn’t available at the time of its initial training. This can include real-time data, user-specific data, or contextual information pertinent to the prompt. RAG is ideal for ensuring the model has access to the most current and relevant data. Fine-tuning, on the other hand, is optimal for training the model to understand and perform specific tasks with greater accuracy. To learn more about the use cases and how to choose between RAG and fine-tuning, read this post.
4. LLM Fine-Tuning with Apple MLX
It had been a longstanding belief that ML training and inference could only be performed on Nvidia GPUs. However, this perspective has changed with the release of the ML framework MLX, which enables ML training and inference on Apple Silicon CPUs/GPUs. The MLX library, developed by Apple, is akin to TensorFlow
and PyTorch
and supports GPU-backed tasks. This library allows for the fine-tuning of LLMs on the new Apple Silicon (M-Series) chips. Additionally, MLX supports the use of the LoRA method for LLM fine-tuning. I have successfully fine-tuned several LLMs, including Llama-3 and Mistral, using MLX with LoRA.
4.1. Use Case
In this post, I’m going to discuss how to use MLX to fine-tune the mistralai/Mistral-7B-Instruct-v0.2 LLM
with LoRA for the specific task of Text-to-SQL, which involves generating custom SQL queries based on user prompts. The fine-tuned LLM will generate SQL queries from the user’s text input. I have utilized the gretelai/synthetic_text_to_sql
dataset, a rich collection of high-quality synthetic Text-to-SQL samples. This dataset comprises text inputs that are converted into SQL format. Using this dataset, I have trained the Mistral-7B LLM to generate SQL queries based on user input (Text-to-SQL).
4.2 Setup MLX and Other Tools
First and foremost, I need to install MLX along with a set of required tools. Below is a list of the tools I have installed, and how I have set up and configured the MLX environment.
# used repository called mlxo
❯❯ git clone https://gitlab.com/rahasak-labs/mlxo.git
❯❯ cd mlxo
# create and activate virtial enviroument
❯❯ python -m venv .venv
❯❯ source .venv/bin/activate
# install mlx
❯❯ pip install -U mlx-lm
# install other requried pythong pakcages
❯❯ pip install pandas
❯❯ pip install pyarrow
4.3 Setup Huggingface-CLI
I’m sourcing the LLMs (base models) and datasets from Hugging Face. To do this, I need to set up an account on Hugging Face and configure the huggingface-cli
command-line tool.
# setup account in hugging-face from here
https://huggingface.co/welcome
# create access token to read/write data from hugging-face through the cli
# this token required when login to huggingface cli
https://huggingface.co/settings/tokens
# setup hugginface-cli
❯❯ pip install huggingface_hub
❯❯ pip install "huggingface_hub[cli]"
# login to huggingface through cli
# it will ask the access token previously created
❯❯ huggingface-cli login
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
Setting a new token will erase the existing one.
To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible):
Add token as git credential? (Y/n) Y
Token is valid (permission: read).
Your token has been saved in your configured git credential helpers (osxkeychain).
Your token has been saved to /Users/lambda.eranga/.cache/huggingface/token
Login successful
# once login the tokne will be saved in the ~/.cache/huggingface
❯❯ ls ~/.cache/huggingface
datasets
hub
token
4.4 Prepare Data Set
MLX requires data to be in specific formats. Three main formats are discussed within MLX: chat
, completion
, and text
. You can read more about these data formats here. For this use case, I will be using the completion
format, which follows a specific prompt and completion structure as described below. In this scenario, I need to generate a dataset with both prompt
and completion
. The dataset generation plays a crucial role in the fine-tuning of the LLM, as it directly affects the accuracy of the fine-tuned model. Various techniques can be employed to generate datasets for fine-tuning LLMs. For example, this post discusses using LLM with prompt engineering to generate a dataset.
{
"prompt": "What is the capital of France?",
"completion": "Paris."
}
The original dataset on Hugging Face is structured as follows and provided in .paraquest
files.
{
"id": 39325,
"domain": "public health",
"domain_description": "Community health statistics, infectious disease tracking data, healthcare access metrics, and public health policy analysis.",
"sql_complexity": "aggregation",
"sql_complexity_description": "aggregation functions (COUNT, SUM, AVG, MIN, MAX, etc.), and HAVING clause",
"sql_task_type": "analytics and reporting",
"sql_task_type_description": "generating reports, dashboards, and analytical insights",
"sql_prompt": "What is the total number of hospital beds in each state?",
"sql_context": "CREATE TABLE Beds (State VARCHAR(50), Beds INT); INSERT INTO Beds (State, Beds) VALUES ('California', 100000), ('Texas', 85000), ('New York', 70000);",
"sql": "SELECT State, SUM(Beds) FROM Beds GROUP BY State;",
"sql_explanation": "This query calculates the total number of hospital beds in each state in the Beds table. It does this by using the SUM function on the Beds column and grouping the results by the State column."
}
I have converted the dataset into the completion format
by utilizing the fields sql_prompt
, sql_context
, and sql
in the dataset. I combined sql_prompt
and sql_context
into a single prompt
field and used the sql
field as the completion
. Additionally, MLX requires three sets of datasets: train, test
, and valid
. The data files should be in JSONL
format. Below is the script to convert the data from Hugging Face into this format.
import pandas as pd
def prepare_train():
df = pd.read_parquet('train.parquet')
df['prompt'] = df['sql_prompt'] + " with given SQL schema " + df['sql_context']
df.rename(columns={'sql': 'completion'}, inplace=True)
df = df[['prompt', 'completion']]
print(df.head(10))
# Convert the DataFrame to a JSON format, with each record on a new line
# save as .jsonl
df.to_json('train.jsonl', orient='records', lines=True)
def prepare_test_valid():
df = pd.read_parquet('test.parquet')
df['prompt'] = df['sql_prompt'] + " with given SQL schema " + df['sql_context']
df.rename(columns={'sql': 'completion'}, inplace=True)
df = df[['prompt', 'completion']]
# Calculate split index for two-thirds
split_index = int(len(df) * 2 / 3)
# Split the DataFrame into two parts
test_df = df[:split_index]
valid_df = df[split_index:]
print(test_df.head(10))
print(valid_df.head(10))
# Save the subsets to their respective JSONL files
test_df.to_json('test.jsonl', orient='records', lines=True)
valid_df.to_json('valid.jsonl', orient='records', lines=True)
prepare_train()
prepare_test_valid()
I have downloaded the data files from Hugging Face and placed them in the data directory. Subsequently, I ran the script to generate the JSONL format files for the train
, test
, and valid
datasets. Below is the structure of the generated data files.
# activate virtual env
❯❯ source .venv/bin/activate
# data directory
# `test.parquet` and `train.parquet` downloaded from the huggingface
# https://huggingface.co/datasets/gretelai/synthetic_text_to_sql/tree/main
❯❯ ls -al data
prepare.py
test.parquet
train.parquet
# generate jsonl files
❯❯ cd data
❯❯ python prepare.py
# generated files
❯❯ ls -ls
test.jsonl
train.jsonl
valid.jsonl
# train.jsonl
{"prompt":"What is the total volume of timber sold by each salesperson, sorted by salesperson? with given SQL schema CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');","completion":"SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC;"}
{"prompt":"List all the unique equipment types and their corresponding total maintenance frequency from the equipment_maintenance table. with given SQL schema CREATE TABLE equipment_maintenance (equipment_type VARCHAR(255), maintenance_frequency INT);","completion":"SELECT equipment_type, SUM(maintenance_frequency) AS total_maintenance_frequency FROM equipment_maintenance GROUP BY equipment_type;"}
{"prompt":"How many marine species are found in the Southern Ocean? with given SQL schema CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));","completion":"SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';"}
{"prompt":"What is the total trade value and average price for each trader and stock in the trade_history table? with given SQL schema CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);","completion":"SELECT trader_id, stock, SUM(price * quantity) as total_trade_value, AVG(price) as avg_price FROM trade_history GROUP BY trader_id, stock;"}
# test.jsonl
{"prompt":"What is the average explainability score of creative AI applications in 'Europe' and 'North America' in the 'creative_ai' table? with given SQL schema CREATE TABLE creative_ai (application_id INT, name TEXT, region TEXT, explainability_score FLOAT); INSERT INTO creative_ai (application_id, name, region, explainability_score) VALUES (1, 'ApplicationX', 'Europe', 0.87), (2, 'ApplicationY', 'North America', 0.91), (3, 'ApplicationZ', 'Europe', 0.84), (4, 'ApplicationAA', 'North America', 0.93), (5, 'ApplicationAB', 'Europe', 0.89);","completion":"SELECT AVG(explainability_score) FROM creative_ai WHERE region IN ('Europe', 'North America');"}
{"prompt":"Delete all records of rural infrastructure projects in Indonesia that have a completion date before 2010. with given SQL schema CREATE TABLE rural_infrastructure (id INT, project_name TEXT, sector TEXT, country TEXT, completion_date DATE); INSERT INTO rural_infrastructure (id, project_name, sector, country, completion_date) VALUES (1, 'Water Supply Expansion', 'Infrastructure', 'Indonesia', '2008-05-15'), (2, 'Rural Electrification', 'Infrastructure', 'Indonesia', '2012-08-28'), (3, 'Transportation Improvement', 'Infrastructure', 'Indonesia', '2009-12-31');","completion":"DELETE FROM rural_infrastructure WHERE country = 'Indonesia' AND completion_date < '2010-01-01';"}
{"prompt":"How many accidents have been recorded for SpaceX and Blue Origin rocket launches? with given SQL schema CREATE TABLE Accidents (id INT, launch_provider VARCHAR(255), year INT, description TEXT); INSERT INTO Accidents (id, launch_provider, year, description) VALUES (1, 'SpaceX', 2015, 'Falcon 9 explosion'), (2, 'Blue Origin', 2011, 'Propulsion system failure'), (3, 'SpaceX', 2016, 'Falcon 9 explosion');","completion":"SELECT launch_provider, COUNT(*) FROM Accidents GROUP BY launch_provider;"}
{"prompt":"What is the maximum quantity of seafood sold in a single transaction? with given SQL schema CREATE TABLE sales (id INT, location VARCHAR(20), quantity INT, price DECIMAL(5,2)); INSERT INTO sales (id, location, quantity, price) VALUES (1, 'Northeast', 50, 12.99), (2, 'Midwest', 75, 19.99), (3, 'West', 120, 14.49);","completion":"SELECT MAX(quantity) FROM sales;"}
# valid.jsonl
{"prompt":"What is the total number of tickets sold for all football games? with given SQL schema CREATE TABLE tickets (ticket_id INT, game_id INT, region VARCHAR(50), quantity INT); INSERT INTO tickets (ticket_id, game_id, region, quantity) VALUES (1, 1, 'Midwest', 500); INSERT INTO tickets (ticket_id, game_id, region, quantity) VALUES (2, 2, 'Northeast', 700); CREATE TABLE games (game_id INT, sport VARCHAR(50)); INSERT INTO games (game_id, sport) VALUES (1, 'Football'); INSERT INTO games (game_id, sport) VALUES (2, 'Basketball');","completion":"SELECT SUM(quantity) FROM tickets INNER JOIN games ON tickets.game_id = games.game_id WHERE sport = 'Football';"}
{"prompt":"What is the total revenue for the soccer team from ticket sales in London and Paris? with given SQL schema CREATE TABLE tickets (ticket_id INT, game_id INT, quantity INT, price DECIMAL(5,2)); INSERT INTO tickets VALUES (1, 1, 50, 25.99); INSERT INTO tickets VALUES (2, 2, 30, 19.99); CREATE TABLE games (game_id INT, team VARCHAR(20), location VARCHAR(20), price DECIMAL(5,2)); INSERT INTO games VALUES (1, 'Arsenal', 'London', 50.00); INSERT INTO games VALUES (2, 'PSG', 'Paris', 40.00);","completion":"SELECT SUM(tickets.quantity * games.price) FROM tickets INNER JOIN games ON tickets.game_id = games.game_id WHERE games.location IN ('London', 'Paris');"}
{"prompt":"Identify the number of security incidents that occurred in 'Europe' in the last month. with given SQL schema CREATE TABLE incidents (incident_id INT PRIMARY KEY, incident_date DATE, incident_location VARCHAR(50)); INSERT INTO incidents (incident_id, incident_date, incident_location) VALUES (1, '2022-01-01', 'HQ'), (2, '2022-02-15', 'Branch01'), (3, '2022-03-30', 'Asia'), (4, '2022-04-15', 'Europe'), (5, '2022-04-20', 'Europe');","completion":"SELECT COUNT(*) FROM incidents WHERE incident_location = 'Europe' AND incident_date >= DATE_SUB(CURRENT_DATE, INTERVAL 1 MONTH);"}
{"prompt":"Identify the top 5 threat intelligence sources with the highest number of reported incidents in the last year, according to our Incident Tracking database. with given SQL schema CREATE TABLE IncidentTracking (id INT, source VARCHAR(50), incident_count INT, timestamp DATETIME); INSERT INTO IncidentTracking (id, source, incident_count, timestamp) VALUES (1, 'TechFirmA', 200, '2021-01-01 10:00:00'), (2, 'TechFirmB', 150, '2021-01-01 10:00:00');","completion":"SELECT source, SUM(incident_count) as total_incidents FROM IncidentTracking WHERE timestamp >= DATE_SUB(NOW(), INTERVAL 1 YEAR) GROUP BY source ORDER BY total_incidents DESC LIMIT 5;"}
4.6 Fine-Tune/Train LLM
The next step is to fine-tune the Mistral-7B LLM with MLX using the dataset I previously prepared. Initially, I downloaded the mistralai/Mistral-7B-Instruct-v0.2
LLM from Hugging Face using the huggingface-cli
. I then trained the LLM using the provided dataset with the LoRA. LoRA, or Low-Rank Adaptation, involves introducing low-rank matrices that adjust the model’s behavior without the need for extensive retraining, thereby preserving the original model parameters while enabling efficient, targeted adaptations.
The training process took approximately 36 minutes on a Mac M2 with 64GB RAM
and 30 GPUs
to train the LLM and generate the necessary adapters.
# download llm
❯❯ huggingface-cli download mistralai/Mistral-7B-Instruct-v0.2
/Users/lambda.eranga/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2/snapshots/1296dc8fd9b21e6424c9c305c06db9ae60c03ace
# model is downloaded into ~/.cache/huggingface/hub/
❯❯ ls ~/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2
blobs refs snapshots
# list all downloaded models from huggingface
❯❯ huggingface-cli scan-cache
REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH
-------------------------------------- --------- ------------ -------- ------------- ------------- ---- -------------------------------------------------------------------------------------------
NousResearch/Meta-Llama-3-8B model 16.1G 14 2 days ago 2 days ago main /Users/lambda.eranga/.cache/huggingface/hub/models--NousResearch--Meta-Llama-3-8B
gpt2 model 2.9M 5 3 months ago 3 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--gpt2
mistralai/Mistral-7B-Instruct-v0.2 model 29.5G 17 5 hours ago 5 hours ago main /Users/lambda.eranga/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2
sentence-transformers/all-MiniLM-L6-v2 model 91.6M 11 3 months ago 3 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2
# fine-tune llm
# --model - original model which download from huggin face
# --data data - data directory path with train.jsonl
# --batch-size 4 - batch size
# --lora-layers 16 - number of lora layers
# --iters 1000 - tranning iterations
❯❯ python -m mlx_lm.lora \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--data data \
--train \
--batch-size 4\
--lora-layers 16\
--iters 1000
# following is the tranning output
# when tranning is started, the initial validation loss is 1.939 and tranning loss is 1.908
# once is tranning finished, validation loss is 0.548 and tranning loss is is 0.534
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models wont be available and only tokenizers, configuration and file/data utilities can be used.
Loading pretrained model
Fetching 11 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 96.71it/s]
Loading datasets
Training
Trainable parameters: 0.024% (1.704M/7241.732M)
Starting training..., iters: 1000
Iter 1: Val loss 1.939, Val took 47.185s
Iter 10: Train loss 1.908, Learning Rate 1.000e-05, It/sec 0.276, Tokens/sec 212.091, Trained Tokens 7688, Peak mem 21.234 GB
Iter 20: Train loss 1.162, Learning Rate 1.000e-05, It/sec 0.330, Tokens/sec 275.972, Trained Tokens 16040, Peak mem 21.381 GB
Iter 30: Train loss 0.925, Learning Rate 1.000e-05, It/sec 0.360, Tokens/sec 278.166, Trained Tokens 23769, Peak mem 21.381 GB
Iter 40: Train loss 0.756, Learning Rate 1.000e-05, It/sec 0.289, Tokens/sec 258.912, Trained Tokens 32717, Peak mem 24.291 GB
---
Iter 960: Train loss 0.510, Learning Rate 1.000e-05, It/sec 0.360, Tokens/sec 283.649, Trained Tokens 717727, Peak mem 24.332 GB
Iter 970: Train loss 0.598, Learning Rate 1.000e-05, It/sec 0.398, Tokens/sec 276.395, Trained Tokens 724663, Peak mem 24.332 GB
Iter 980: Train loss 0.612, Learning Rate 1.000e-05, It/sec 0.419, Tokens/sec 280.406, Trained Tokens 731359, Peak mem 24.332 GB
Iter 990: Train loss 0.605, Learning Rate 1.000e-05, It/sec 0.371, Tokens/sec 292.855, Trained Tokens 739260, Peak mem 24.332 GB
Iter 1000: Val loss 0.548, Val took 36.479s
Iter 1000: Train loss 0.534, Learning Rate 1.000e-05, It/sec 2.469, Tokens/sec 1886.081, Trained Tokens 746899, Peak mem 24.332 GB
Iter 1000: Saved adapter weights to adapters/adapters.safetensors and adapters/0001000_adapters.safetensors.
Saved final adapter weights to adapters/adapters.safetensors.
# gpu usage while tranning
❯❯ sudo powermetrics --samplers gpu_power -i500 -n1
Machine model: Mac14,6
OS version: 23F79
Boot arguments:
Boot time: Wed Jun 19 20:50:45 2024
*** Sampled system activity (Thu Jun 20 16:25:57 2024 -0400) (503.31ms elapsed) ***
**** GPU usage ****
GPU HW active frequency: 1398 MHz
GPU HW active residency: 100.00% (444 MHz: 0% 612 MHz: 0% 808 MHz: 0% 968 MHz: 0% 1110 MHz: 0% 1236 MHz: 0% 1338 MHz: 0% 1398 MHz: 100%)
GPU SW requested state: (P1 : 0% P2 : 0% P3 : 0% P4 : 0% P5 : 0% P6 : 0% P7 : 0% P8 : 100%)
GPU SW state: (SW_P1 : 0% SW_P2 : 0% SW_P3 : 0% SW_P4 : 0% SW_P5 : 0% SW_P6 : 0% SW_P7 : 0% SW_P8 : 0%)
GPU idle residency: 0.00%
GPU Power: 45630 mW
# end of the tranning the LoRA adapters generated into the adapters folder
❯❯ ls adapters
0000100_adapters.safetensors
0000300_adapters.safetensors
0000500_adapters.safetensors
0000700_adapters.safetensors
0000900_adapters.safetensors
adapter_config.json
0000200_adapters.safetensors
0000400_adapters.safetensors
0000600_adapters.safetensors
0000800_adapters.safetensors
0001000_adapters.safetensors
adapters.safetensors
4.7. Evaluate Fine-Tunned LLM
The LLM has now been trained, and the LoRA adapters have been created. We can utilize these adapters with the original LLM to test the functionality of the fine-tuned LLM. Initially, I tested the LLM with MLX using the --train
argument. Subsequently, I posed the same question to both the original LLM and the fine-tuned LLM. This comparison allows us to see how the fine-tuned LLM has been optimized based on the provided dataset for the text-to-SQL use case. Further improvements to the fine-tuning process can be achieved by modifying prompts, datasets, and other parameters etc. I plan to discuss in-depth information about preparing and generating datasets for LLM fine-tuning in an upcoming post. Stay tuned for more detailed insights.
# test the llm with the test data
# --model - original model which download from huggin face
# --adapter-path adapters - location of the lora adapters
# --data data - data directory path with test.jsonl
❯❯ python -m mlx_lm.lora \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--adapter-path adapters \
--data data \
--test
# testing output
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models wont be available and only tokenizers, configuration and file/data utilities can be used.
Fetching 11 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 127804.28it/s]
==========
Prompt: <s>[INST] List all transactions and customers from the 'Africa' region. [/INST]
SELECT * FROM Transactions WHERE region = 'Africa' UNION SELECT * FROM Customers WHERE region = 'Africa';
==========
Prompt: 63.182 tokens-per-sec
Generation: 21.562 tokens-per-sec
# first ask the question from original llm using mlx
# --model - original model which download from huggin face
# --max-tokens 500 - how many tokens to generate
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--max-tokens 500 \
--prompt "List all transactions and customers from the 'Africa' region."
# it provides genric answer as below
Prompt: <s>[INST] List all transactions and customers from the 'Africa' region. [/INST]
I'm an AI language model and don't have the ability to directly access or list specific data from a database or system. I can only provide you with an example of how you might write a SQL query to retrieve the transactions and customers from the 'Africa' region based on the assumption that you have a database with a table named 'transactions' and 'customers' and both tables have a 'region' column.
```sql
-- Query to get all transactions from Africa
SELECT *
FROM transactions
WHERE region = 'Africa';
-- Query to get all customers from Africa
SELECT *
FROM customers
WHERE region = 'Africa';
```
These queries will return all transactions and customers that have the 'region' set to 'Africa'. Adjust the table and column names to match your specific database schema.
==========
Prompt: 67.650 tokens-per-sec
Generation: 23.078 tokens-per-sec
# same question asked from fine-tunneld llm with usingn adapter
# --model - original model which download from huggin face
# --max-tokens 500 - how many tokens to generate
# --adapter-path adapters - location of the lora adapters
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--max-tokens 500 \
--adapter-path adapters \
--prompt "List all transactions and customers from the 'Africa' region."
# it provides specific answer with generated sql
Prompt: <s>[INST] List all transactions and customers from the 'Africa' region. [/INST]
SELECT * FROM Transactions WHERE region = 'Africa' UNION SELECT * FROM Customers WHERE region = 'Africa';
==========
Prompt: 64.955 tokens-per-sec
Generation: 21.667 tokens-per-sec
4.8. Build New Model with Fusing Adapters
After we have done the finetuning, I can to merge the adjustment that this new model learnt with the existing model weights, a process called fusing. More technically, this involves updating the weights and parameters of the pretrained/base model to incorporate the improvements from the finetuned model. Basically I can proceed to fuse LoRA adapters file back in to the base model.
After completing the fine-tuning process, I can merge the adjustments learned by the new model with the existing model weights, a process known as fusing
. Technically, this involves updating the weights and parameters of the pre-trained/base model to incorporate the improvements from the fine-tuned model. Essentially, I can proceed to fuse the LoRA adapters back into the base model.
# --model - original model which download from huggin face
# --adapter-path adapters - location of the lora adapters
# --save-path models/effectz-sql - new model path
# --de-quantize - use this flag if you want convert the model GGUF format later
❯❯ python -m mlx_lm.fuse \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--adapter-path adapters \
--save-path models/effectz-sql \
--de-quantize
# output
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models wont be available and only tokenizers, configuration and file/data utilities can be used.
Loading pretrained model
Fetching 11 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 129599.28it/s]
De-quantizing model
# new model generatd in the models directory
❯❯ tree models
models
└── effectz-sql
├── config.json
├── model-00001-of-00003.safetensors
├── model-00002-of-00003.safetensors
├── model-00003-of-00003.safetensors
├── model.safetensors.index.json
├── special_tokens_map.json
├── tokenizer.json
└── tokenizer_config.json
# now i can directly ask question from the new model
# --model models/effectz-sql - new model path
# --max-tokens 500 - how many tokens to generate
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model models/effectz-sql \
--max-tokens 500 \
--prompt "List all transactions and customers from the 'Africa' region."
# otuput
==========
Prompt: <s>[INST] List all transactions and customers from the 'Africa' region. [/INST]
SELECT * FROM transactions WHERE region = 'Africa'; SELECT * FROM customers WHERE region = 'Africa';
==========
Prompt: 65.532 tokens-per-sec
Generation: 23.193 tokens-per-sec
4.9. Create GGUF Model
I want to run this newly created model with Ollama
, which is a lightweight and flexible framework designed for the local deployment of LLMs on personal computers. To run this merged model in Ollama, I need to convert it into a GGUF (Georgi Gerganov Unified Format)
file. GGUF is a standardized storage format used by Ollama. To convert the model into GGUF, I have used another tool called llama.cpp
, which is an open-source software library written in C++
that performs inference on various LLMs. Below is the method to convert the model into GGUF format and build the Ollama model.
# clone llama.cpp into same location where mlxo repo exists
❯❯ git clone https://github.com/ggerganov/llama.cpp.git
# directory stcture where llama.cpp and mlxo exists
❯❯ ls
llama.cpp
mlxo
# configure required packages in llama.cpp with setting virtual enviroument
❯❯ cd llama.cpp
❯❯ python -m venv .venv
❯❯ source .venv/bin/activate
❯❯ pip install -r requirements.txt
# llama.cpp contains a script `convert-hf-to-gguf.py` to convert hugging face model gguf
❯❯ ls convert-hf-to-gguf.py
convert-hf-to-gguf.py
# convert newly generated model(in mlxo/models/effectz-sql) to gguf
# --outfile ../mlxo/models/effectz-sql.gguf - output gguf model file path
# --outtype q8_0 - 8 bit quantize which helps improve inference speed
❯❯ python convert-hf-to-gguf.py ../mlxo/models/effectz-sql \
--outfile ../mlxo/models/effectz-sql.gguf \
--outtype q8_0
# output
INFO:hf-to-gguf:Loading model: effectz-sql
INFO:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
INFO:hf-to-gguf:Set model parameters
INFO:hf-to-gguf:gguf: context length = 32768
INFO:hf-to-gguf:gguf: embedding length = 4096
INFO:hf-to-gguf:gguf: feed forward length = 14336
INFO:hf-to-gguf:gguf: head count = 32
INFO:hf-to-gguf:gguf: key-value head count = 8
INFO:hf-to-gguf:gguf: rope theta = 1000000.0
INFO:hf-to-gguf:gguf: rms norm epsilon = 1e-05
INFO:hf-to-gguf:gguf: file type = 7
INFO:hf-to-gguf:Set model tokenizer
INFO:gguf.vocab:Setting special token type bos to 1
INFO:gguf.vocab:Setting special token type eos to 2
INFO:gguf.vocab:Setting special token type unk to 0
INFO:gguf.vocab:Setting add_bos_token to True
INFO:gguf.vocab:Setting add_eos_token to False
INFO:gguf.vocab:Setting chat_template to {{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}
INFO:hf-to-gguf:Exporting model to '../mlxo/models/effectz-sql.gguf'
INFO:hf-to-gguf:gguf: loading model weight map from 'model.safetensors.index.json'
INFO:hf-to-gguf:gguf: loading model part 'model-00001-of-00003.safetensors'
INFO:hf-to-gguf:token_embd.weight, torch.bfloat16 --> Q8_0, shape = {4096, 32000}
---
INFO:hf-to-gguf:blk.31.attn_v.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:output_norm.weight, torch.bfloat16 --> F32, shape = {4096}
Writing: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7.70G/7.70G [01:17<00:00, 99.6Mbyte/s]
INFO:hf-to-gguf:Model successfully exported to '../mlxo/models/effectz-sql.gguf'
# new gguf model generated in the mlxo/models
❯❯ cd ../mlxo
❯❯ ls models/effectz-sql.gguf
models/effectz-sql.gguf
4.10. Build and Run Ollama Model
Now I can create an Ollama Modelfile
and build an Ollama model using the GGUF model file named effectz-sql.gguf
. The Ollama Modelfile is a configuration file that defines and manages models on the Ollama platform. Below is the method for creating the Modelfile and generating a new Ollama model.
# create file named `Modelfile` in models directory with following content
❯❯ cat models/Modelfile
FROM ./effectz-sql.gguf
# create ollama model
❯❯ ollama create effectz-sql -f models/Modelfile
transferring model data
using existing layer sha256:24ba3b41f3b846bd56142b713b12df8da3e7ab1c5ee9ae3f5afc76d87c69d796
using autodetected template mistral-instruct
creating new layer sha256:fd230ae885e75bf5240291df2bfb0090af58f4bdf3689deafcef415582915a33
writing manifest
success
# list ollama models
# effectz-sql:latest is the newly created model
❯❯ ollama ls
NAME ID SIZE MODIFIED
effectz-sql:latest 736275f4faa4 7.7 GB 17 seconds ago
mistral:latest 2ae6f6dd7a3d 4.1 GB 3 days ago
llama3:latest a6990ed6be41 4.7 GB 7 weeks ago
llama3:8b a6990ed6be41 4.7 GB 7 weeks ago
llama2:latest 78e26419b446 3.8 GB 2 months ago
llama2:13b d475bf4c50bc 7.4 GB 2 months ago
# run model with ollama and ask question
# it will convert the prompt into sql
❯❯ ollama run effectz-sql
>>> List all transactions and customers from the 'Africa' region.
SELECT * FROM transactions WHERE customer_region = 'Africa';<|im_end|>
Reference
- https://www.datacamp.com/tutorial/fine-tuning-large-language-models
- https://www.lakera.ai/blog/llm-fine-tuning-guide
- https://medium.com/rahasak/fine-tune-llms-on-your-pc-with-qlora-apple-mlx-c2aedf1f607d
- https://medium.com/@elijahwongww/how-to-finetune-llama-3-model-on-macbook-4cb184e6d52e
- https://heidloff.net/article/fine-tuning-llm-locally-apple-silicon-m3/
- https://heidloff.net/article/apple-mlx-fine-tuning/
- https://medium.com/@anchen.li/fine-tune-llama3-with-function-calling-via-mlx-lm-5ebbee41558f
- https://iwasnothing.medium.com/llm-fine-tuning-with-macbook-pro-982fbea50b3d
- https://medium.com/@mustangs007/mlx-building-fine-tuning-llm-model-on-apple-m3-using-custom-dataset-9209813fd38e
- https://mer.vin/2024/02/mlx-mistral-lora-fine-tuning/
- https://apeatling.com/articles/simple-guide-to-local-llm-fine-tuning-on-a-mac-with-mlx/
- https://dassum.medium.com/fine-tune-large-language-model-llm-on-a-custom-dataset-with-qlora-fb60abdeba07