使用 LoRA 在 Keras 中微调 Gemma 模型

在 ai.google.dev 上查看 在 Google Colab 中运行 在 Vertex AI 中打开 在 GitHub 上查看源代码

概览

Gemma 是一系列先进的轻量级开放式模型,其开发采用了与 Gemini 模型相同的研究成果和技术。

事实证明,Gemma 等大语言模型 (LLM) 在各种 NLP 任务中都非常有效。LLM 首先会以自监督的方式在大型文本资料库上进行预训练。预训练有助于 LLM 学习通用知识,例如字词之间的统计关系。然后,您可以使用特定领域的数据对 LLM 进行微调,以执行下游任务(例如情感分析)。

LLM 体积非常大(参数数量可达数十亿级)。大多数应用不需要进行完整微调(即更新模型中的所有参数),因为典型的微调数据集与预训练数据集相比要小得多。

低秩自适应 (LoRA) 是一种微调方法,通过冻结模型权重并在模型中插入较少数量的新权重,大大减少下游任务的可训练参数数量。这大大提高了使用 LoRA 进行训练的速度和内存效率,并产生了较小的模型权重(几百 MB),同时保持了模型输出的质量。

本教程将向您详细介绍如何使用 Databricks Dolly 15k 数据集,使用 KerasNLP 对 Gemma 2B 模型进行 LoRA 微调。此数据集包含 15,000 个由人类生成的高质量提示 / 回答对,专门用于微调 LLM。

设置

获取 Gemma 访问权限

要完成本教程,您首先需要在 Gemma 设置中完成设置说明。Gemma 设置说明介绍了如何执行以下操作:

  • kaggle.com 上访问 Gemma。
  • 选择一个具有足够资源来运行 Gemma 2B 模型的 Colab 运行时。
  • 生成并配置 Kaggle 用户名和 API 密钥。

完成 Gemma 设置后,请继续执行下一部分,您将为 Colab 环境设置环境变量。

选择运行时

如需完成本教程,您需要拥有一个具有足够资源的 Colab 运行时来运行 Gemma 模型。在这种情况下,您可以使用 T4 GPU:

  1. 在 Colab 窗口的右上角,选择 ▾(其他连接选项)。
  2. 选择更改运行时类型
  3. 硬件加速器下,选择 T4 GPU

配置 API 密钥

要使用 Gemma,您必须提供您的 Kaggle 用户名和 Kaggle API 密钥。

要生成 Kaggle API 密钥,请前往您的 Kaggle 用户个人资料中的 Account(账号)标签页,然后选择 Create New Token(创建新令牌)。这将触发下载包含您的 API 凭据的 kaggle.json 文件。

在 Colab 中,选择左侧窗格中的 Secrets(密钥)图标 (🔑?),然后添加您的 Kaggle 用户名和 Kaggle API 密钥。将您的用户名存储在名称为 KAGGLE_USERNAME 的变量下,将您的 API 密钥存储在名称为 KAGGLE_KEY 的变量下。

设置环境变量

KAGGLE_USERNAMEKAGGLE_KEY 设置环境变量。

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

安装依赖项

安装 Keras、KerasNLP 和其他依赖项。

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

选择一个后端

Keras 是一个高级多框架深度学习 API,旨在实现简单易用。使用 Keras 3,您可以在三个后端之一(TensorFlow、JAX 或 PyTorch)上运行工作流。

在本教程中,请为 JAX 配置后端。

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

导入软件包

导入 Keras 和 KerasNLP。

import keras
import keras_nlp

加载数据集

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-07-31 01:56:39--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  73.7MB/s    in 0.2s    

2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

对数据进行预处理。本教程使用 1,000 个训练示例中的一部分来更快地执行该笔记本。考虑使用更多训练数据以实现更高质量的微调。

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

加载模型

KerasNLP 提供了许多热门模型架构的实现。在本教程中,您将使用 GemmaCausalLM 创建一个模型,这是一个用于因果语言建模的端到端 Gemma 模型。因果语言模型可根据之前的令牌预测下一个令牌。

使用 from_preset 方法创建模型:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

from_preset 方法会根据预设架构和权重对模型进行实例化。在上面的代码中,字符串“gemma2_2b_en”指定了预设架构,即一个具有 20 亿参数的 Gemma 模型。

微调之前的推理

在本部分,您将使用各种提示对模型进行查询,看看它会如何回答。

欧洲行程提示

查询模型,获取有关在欧洲旅行时该做些什么的建议。

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

How can I make a reservation for a trip?

该模型会提供有关如何规划行程的一般性提示作为回答。

ELI5 光合作用提示

提示模型用简单易懂的词汇解释光合作用,让 5 岁的孩子能够理解。

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.

Instruction:
What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?

Response:
The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.

Instruction:
Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.

Response:
Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.

Instruction:
How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?

Response:
Photosynthesis occurs in the cells of a plant. The purpose of

模型回答中包含儿童可能不易理解的字词,例如叶绿素。

LoRA 微调

为了让模型获得更好的响应,请使用 Databricks Dolly 15k 数据集,通过低排名自适应 (LoRA) 对模型进行微调。

LoRA 阶决定了添加到 LLM 原始权重的可训练矩阵的维度。它控制着微调调整的表现力和精确度。

秩越高,可进行更细致的更改,但可训练的参数也越多。秩越低,计算开销越低,但适应精度可能较低。

本教程使用的 LoRA 等级为 4。在实践中,请从相对较小的 rank(例如 4、8、16)开始。这对于实验来说具有较高的计算效率。使用此排名训练模型并评估任务的性能改进。在后续的试用中,逐步提高排名,看看效果是否会进一步提升。

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

请注意,启用 LoRA 会显著减少可训练参数的数量(从 26 亿减少到 290 万)。

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251
<keras.src.callbacks.history.History at 0x799d04393c40>

关于在 NVIDIA GPU 上进行混合精度微调的注意事项

建议使用全精度进行微调。在 NVIDIA GPU 上进行微调时,请注意,您可以使用混合精度 (keras.mixed_precision.set_global_policy('mixed_bfloat16')) 来加快训练速度,同时对训练质量的影响降到最低。混合精度微调确实会消耗更多内存,因此仅适用于更大的 GPU。

对于推理,半精度 (keras.config.set_floatx("bfloat16")) 可用并节省内存,而混合精度不适用。

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

微调后的推理

微调后,回答会遵循提示中提供的说明。

欧洲行程提示

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.

该模型现在可以推荐欧洲的值得一去的地点。

ELI5 光合作用提示

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.

该模型现在会用更简单的语言来解释光合作用。

请注意,出于演示目的,本教程仅在数据集的一小部分上对模型进行了一次迭代的精调,并且 LoRA 秩值较低。为了让经过微调的模型提供更好的回答,您可以尝试:

  1. 增加微调数据集的大小
  2. 训练更多步数(周期)
  3. 设置更高的 LoRA 等级
  4. 修改超参数值,如 learning_rateweight_decay

总结和后续步骤

本教程介绍了如何使用 KerasNLP 对 Gemma 模型进行 LoRA 微调。接下来,请参阅以下文档: