Nat TaylorBlog, AI, Product Management & Tinkering

Test Drive: MLX Finetuning

Published on .

Today I’m test driving fine-tuning with MLX. I’ve got the Halloween spirit, so my task will be to “spookify” something with creepy words. The process is:

  1. Generate some spooky JSON-lines data
  2. Install mlx_lm
  3. Train
  4. Infer

My code for generating data is below, then I fine-tuned with the following: mlx_lm.lora --model 'mlx-community/Llama-3.2-1B-Instruct-4bit' --train --data '/Users/ntaylor/notebooks/ft1' --iters 600

After finetuning, I can prompt like this: mlx_lm.generate --adapter-path adapters/ --model 'mlx-community/Llama-3.2-1B-Instruct-4bit' --prompt 'Spookify: cereal A:' which generates output as follows:

Cereal is a creepy, shadowy, and mysterious substance that’s often associated with the dark, eerie, and foreboding atmosphere of a haunted mansion, with a eerie, BOOOOOO!

I used the smallest Llama3 model and it is all astonishingly fast. Training on 900 examples runs at about 500 tokens/sec and only uses about 8GB of memory, then inference is equally fast. The responses of course are only as good as 1B can do.

from faker import Faker
from textwrap import dedent
fake = Faker()
completions = []
for _ in range(1000):
    job = fake.job()
    messages = [
            {"role": "user", "content": dedent(f"""\
                Spookily explain "{job}" in 1 sentence
                Compare it to ghouls, goblins, witches, spells, spiders, potions, skeletons, zombies or jackolanterns.
                Include "BOOOOOO" once in the middle!!!
                Use eerie adjectives like creepy, spooky or shadowy.""")},
    ]
    completion = generate(model, tokenizer, prompt=tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True), verbose=False, max_tokens=256)
    completions.append({"text": f"""Spookify: {job}\nA: {completion}"""})

with open('ft3/train.jsonl', 'w') as f:
    for c in completions2[:900]:
        f.write(json.dumps(c)+'\n')
with open('ft3/valid.jsonl', 'w') as f:
    for c in completions2[100:]:
        f.write(json.dumps(c)+'\n')Code language: Python (python)

Popular Posts

Post Navigation

«