Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nlp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This directory contains ready-to-use Natural Language Processing application not

| No. | Model | Description |
| :-- | :---- | :------------------------------ |
| 1 | / | This section is empty for now — feel free to contribute your first application! |
| 1 | [bert](./bert/train_bert_classification.ipynb) | bert training and inference application based on MindSpore NLP. |

## Contributing New NLP Applications

Expand Down
Binary file added nlp/bert/images/bert.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
374 changes: 374 additions & 0 deletions nlp/bert/train_bert_classification.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,374 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "e7c63436",
"metadata": {},
"source": [
"# 基于 MindNLP 实现 BERT 情感分类 (SST-2)\n",
"\n",
"本实验基于 MindNLP 套件,使用预训练的 BERT 模型在 GLUE 基准数据集的 SST-2(Stanford Sentiment Treebank)任务上进行微调(Fine-tuning),实现对电影评论的情感分类(正面/负面)。\n"
]
},
{
"cell_type": "markdown",
"id": "c2a8819a",
"metadata": {},
"source": [
"# BERT 模型原理简介\n",
"\n",
"### (1) BERT 概述\n",
"BERT (Bidirectional Encoder Representations from Transformers) 是由 Google 在 2018 年提出的预训练语言模型。它的出现是自然语言处理 (NLP) 领域的里程碑事件,大幅刷新了当时 11 项 NLP 任务的 SOTA 记录。\n",
"\n",
"BERT 的核心架构基于 Transformer 的 **Encoder (编码器)** 部分。与传统的单向语言模型(如 GPT 从左到右)或浅层双向模型(如 Bi-LSTM)不同,BERT 利用 **Masked Language Model (MLM)** 预训练目标,使其能够真正地同时从上下文两个方向学习深层的语义表示。\n",
"\n",
"![bert模型结构](./images/bert.png)\n",
"\n",
"### (2) 核心特性\n",
"- **双向性 (Bidirectionality)**: BERT 在处理每个词时,都能同时看到它之前和之后的词,从而更准确地理解语境(Context)。\n",
"- **Transformer 编码器**: 采用多层 Self-Attention 机制,并行计算能力强,能捕捉长距离依赖。\n",
"- **预训练任务**:\n",
" 1. **MLM (掩码语言模型)**: 随机遮盖输入中的部分 Token,让模型去预测它们。\n",
" 2. **NSP (下一句预测)**: 判断两个句子是否是连续的,帮助模型理解句子间的关系。\n",
"\n",
"### (3) 本实验使用的模型\n",
"在本实验中,我们使用 **`bert-base-uncased`** 模型:\n",
"- **Base**: 代表模型规模(12层 Transformer Block,768 隐藏层维度,12 个 Attention Heads,约 1.1 亿参数)。\n",
"- **Uncased**: 代表在预训练前将所有文本转换为小写(即不区分大小写)。\n",
"\n",
"### (4) BERT 在文本分类中的应用\n",
"在 SST-2 情感分类任务中,我们利用 BERT 的特殊标记 `[CLS]`:\n",
"1. BERT 在输入序列的开头自动添加一个 `[CLS]` 标记。\n",
"2. 经过 12 层 Transformer 处理后,`[CLS]` 位置输出的向量被视为整个句子的语义表示(Sentence Representation)。\n",
"3. 我们在 `[CLS]` 向量之上添加一个简单的全连接层(Classifier),将维度从 768 映射到 2(Positive/Negative),即可完成分类任务。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7e3cc3ad",
"metadata": {},
"outputs": [],
"source": [
"# !pip install mindnlp mindspore evaluate tqdm"
]
},
{
"cell_type": "markdown",
"id": "741fe16b",
"metadata": {},
"source": [
"## 1. 实验环境\n",
"- MindSpore 版本: 2.7.0\n",
"- MindNLP 版本: 0.5.1\n",
"- 硬件环境: Ascend/GPU"
]
},
{
"cell_type": "markdown",
"id": "63fc2ed9",
"metadata": {},
"source": [
"## 2. 导入库与环境设置\n",
"\n",
"设置 MindSpore 运行模式。为了保证在不同硬件上的兼容性以及与 HuggingFace 风格 Trainer 的适配,推荐使用 `PYNATIVE_MODE`(动态图模式)。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d81d09ce",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import warnings\n",
"import numpy as np\n",
"import mindspore\n",
"\n",
"# 屏蔽底层繁杂日志\n",
"os.environ['GLOG_v'] = '3'\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"print(f\"MindSpore Version: {mindspore.__version__}\")"
]
},
{
"cell_type": "markdown",
"id": "18aa9b55",
"metadata": {},
"source": [
"## 3. 数据集加载\n",
"\n",
"使用 MindNLP 的 `load_dataset` 接口下载并加载 GLUE 基准中的 SST-2 数据集。SST-2 是一个二分类的情感分析数据集。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d448ecf8",
"metadata": {},
"outputs": [],
"source": [
"from mindnlp.dataset import load_dataset\n",
"\n",
"TASK = \"sst2\"\n",
"MODEL_CHECKPOINT = \"bert-base-uncased\"\n",
"\n",
"print(f\"正在加载 GLUE/{TASK} 数据集...\")\n",
"dataset_dict = load_dataset(\"glue\", TASK)\n",
"print(f\"训练集样本数: {dataset_dict['train'].get_dataset_size()}\")\n",
"print(f\"验证集样本数: {dataset_dict['validation'].get_dataset_size()}\")"
]
},
{
"cell_type": "markdown",
"id": "d4a79795",
"metadata": {},
"source": [
"## 4. 数据预处理\n",
"\n",
"我们需要对文本进行 Tokenize(分词)并转换为模型可接受的 Tensor 格式。\n",
"\n",
"**注意**:为了确保 `Trainer` 的兼容性,我们将流式数据集处理为内存中的 List 格式,并对 numpy 类型的数据进行显式转换。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8f0a543",
"metadata": {},
"outputs": [],
"source": [
"from mindnlp.transformers import AutoTokenizer\n",
"from tqdm import tqdm\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)\n",
"\n",
"def process_dataset_to_list(dataset, tokenizer, max_seq_len=128):\n",
" \"\"\"\n",
" 将 MindSpore 数据集进行分词处理,并转换为内存列表以适配 Trainer\n",
" \"\"\"\n",
" # 定义分词逻辑\n",
" def tokenize(text):\n",
" # 类型安全检查:确保输入为字符串\n",
" if isinstance(text, np.ndarray):\n",
" text = str(text.item())\n",
" if isinstance(text, bytes):\n",
" text = text.decode('utf-8')\n",
" \n",
" tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)\n",
" return tokenized['input_ids'], tokenized['attention_mask'], tokenized['token_type_ids']\n",
"\n",
" # 识别文本列名\n",
" col_names = dataset.column_names\n",
" text_col = 'sentence' if 'sentence' in col_names else col_names[0]\n",
" \n",
" # 使用 map 操作进行分词\n",
" dataset = dataset.map(\n",
" operations=tokenize, \n",
" input_columns=text_col, \n",
" output_columns=['input_ids', 'attention_mask', 'token_type_ids']\n",
" )\n",
" \n",
" # Label 类型转换\n",
" if 'label' in col_names:\n",
" dataset = dataset.map(operations=lambda x: x.astype(\"int32\"), input_columns=\"label\")\n",
"\n",
" # 转换为 Python List\n",
" data_list = []\n",
" iterator = dataset.create_dict_iterator(output_numpy=True)\n",
" \n",
" print(\"正在转换数据格式...\")\n",
" for item in tqdm(iterator):\n",
" # 重命名 label 为 labels 以匹配 HF Trainer 标准\n",
" if 'label' in item:\n",
" item['labels'] = item.pop('label')\n",
" data_list.append(item)\n",
" \n",
" return data_list\n",
"\n",
"# 执行预处理\n",
"print(\"处理训练集...\")\n",
"train_dataset = process_dataset_to_list(dataset_dict['train'], tokenizer)\n",
"print(\"处理验证集...\")\n",
"eval_dataset = process_dataset_to_list(dataset_dict['validation'], tokenizer)"
]
},
{
"cell_type": "markdown",
"id": "a1b7c1cc",
"metadata": {},
"source": [
"## 5. 模型构建与评估指标\n",
"\n",
"加载预训练的 BERT 模型,并定义评估指标(Accuracy)。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3fd1d17e",
"metadata": {},
"outputs": [],
"source": [
"from mindnlp.transformers import AutoModelForSequenceClassification\n",
"import evaluate\n",
"\n",
"# 加载 BERT 分类模型 (num_labels=2)\n",
"model = AutoModelForSequenceClassification.from_pretrained(MODEL_CHECKPOINT, num_labels=2)\n",
"\n",
"# 定义评估函数\n",
"def compute_metrics(eval_pred):\n",
" try:\n",
" metric = evaluate.load(\"glue\", TASK)\n",
" except Exception:\n",
" # 离线备用方案\n",
" def simple_acc(predictions, references):\n",
" return {\"accuracy\": (predictions == references).mean()}\n",
" metric = type(\"Metric\", (), {\"compute\": lambda self, predictions, references: simple_acc(predictions, references)})()\n",
" \n",
" logits, labels = eval_pred\n",
" predictions = np.argmax(logits, axis=-1)\n",
" return metric.compute(predictions=predictions, references=labels)"
]
},
{
"cell_type": "markdown",
"id": "2b7d480d",
"metadata": {},
"source": [
"## 6. 模型训练\n",
"\n",
"使用 `TrainingArguments` 配置超参数,并使用 `Trainer` 启动训练流程。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8f21f34",
"metadata": {},
"outputs": [],
"source": [
"from mindnlp.transformers import Trainer, TrainingArguments\n",
"\n",
"BATCH_SIZE = 32\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir=f\"output_{TASK}\",\n",
" eval_strategy=\"epoch\", # 每个 epoch 结束后评估\n",
" save_strategy=\"epoch\", # 每个 epoch 结束后保存\n",
" logging_steps=10, # 日志打印频率\n",
" learning_rate=2e-5, # 学习率\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" per_device_eval_batch_size=BATCH_SIZE,\n",
" num_train_epochs=3, # 训练轮数\n",
" save_total_limit=1, # 只保留最新的模型\n",
" load_best_model_at_end=True, # 训练结束加载最优模型\n",
" metric_for_best_model=\"accuracy\"\n",
")\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=train_dataset, \n",
" eval_dataset=eval_dataset, \n",
" compute_metrics=compute_metrics,\n",
")\n",
"\n",
"print(\"开始训练...\")\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "13fb730e",
"metadata": {},
"source": [
"## 7. 模型保存\n",
"\n",
"将微调后的模型权重和分词器保存到本地,以便后续推理使用。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0668bcae",
"metadata": {},
"outputs": [],
"source": [
"save_path = \"./bert_sst2_finetuned\"\n",
"trainer.save_model(save_path)\n",
"tokenizer.save_pretrained(save_path)\n",
"print(f\"模型已保存至: {save_path}\")"
]
},
{
"cell_type": "markdown",
"id": "e056e006",
"metadata": {},
"source": [
"## 8. 模型推理\n",
"\n",
"使用 `pipeline` 高阶接口加载保存的模型,对新文本进行情感预测。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8adee43",
"metadata": {},
"outputs": [],
"source": [
"from mindnlp.transformers import pipeline\n",
"\n",
"# 加载保存的模型进行推理\n",
"classifier = pipeline(\"sentiment-analysis\", model=save_path, tokenizer=save_path, top_k=None)\n",
"\n",
"# 测试用例\n",
"test_sentences = [\n",
" \"I absolutely love this movie, it's fantastic!\", \n",
" \"The plot was boring and the acting was terrible.\",\n",
" \"It was okay, not great but not bad.\"\n",
"]\n",
"\n",
"print(\"-\" * 30)\n",
"print(\"推理结果展示:\")\n",
"print(\"-\" * 30)\n",
"\n",
"for text in test_sentences:\n",
" results = classifier(text)\n",
" # 解析结果\n",
" scores = results[0]\n",
" # 获取最高分标签\n",
" best_result = max(scores, key=lambda x: x['score'])\n",
" label = \"Positive\" if best_result['label'] == 'LABEL_1' else \"Negative\"\n",
" \n",
" print(f\"文本: {text}\")\n",
" print(f\"预测: {label} (置信度: {best_result['score']:.4f})\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "mindspore",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.25"
}
},
"nbformat": 4,
"nbformat_minor": 5
}