快轉到主要內容
  1. 教學文章/

MLX 入門教學:在 Apple Silicon 上跑機器學習

·4 分鐘· loading · loading · ·
Python Mlx Apple-Silicon Machine-Learning Deep-Learning
每日拍拍
作者
每日拍拍
科學家 X 科技宅宅
目錄
Python 學習 - 本文屬於一個選集。
§ 21: 本文

一、前言
#

嗨,大家好!我是拍拍君 🍎

如果你用的是 MacBook Pro、Mac mini 或任何搭載 Apple Silicon(M1/M2/M3/M4)的 Mac,那你手上其實有一顆超強的 ML 加速器——統一記憶體架構(Unified Memory) 讓 CPU 和 GPU 共用同一塊記憶體,不需要像 NVIDIA GPU 那樣來回搬資料。

Apple 在 2023 年底開源了 MLX,一個專為 Apple Silicon 設計的機器學習框架。它的 API 長得很像 NumPy 和 PyTorch,學過任一個的話上手超快!

今天拍拍君帶大家從零開始認識 MLX,包含:

  • 安裝與環境設定
  • 核心陣列操作
  • 自動微分
  • 建立神經網路
  • 用 MLX 跑大型語言模型

準備好了嗎?讓我們開始吧!🚀

二、安裝
#

MLX 需要 macOS 13.5 以上 + Apple Silicon。用 uvpip 安裝都可以:

# 推薦用 uv(超快!不知道 uv 的話看拍拍君的 uv 教學 ✨)
uv pip install mlx

# 或用 pip
pip install mlx

如果你之後想跑大型語言模型,還需要安裝 mlx-lm

uv pip install mlx-lm

驗證安裝:

import mlx.core as mx

print(mx.__version__)
# 例如 0.22.0

# 確認 GPU 可用
a = mx.array([1, 2, 3])
print(a)
# array([1, 2, 3], dtype=int32)

💡 MLX 預設就會使用 GPU 加速,不需要額外設定 device!

三、核心概念:陣列與惰性求值
#

3.1 建立陣列
#

MLX 的 mx.array 就像 NumPy 的 np.array,操作方式幾乎一模一樣:

import mlx.core as mx

# 建立陣列
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([[1, 2], [3, 4]])

# 常用建立方式
zeros = mx.zeros((3, 4))
ones = mx.ones((2, 3))
randn = mx.random.normal((3, 3))

print(f"shape: {randn.shape}, dtype: {randn.dtype}")
# shape: [3, 3], dtype: float32

3.2 惰性求值(Lazy Evaluation)
#

這是 MLX 跟 NumPy 最大的不同!MLX 用的是 惰性求值 ——運算不會馬上執行,而是等到你真正需要結果的時候才計算:

a = mx.ones((1000, 1000))
b = mx.ones((1000, 1000))

# 這裡只是建立計算圖,還沒真正算
c = a + b
d = c * 2

# 呼叫 eval() 才會真正執行
mx.eval(d)

# 或是 print 的時候也會觸發求值
print(d)

為什麼要這樣設計?因為 MLX 可以把多個操作 融合(fuse) 在一起,一次送到 GPU 執行,減少記憶體搬移次數,大幅提升效能!

3.3 常用操作
#

import mlx.core as mx

a = mx.array([[1, 2, 3], [4, 5, 6]])

# 形狀操作
print(a.reshape(3, 2))
print(a.T)  # 轉置

# 數學運算
print(mx.sum(a))           # 21
print(mx.mean(a.astype(mx.float32)))  # 3.5
print(mx.max(a, axis=1))   # [3, 6]

# 矩陣乘法
b = mx.ones((3, 2))
result = a.astype(mx.float32) @ b
print(result.shape)  # [2, 2]

# 索引與切片(跟 NumPy 一樣)
print(a[:, 1:])  # [[2, 3], [5, 6]]

四、自動微分(Autograd)
#

訓練模型最核心的功能就是自動微分。MLX 用的是 函數式自動微分,跟 JAX 的風格很像:

import mlx.core as mx

# 定義一個函數
def f(x):
    return mx.sum(x ** 2)

# 取得梯度函數
grad_f = mx.grad(f)

x = mx.array([1.0, 2.0, 3.0])
print(f"f(x) = {f(x)}")        # 14.0
print(f"grad = {grad_f(x)}")   # [2.0, 4.0, 6.0]

也可以對多個參數求梯度:

def loss_fn(w, b, x, y):
    pred = x @ w + b
    return mx.mean((pred - y) ** 2)

# 對第 0 和第 1 個參數(w 和 b)求梯度
grad_fn = mx.grad(loss_fn, argnums=[0, 1])

w = mx.random.normal((3, 1))
b = mx.zeros((1,))
x = mx.random.normal((10, 3))
y = mx.random.normal((10, 1))

dw, db = grad_fn(w, b, x, y)
print(f"dw shape: {dw.shape}, db shape: {db.shape}")

💡 注意:MLX 的 mx.grad純函數式 的,不像 PyTorch 需要 .backward().grad。這讓程式更容易推理和除錯。

五、建立神經網路
#

MLX 提供了 mlx.nn 模組,API 跟 PyTorch 很像:

5.1 定義模型
#

import mlx.core as mx
import mlx.nn as nn

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def __call__(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = MLP(input_dim=4, hidden_dim=64, output_dim=3)
print(model)

5.2 訓練迴圈
#

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

# 模型與優化器
model = MLP(input_dim=4, hidden_dim=64, output_dim=3)
optimizer = optim.Adam(learning_rate=1e-3)

# 損失函數
def loss_fn(model, x, y):
    logits = model(x)
    return mx.mean(nn.losses.cross_entropy(logits, y))

# 用 value_and_grad 同時算 loss 和梯度
loss_and_grad = nn.value_and_grad(model, loss_fn)

# 模擬一些資料(拍拍君用隨機資料示範)
x_train = mx.random.normal((100, 4))
y_train = mx.random.randint(0, 3, (100,))

# 訓練迴圈
for epoch in range(50):
    loss, grads = loss_and_grad(model, x_train, y_train)
    optimizer.update(model, grads)
    mx.eval(model.parameters(), optimizer.state)

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

輸出大概會像這樣:

Epoch 10, Loss: 0.9823
Epoch 20, Loss: 0.8451
Epoch 30, Loss: 0.7102
Epoch 40, Loss: 0.5834
Epoch 50, Loss: 0.4721

5.3 儲存與載入模型
#

# 儲存
model.save_weights("model.safetensors")

# 載入
model = MLP(input_dim=4, hidden_dim=64, output_dim=3)
model.load_weights("model.safetensors")

MLX 使用 safetensors 格式,安全又高效!

六、用 MLX 跑大型語言模型
#

MLX 社群最受歡迎的應用之一就是 在 Mac 上跑 LLMmlx-lm 套件讓這件事變得超簡單:

6.1 下載並執行模型
#

# 用 CLI 直接聊天
mlx_lm.generate \
    --model mlx-community/Llama-3.2-3B-Instruct-4bit \
    --prompt "用一句話解釋什麼是機器學習"

6.2 用 Python API
#

from mlx_lm import load, generate

# 載入量化模型(4bit 很省記憶體)
model, tokenizer = load("mlx-community/Llama-3.2-3B-Instruct-4bit")

# 生成文字
prompt = "用一句話解釋什麼是 Python"
messages = [{"role": "user", "content": prompt}]
formatted = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

response = generate(
    model,
    tokenizer,
    prompt=formatted,
    max_tokens=200,
    temp=0.7,
)
print(response)

6.3 量化模型
#

如果你想把 Hugging Face 上的模型轉成 MLX 格式:

# 量化成 4bit
mlx_lm.convert \
    --hf-path meta-llama/Llama-3.2-3B-Instruct \
    --mlx-path ./llama-3.2-3b-4bit \
    --quantize \
    --q-bits 4

量化後的模型不但更小,在 Apple Silicon 上跑起來也更快!

6.4 啟動本地 API Server
#

mlx_lm.server \
    --model mlx-community/Llama-3.2-3B-Instruct-4bit \
    --port 8080

這會啟動一個 OpenAI 相容的 API,你可以用任何支援 OpenAI API 的工具來連接。

七、MLX vs PyTorch vs JAX
#

特性 MLX PyTorch JAX
設計目標 Apple Silicon 最佳化 通用 GPU Google TPU/GPU
求值模式 惰性 (Lazy) 即時 (Eager) 惰性 + JIT
自動微分 函數式 OOP (.backward) 函數式
統一記憶體 ✅ 原生支援 ❌ 需搬移 ❌ 需搬移
生態系 成長中 最豐富 學術為主
學習曲線 低(像 NumPy)

💡 拍拍君建議:如果你主要在 Mac 上開發,MLX 是最佳選擇。需要跨平台或使用大量第三方模型,PyTorch 仍然是首選。不知道 PyTorch 的話,可以看看拍拍君的 PyTorch 入門教學

結語
#

恭喜你看完 MLX 入門教學!🎉

讓我們快速回顧一下今天學到的重點:

  • MLX 是專為 Apple Silicon 打造的 ML 框架
  • 惰性求值 讓 MLX 能自動最佳化計算
  • 統一記憶體 省去 CPU ↔ GPU 資料搬移的開銷
  • API 類似 NumPy + PyTorch,上手門檻低
  • mlx-lm 可以輕鬆在 Mac 上跑大型語言模型

Apple Silicon 的 ML 生態系正在快速成長,越來越多模型會針對 MLX 最佳化。如果你有 Mac,現在正是開始玩 MLX 的好時機!

有什麼問題歡迎留言,拍拍君下次見!👋

延伸閱讀
#

Python 學習 - 本文屬於一個選集。
§ 21: 本文

相關文章

PyTorch 神經網路入門:從零開始建立你的第一個模型
·5 分鐘· loading · loading
Python Pytorch Neural-Network Deep-Learning Machine-Learning
Docker for Python:讓你的程式在任何地方都能跑
·6 分鐘· loading · loading
Python Docker Container Devops 部署
Streamlit:用 Python 快速打造互動式資料應用
·8 分鐘· loading · loading
Python Streamlit Data-Visualization Web-App Dashboard
Python Logging:別再 print 了,用正經的方式記錄日誌吧
·6 分鐘· loading · loading
Python Logging Debug 標準庫
Pre-commit Hooks:讓壞 Code 連 Commit 的機會都沒有
·4 分鐘· loading · loading
Python Pre-Commit Git Linter Code-Quality Ruff Mypy
Polars:比 Pandas 快 10 倍的 DataFrame 新選擇
·6 分鐘· loading · loading
Python Polars Dataframe 資料分析 Rust