From d0cbae92fbd9148c2129db6ebce700e536049a83 Mon Sep 17 00:00:00 2001 From: liumintao2025 <1820220397@qq.com> Date: Fri, 30 Jan 2026 13:37:24 +0800 Subject: [PATCH] add audio classification application based on wav2vec2 --- audio/README.md | 1 + .../wav2vec2_audio_classification.ipynb | 558 ++++++++++++++++++ 2 files changed, 559 insertions(+) create mode 100644 audio/wav2vec2/wav2vec2_audio_classification.ipynb diff --git a/audio/README.md b/audio/README.md index 56f628f..6b3465c 100644 --- a/audio/README.md +++ b/audio/README.md @@ -7,6 +7,7 @@ This directory contains ready-to-use Audio application notebooks built with Mind | No. | Model | Description | | :-- | :---- | :-------------------------------- | | 1 | [WaveNet](./wavenet/) | Includes notebooks for WaveNet training on tasks such as audio synthesis | +| 2 | [Wav2Vec2](./wav2vec2/) | Includes notebooks for Wav2Vec2 fine-tuning on tasks such as audio intent classification (MInDS-14) | ## Contributing New Audio Applications diff --git a/audio/wav2vec2/wav2vec2_audio_classification.ipynb b/audio/wav2vec2/wav2vec2_audio_classification.ipynb new file mode 100644 index 0000000..eccd17f --- /dev/null +++ b/audio/wav2vec2/wav2vec2_audio_classification.ipynb @@ -0,0 +1,558 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#基于 MindNLP 的 Wav2Vec2 音频意图识别\n", + "\n", + "## 项目简介\n", + "本项目基于 **MindSpore 2.7.0** 和 **MindNLP** 框架,使用预训练的 `facebook/wav2vec2-base` 模型在 **MInDS-14** 数据集上进行微调,实现英语语音意图分类任务。\n", + "\n", + "## 实验环境\n", + "* **硬件平台**: Huawei Ascend 910 (NPU)\n", + "* **运行模式**: PYNATIVE_MODE (动态图模式)\n", + "* **关键策略**: \n", + " 1. **定长输入**: 统一填充至 5秒 (80000采样点)。\n", + " 2. **手动解码**: 使用 `librosa` 替代默认解码器,提升底层兼容性。\n", + " 3. **梯度屏蔽**: 冻结非浮点型参数 (Int64),防止优化器报错。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 环境初始化与依赖配置\n", + "配置 HuggingFace 镜像源以加速模型下载,并设置 MindSpore 运行上下文为 Ascend NPU。" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[WARNING] ME(2670341:281473801904160,MainProcess):2026-01-30-12:50:44.499.000 [mindspore/context.py:1412] For 'context.set_context', the parameter 'device_target' will be deprecated and removed in a future version. Please use the api mindspore.set_device() instead.\n" + ] + } + ], + "source": [ + "import os\n", + "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n", + "import numpy as np\n", + "import librosa\n", + "import mindspore as ms\n", + "from datasets import load_dataset, Audio\n", + "from mindnlp.transformers import (\n", + " AutoFeatureExtractor, \n", + " AutoModelForAudioClassification, \n", + " TrainingArguments, \n", + " Trainer\n", + ")\n", + "\n", + "ms.set_context(mode=ms.PYNATIVE_MODE, device_target=\"Ascend\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 数据集加载与划分\n", + "加载 MInDS-14 (en-US) 数据集,并按 **8:2** 的比例划分为训练集和验证集。同时构建标签与 ID 的映射字典。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>> Loading Dataset...\n", + "数据集加载完成,类别数量: 14\n" + ] + } + ], + "source": [ + "print(\">>> Loading Dataset...\")\n", + "try:\n", + " # 尝试加载 MInDS-14 数据集 (en-US)\n", + " minds = load_dataset(\"PolyAI/minds14\", name=\"en-US\", split=\"train\")\n", + "except:\n", + " # 网络波动时的重试机制\n", + " minds = load_dataset(\"PolyAI/minds14\", name=\"en-US\", split=\"train\")\n", + "\n", + "# 划分验证集 (20% 用于验证)\n", + "minds = minds.train_test_split(test_size=0.2)\n", + "\n", + "# 构建标签映射字典\n", + "labels = minds[\"train\"].features[\"intent_class\"].names\n", + "label2id = {label: str(i) for i, label in enumerate(labels)}\n", + "id2label = {str(i): label for i, label in enumerate(labels)}\n", + "\n", + "print(f\"数据集加载完成,类别数量: {len(labels)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 数据预处理 (Feature Extraction)\n", + "为了适应 Wav2Vec2 模型输入要求,我们执行以下关键操作:\n", + "1. **手动读取**: 使用 `librosa` 读取音频文件或字节流,规避系统底层依赖问题。\n", + "2. **统一长度**: 将所有音频重采样至 16kHz,并填充/截断至 **80000** 采样点 (5秒)。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/zdai/miniconda3/envs/QA/lib/python3.11/site-packages/transformers/configuration_utils.py:335: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>> Preprocessing Data...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|██████████| 450/450 [00:04<00:00, 105.33 examples/s]\n", + "Map: 100%|██████████| 113/113 [00:00<00:00, 215.08 examples/s]\n" + ] + } + ], + "source": [ + "model_id = \"facebook/wav2vec2-base\"\n", + "feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)\n", + "\n", + "# 音频最大长度: 4-5秒 (64000-80000采样点)\n", + "# 经验值设为 80000 (5秒) 以覆盖完整指令\n", + "MAX_DURATION = 80000 \n", + "\n", + "def preprocess_function(examples):\n", + " \"\"\"\n", + " 读取音频 -> 重采样至16k -> 填充/截断至固定长度\n", + " \"\"\"\n", + " audio_arrays = []\n", + " for x in examples[\"audio\"]:\n", + " try:\n", + " # 使用 librosa 手动读取,规避 datasets 底层解码依赖问题\n", + " if x.get('bytes'):\n", + " import io\n", + " y, _ = librosa.load(io.BytesIO(x['bytes']), sr=16000)\n", + " else:\n", + " y, _ = librosa.load(x['path'], sr=16000)\n", + " \n", + " # 异常处理:空音频补全\n", + " if len(y) == 0: y = np.zeros(16000)\n", + " audio_arrays.append(y)\n", + " except:\n", + " audio_arrays.append(np.zeros(16000))\n", + " \n", + " # 使用 FeatureExtractor 进行标准化处理\n", + " inputs = feature_extractor(\n", + " audio_arrays, \n", + " sampling_rate=16000, \n", + " max_length=MAX_DURATION,\n", + " truncation=True, \n", + " padding=\"max_length\",\n", + " )\n", + " return inputs\n", + "\n", + "# 移除原始 audio 列,设置为 decode=False 以获取原始字节流\n", + "minds = minds.cast_column(\"audio\", Audio(decode=False))\n", + "\n", + "print(\">>> Preprocessing Data...\")\n", + "encoded_minds = minds.map(\n", + " preprocess_function, \n", + " remove_columns=[\"audio\", \"path\", \"transcription\", \"english_transcription\", \"lang_id\"], \n", + " batched=True\n", + ")\n", + "encoded_minds = encoded_minds.rename_column(\"intent_class\", \"labels\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. 模型加载与参数兼容性补丁\n", + "加载预训练模型,并应用 **Int64 梯度屏蔽补丁**。这是为了解决 MindSpore 优化器在处理位置编码等整数型参数时可能出现的梯度计算错误。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>> Loading Model: facebook/wav2vec2-base\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/zdai/miniconda3/envs/QA/lib/python3.11/site-packages/transformers/configuration_utils.py:335: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[MS_ALLOC_CONF]Runtime config: enable_vmm:True vmm_align_size:2MB\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "print(f\">>> Loading Model: {model_id}\")\n", + "model = AutoModelForAudioClassification.from_pretrained(\n", + " model_id, \n", + " num_labels=len(labels),\n", + " label2id=label2id,\n", + " id2label=id2label,\n", + ")\n", + "\n", + "# 冻结 Int64/Int32 参数梯度\n", + "# 说明:Wav2Vec2 的位置编码参数为整数类型,MindSpore 优化器无法对整数求导。\n", + "# 必须显式将 requires_grad 设为 False,否则训练会报错。\n", + "for p in model.get_parameters():\n", + " if p.dtype in (ms.int32, ms.int64):\n", + " p.requires_grad = False" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5. 训练参数配置\n", + "配置 TrainingArguments,设置学习率、Batch Size 以及 Evaluation 策略。" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "args = TrainingArguments(\n", + " output_dir=\"wav2vec2_ckpt\",\n", + " eval_strategy=\"epoch\",\n", + " save_strategy=\"epoch\",\n", + " learning_rate=3e-5, # 经典微调学习率\n", + " per_device_train_batch_size=8, # 训练 Batch Size\n", + " per_device_eval_batch_size=8, # 验证 Batch Size\n", + " num_train_epochs=15, # 正常训练轮次\n", + " logging_steps=10,\n", + " load_best_model_at_end=True,\n", + " metric_for_best_model=\"accuracy\",\n", + " save_total_limit=1,\n", + " seed=42\n", + ")\n", + "\n", + "# 评估指标函数 (引入 evaluate 库)\n", + "import evaluate\n", + "accuracy = evaluate.load(\"accuracy\")\n", + "def compute_metrics(eval_pred):\n", + " predictions = np.argmax(eval_pred.predictions, axis=1)\n", + " return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6. 模型训练 (Training)\n", + "初始化 Trainer 并启动训练。注意此处不传递 `tokenizer` 参数,以避免触发额外的依赖检查。" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "🚀 开始标准训练...\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
| Epoch | \n", + "Training Loss | \n", + "Validation Loss | \n", + "Accuracy | \n", + "
|---|---|---|---|
| 1 | \n", + "2.649500 | \n", + "2.641374 | \n", + "0.079646 | \n", + "
| 2 | \n", + "2.639900 | \n", + "2.617154 | \n", + "0.106195 | \n", + "
| 3 | \n", + "2.483800 | \n", + "2.415874 | \n", + "0.300885 | \n", + "
| 4 | \n", + "2.117700 | \n", + "2.107937 | \n", + "0.566372 | \n", + "
| 5 | \n", + "1.806700 | \n", + "1.876625 | \n", + "0.646018 | \n", + "
| 6 | \n", + "1.418400 | \n", + "1.667981 | \n", + "0.663717 | \n", + "
| 7 | \n", + "1.333700 | \n", + "1.512553 | \n", + "0.672566 | \n", + "
| 8 | \n", + "1.117300 | \n", + "1.409446 | \n", + "0.699115 | \n", + "
| 9 | \n", + "0.909000 | \n", + "1.298606 | \n", + "0.752212 | \n", + "
| 10 | \n", + "0.832500 | \n", + "1.270922 | \n", + "0.716814 | \n", + "
| 11 | \n", + "0.669600 | \n", + "1.242440 | \n", + "0.716814 | \n", + "
| 12 | \n", + "0.628500 | \n", + "1.216295 | \n", + "0.699115 | \n", + "
| 13 | \n", + "0.584200 | \n", + "1.185699 | \n", + "0.734513 | \n", + "
| 14 | \n", + "0.517500 | \n", + "1.149428 | \n", + "0.743363 | \n", + "
| 15 | \n", + "0.508300 | \n", + "1.146037 | \n", + "0.743363 | \n", + "
"
+ ],
+ "text/plain": [
+ "