diff --git a/cv/README.md b/cv/README.md index d8edae0..3dfdf60 100644 --- a/cv/README.md +++ b/cv/README.md @@ -9,6 +9,7 @@ This directory contains ready-to-use Computer Vision application notebooks built | 1 | [ResNet](./resnet/) | Includes notebooks for ResNet finetuning on tasks such as chinese herbal classification | | 2 | [U-Net](./unet/) | Includes notebooks for U-Net training on tasks such as segmentation | | 3 | [SAM](./sam/) | Includes notebooks for using SAM to inference | +| 4 | [OCR](./ocr/) | Includes notebooks for OCR inference on tasks such as DeepSeek-OCR demo | ## Contributing New CV Applications diff --git a/cv/ocr/inference_deepSeekorc_demo.ipynb b/cv/ocr/inference_deepSeekorc_demo.ipynb new file mode 100644 index 0000000..a87a39f --- /dev/null +++ b/cv/ocr/inference_deepSeekorc_demo.ipynb @@ -0,0 +1,1051 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# DeepSeek-OCR MindSpore DEMO\n", + "\n", + "基于 **MindSpore 2.7.0 + MindNLP 0.5.1** 的文本识别与结构化解析演示。\n", + "\n", + "## 环境要求\n", + "\n", + "| 组件 | 版本 |\n", + "|------|------|\n", + "| Python | 3.9 |\n", + "| MindSpore | 2.7.0 |\n", + "| MindNLP | 0.5.1 |\n", + "| transformers | 4.57.3 |\n", + "| Gradio | 6.1.0 |\n", + "| 硬件 | Ascend NPU 910B (65536MB HBM) |\n", + "| CANN | 8.2.RC2 |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 0: 安装依赖\n", + "# 如已完成环境准备,可跳过本单元\n", + "# Ascend 环境请确保已先安装与 CANN 8.2.RC2 匹配的 MindSpore 2.7.0\n", + "!pip install mindspore==2.7.0\n", + "!pip install mindnlp==0.5.1\n", + "!pip install diffusers==0.35.2\n", + "!pip install gradio\n", + "!pip install einops\n", + "!pip install torchvision" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 1: 环境检查\n", + "import mindspore as ms\n", + "ms.set_context(device_target=\"Ascend\", device_id=0)\n", + "print(f\"MindSpore version: {ms.__version__}\")\n", + "\n", + "import mindnlp\n", + "print(f\"MindNLP available\")\n", + "\n", + "import transformers\n", + "print(f\"transformers version: {transformers.__version__}\")\n", + "\n", + "import gradio as gr\n", + "print(f\"Gradio version: {gr.__version__}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 模型加载\n", + "\n", + "使用 MindNLP 的 transformers 兼容接口加载 DeepSeek-OCR 模型。\n", + "\n", + "**关键参数说明**:\n", + "- `_attn_implementation='eager'`: Ascend NPU 上兼容性最佳的注意力实现\n", + "- `trust_remote_code=True`: 加载模型自定义代码\n", + "- `use_safetensors=True`: 使用安全张量格式" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 2: 模型加载\n", + "import types\n", + "import mindnlp\n", + "import mindtorch\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from transformers import AutoModel, AutoTokenizer\n", + "\n", + "model_name = 'lvyufeng/DeepSeek-OCR'\n", + "\n", + "print(\"加载 tokenizer...\")\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n", + "\n", + "print(\"加载模型 (float32)...\")\n", + "model = AutoModel.from_pretrained(\n", + " model_name,\n", + " _attn_implementation='eager',\n", + " trust_remote_code=True,\n", + " use_safetensors=True,\n", + " device_map='auto'\n", + ")\n", + "model = model.eval()\n", + "\n", + "print(\"合并 MoE 权重...\")\n", + "model.combine_moe()\n", + "\n", + "# NPU 不支持 scatter_add 用 one_hot 替代\n", + "def _patched_forward_for_moe(self, hidden_states):\n", + " batch_size, sequence_length, hidden_dim = hidden_states.shape\n", + " selected_experts, routing_weights = self.gate(hidden_states)\n", + " n_experts = self.config.n_routed_experts\n", + " routing_weights = routing_weights.to(hidden_states.dtype)\n", + " one_hot = F.one_hot(selected_experts, n_experts).to(routing_weights.dtype)\n", + " router_scores = (one_hot * routing_weights.unsqueeze(-1)).sum(dim=1)\n", + " hidden_states = hidden_states.view(-1, hidden_dim)\n", + " if self.config.n_shared_experts is not None:\n", + " shared_expert_output = self.shared_experts(hidden_states)\n", + " hidden_w1 = torch.matmul(hidden_states, self.w1)\n", + " hidden_w3 = torch.matmul(hidden_states, self.w3)\n", + " hidden_states = self.act(hidden_w1) * hidden_w3\n", + " hidden_states = torch.bmm(hidden_states, self.w2) * torch.transpose(router_scores, 0, 1).unsqueeze(-1)\n", + " final_hidden_states = hidden_states.sum(dim=0, dtype=hidden_states.dtype)\n", + " if self.config.n_shared_experts is not None:\n", + " hidden_states = final_hidden_states + shared_expert_output\n", + " return hidden_states.view(batch_size, sequence_length, hidden_dim)\n", + "\n", + "for layer in model.model.layers:\n", + " if hasattr(layer.mlp, 'w1'):\n", + " layer.mlp.forward = types.MethodType(_patched_forward_for_moe, layer.mlp)\n", + "\n", + "print(\"模型加载完成!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 单张图片推理示例(非流式)\n", + "\n", + "使用 `model.infer()` 方法进行标准推理,支持多种分辨率模式:\n", + "\n", + "| 模式 | base_size | image_size | crop_mode | 适用场景 |\n", + "|------|-----------|------------|-----------|----------|\n", + "| Tiny | 512 | 512 | False | 快速预览 |\n", + "| Small | 640 | 640 | False | 一般文档 |\n", + "| Base | 1024 | 1024 | False | 高质量 |\n", + "| Large | 1280 | 1280 | False | 超高分辨率 |\n", + "| **Gundam** | **1024** | **640** | **True** | **推荐:精度速度最佳平衡** |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 3: 单张图片推理(非流式)\n", + "import time\n", + "import os\n", + "\n", + "# 准备测试图片\n", + "# 如果没有测试图片,可以从 HuggingFace 下载\n", + "image_file = 'image_ocr.jpg'\n", + "if not os.path.exists(image_file):\n", + " import urllib.request\n", + " url = 'https://hf-mirror.com/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg'\n", + " print(f\"下载测试图片: {url}\")\n", + " urllib.request.urlretrieve(url, image_file)\n", + " print(\"下载完成\")\n", + "\n", + "prompt = \"\\nFree OCR. \"\n", + "output_path = './output'\n", + "os.makedirs(output_path, exist_ok=True)\n", + "\n", + "print(\"开始推理 (Gundam 模式)...\")\n", + "t0 = time.time()\n", + "\n", + "with mindtorch.no_grad():\n", + " res = model.infer(\n", + " tokenizer,\n", + " prompt=prompt,\n", + " image_file=image_file,\n", + " output_path=output_path,\n", + " base_size=1024,\n", + " image_size=640,\n", + " crop_mode=True,\n", + " save_results=True,\n", + " test_compress=True,\n", + " )\n", + "\n", + "elapsed = time.time() - t0\n", + "print(f\"\\n推理完成,总耗时: {elapsed:.2f}s\")\n", + "\n", + "# 显示结果\n", + "if os.path.exists(f'{output_path}/result.mmd'):\n", + " with open(f'{output_path}/result.mmd', 'r') as f:\n", + " print(\"\\n识别结果:\")\n", + " print(f.read())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 流式生成 + 时间统计\n", + "\n", + "使用 `TextIteratorStreamer` 实现流式 token 输出,可以在生成过程中实时查看结果。\n", + "\n", + "**核心思路**:\n", + "1. 从 `model.infer()` 中抽取图像预处理逻辑为独立函数\n", + "2. 用 `TextIteratorStreamer` 替换原始 `NoEOSTextStreamer`\n", + "3. 在独立线程中运行 `model.generate()`\n", + "4. 主线程通过 streamer 迭代获取 token 并统计时间" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 4: 流式生成 + 时间统计\n", + "import math\n", + "import importlib\n", + "from threading import Thread\n", + "from PIL import Image, ImageOps\n", + "from transformers import TextIteratorStreamer\n", + "\n", + "# 导入模型辅助函数\n", + "_mod = importlib.import_module(type(model).__module__)\n", + "format_messages = _mod.format_messages\n", + "load_pil_images = _mod.load_pil_images\n", + "text_encode = _mod.text_encode\n", + "BasicImageTransform = _mod.BasicImageTransform\n", + "dynamic_preprocess = _mod.dynamic_preprocess\n", + "\n", + "IMAGE_TOKEN = ''\n", + "IMAGE_TOKEN_ID = 128815\n", + "PATCH_SIZE = 16\n", + "DOWNSAMPLE_RATIO = 4\n", + "BOS_ID = 0\n", + "\n", + "\n", + "def prepare_inputs(prompt_text, image_file, base_size, image_size, crop_mode):\n", + " \"\"\"从 model.infer() 中抽取的图像预处理逻辑。\"\"\"\n", + " conversation = [\n", + " {\"role\": \"<|User|>\", \"content\": prompt_text, \"images\": [image_file]},\n", + " {\"role\": \"<|Assistant|>\", \"content\": \"\"},\n", + " ]\n", + " formatted_prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='')\n", + " images = load_pil_images(conversation)\n", + "\n", + " image_transform = BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True)\n", + " text_splits = formatted_prompt.split(IMAGE_TOKEN)\n", + "\n", + " images_list, images_crop_list, images_seq_mask = [], [], []\n", + " tokenized_str = []\n", + " images_spatial_crop = []\n", + "\n", + " for text_sep, image in zip(text_splits, images):\n", + " tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)\n", + " tokenized_str += tokenized_sep\n", + " images_seq_mask += [False] * len(tokenized_sep)\n", + "\n", + " if crop_mode:\n", + " if image.size[0] <= 640 and image.size[1] <= 640:\n", + " crop_ratio = [1, 1]\n", + " else:\n", + " images_crop_raw, crop_ratio = dynamic_preprocess(image)\n", + "\n", + " global_view = ImageOps.pad(image, (base_size, base_size),\n", + " color=tuple(int(x * 255) for x in image_transform.mean))\n", + " images_list.append(image_transform(global_view).to(model.dtype))\n", + " width_crop_num, height_crop_num = crop_ratio\n", + " images_spatial_crop.append([width_crop_num, height_crop_num])\n", + "\n", + " if width_crop_num > 1 or height_crop_num > 1:\n", + " for i in range(len(images_crop_raw)):\n", + " images_crop_list.append(image_transform(images_crop_raw[i]).to(model.dtype))\n", + "\n", + " num_queries = math.ceil((image_size // PATCH_SIZE) / DOWNSAMPLE_RATIO)\n", + " num_queries_base = math.ceil((base_size // PATCH_SIZE) / DOWNSAMPLE_RATIO)\n", + "\n", + " tokenized_image = ([IMAGE_TOKEN_ID] * num_queries_base + [IMAGE_TOKEN_ID]) * num_queries_base\n", + " tokenized_image += [IMAGE_TOKEN_ID]\n", + " if width_crop_num > 1 or height_crop_num > 1:\n", + " tokenized_image += ([IMAGE_TOKEN_ID] * (num_queries * width_crop_num) + [IMAGE_TOKEN_ID]) * (\n", + " num_queries * height_crop_num)\n", + " tokenized_str += tokenized_image\n", + " images_seq_mask += [True] * len(tokenized_image)\n", + " else:\n", + " if image_size <= 640:\n", + " image = image.resize((image_size, image_size))\n", + " global_view = ImageOps.pad(image, (image_size, image_size),\n", + " color=tuple(int(x * 255) for x in image_transform.mean))\n", + " images_list.append(image_transform(global_view).to(model.dtype))\n", + " images_spatial_crop.append([1, 1])\n", + "\n", + " num_queries = math.ceil((image_size // PATCH_SIZE) / DOWNSAMPLE_RATIO)\n", + " tokenized_image = ([IMAGE_TOKEN_ID] * num_queries + [IMAGE_TOKEN_ID]) * num_queries\n", + " tokenized_image += [IMAGE_TOKEN_ID]\n", + " tokenized_str += tokenized_image\n", + " images_seq_mask += [True] * len(tokenized_image)\n", + "\n", + " tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)\n", + " tokenized_str += tokenized_sep\n", + " images_seq_mask += [False] * len(tokenized_sep)\n", + " tokenized_str = [BOS_ID] + tokenized_str\n", + " images_seq_mask = [False] + images_seq_mask\n", + "\n", + " input_ids = torch.LongTensor(tokenized_str)\n", + " images_seq_mask_t = torch.tensor(images_seq_mask, dtype=torch.bool)\n", + "\n", + " if len(images_list) == 0:\n", + " images_ori = torch.zeros((1, 3, image_size, image_size))\n", + " images_spatial_crop_t = torch.zeros((1, 2), dtype=torch.long)\n", + " images_crop = torch.zeros((1, 3, base_size, base_size))\n", + " else:\n", + " images_ori = torch.stack(images_list, dim=0)\n", + " images_spatial_crop_t = torch.tensor(images_spatial_crop, dtype=torch.long)\n", + " images_crop = torch.stack(images_crop_list, dim=0) if images_crop_list else torch.zeros((1, 3, base_size, base_size))\n", + "\n", + " return {\n", + " 'input_ids': input_ids.unsqueeze(0).cuda(),\n", + " 'images': [(images_crop.cuda(), images_ori.cuda())],\n", + " 'images_seq_mask': images_seq_mask_t.unsqueeze(0).cuda(),\n", + " 'images_spatial_crop': images_spatial_crop_t,\n", + " }\n", + "\n", + "\n", + "# 流式推理\n", + "prompt_text = \"\\nFree OCR. \"\n", + "\n", + "model.disable_torch_init()\n", + "inputs = prepare_inputs(prompt_text, image_file, base_size=1024, image_size=640, crop_mode=True)\n", + "\n", + "streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)\n", + "\n", + "generate_kwargs = dict(\n", + " input_ids=inputs['input_ids'],\n", + " images=inputs['images'],\n", + " images_seq_mask=inputs['images_seq_mask'],\n", + " images_spatial_crop=inputs['images_spatial_crop'],\n", + " temperature=0.0,\n", + " eos_token_id=tokenizer.eos_token_id,\n", + " streamer=streamer,\n", + " max_new_tokens=8192,\n", + " no_repeat_ngram_size=20,\n", + " use_cache=True,\n", + ")\n", + "\n", + "\n", + "def run_generate():\n", + " with torch.no_grad():\n", + " model.generate(**generate_kwargs)\n", + "\n", + "\n", + "thread = Thread(target=run_generate)\n", + "t_start = time.time()\n", + "thread.start()\n", + "\n", + "first_token_time = None\n", + "token_count = 0\n", + "full_text = \"\"\n", + "STOP_STR = '<|end▁of▁sentence|>'\n", + "\n", + "print(\"流式生成中...\")\n", + "print(\"=\" * 50)\n", + "for new_text in streamer:\n", + " if first_token_time is None:\n", + " first_token_time = time.time() - t_start\n", + " token_count += 1\n", + " full_text += new_text\n", + " # 实时输出\n", + " clean = new_text.replace(STOP_STR, '')\n", + " if clean:\n", + " print(clean, end='', flush=True)\n", + "\n", + "thread.join()\n", + "total_time = time.time() - t_start\n", + "\n", + "print(\"\\n\" + \"=\" * 50)\n", + "print(f\"\\n性能统计:\")\n", + "print(f\" 首 Token 延迟 (TTFT): {first_token_time:.3f}s\")\n", + "print(f\" 总 Token 数: {token_count}\")\n", + "print(f\" 总耗时: {total_time:.2f}s\")\n", + "print(f\" 生成速度: {token_count / total_time:.2f} tokens/s\")\n", + "if token_count > 1 and first_token_time:\n", + " decode_time = total_time - first_token_time\n", + " print(f\" 解码速度 (不含首 token): {(token_count - 1) / decode_time:.2f} tokens/s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 性能优化方案说明\n", + "\n", + "### 已实施的优化\n", + "\n", + "| # | 优化项 | 方法 | 效果 |\n", + "|---|--------|------|------|\n", + "| 1 | **MoE 权重合并** | `model.combine_moe()` | 将分散的专家权重合并为矩阵运算,减少内存访问次数,加速前向传播 |\n", + "| 2 | **scatter_add NPU 适配** | `F.one_hot` + 矩阵乘法 | 替换 NPU 不支持的 `scatter_add_ext` 算子,保证 MoE 合并后推理正确性 |\n", + "| 3 | **KV Cache** | `use_cache=True` | 缓存已计算的 Key/Value,避免自回归生成时重复计算所有位置的注意力 |\n", + "| 4 | **N-gram 去重** | `no_repeat_ngram_size=20` | 防止模型生成重复文本,提升有效 token 效率 |\n", + "| 5 | **Eager Attention** | `_attn_implementation='eager'` | 在 Ascend NPU 上比 Flash Attention 兼容性更好,避免算子不支持的问题 |\n", + "\n", + "### 优化前后实测数据对比(Ascend 910B, Gundam 模式, 256 tokens)\n", + "\n", + "| 配置 | TTFT | 生成速度 | 解码速度 | 加速比 |\n", + "|------|------|----------|----------|--------|\n", + "| **全部优化** (combine_moe + KV Cache) | 9.757s | **7.95 tok/s** | **11.34 tok/s** | 基线 |\n", + "| 关闭 MoE 合并 (无 combine_moe) | 10.805s | 1.68 tok/s | 2.29 tok/s | **4.95x 慢** |\n", + "\n", + "> **结论**: `combine_moe()` 是最关键的优化项,使解码速度提升约 **5 倍**(2.29 → 11.34 tok/s)。\n", + "\n", + "### 不同分辨率模式实测对比(256 tokens, 全部优化)\n", + "\n", + "| 模式 | TTFT | 总耗时 | 生成速度 | 解码速度 | 适用场景 |\n", + "|------|------|--------|----------|----------|----------|\n", + "| Tiny (512) | **0.214s** | **23.36s** | **11.00 tok/s** | 11.06 tok/s | 快速预览、低分辨率 |\n", + "| Small (640) | 0.257s | 23.89s | 10.76 tok/s | 10.83 tok/s | 一般文档 |\n", + "| **Gundam (推荐)** | 9.757s | 32.33s | 7.95 tok/s | **11.34 tok/s** | 高精度 OCR |\n", + "\n", + "> **说明**: Gundam 模式 TTFT 较高是因为 crop 模式需要处理多个图像切片(全局视图+局部切片),但解码速度与其他模式持平。Tiny 模式 TTFT 极低(0.2s),适合对延迟敏感的场景。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Cell 5: 不同分辨率模式对比\n", + "import time\n", + "\n", + "results = {}\n", + "modes = {\n", + " 'Tiny': {'base_size': 512, 'image_size': 512, 'crop_mode': False},\n", + " 'Small': {'base_size': 640, 'image_size': 640, 'crop_mode': False},\n", + " 'Gundam': {'base_size': 1024, 'image_size': 640, 'crop_mode': True},\n", + "}\n", + "\n", + "for mode_name, params in modes.items():\n", + " print(f\"\\n{'='*50}\")\n", + " print(f\"测试模式: {mode_name} (base={params['base_size']}, img={params['image_size']}, crop={params['crop_mode']})\")\n", + " print(f\"{'='*50}\")\n", + "\n", + " inputs = prepare_inputs(prompt_text, image_file, **params)\n", + " streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)\n", + "\n", + " gen_kwargs = dict(\n", + " input_ids=inputs['input_ids'],\n", + " images=inputs['images'],\n", + " images_seq_mask=inputs['images_seq_mask'],\n", + " images_spatial_crop=inputs['images_spatial_crop'],\n", + " temperature=0.0,\n", + " eos_token_id=tokenizer.eos_token_id,\n", + " streamer=streamer,\n", + " max_new_tokens=4096,\n", + " no_repeat_ngram_size=20,\n", + " use_cache=True,\n", + " )\n", + "\n", + " def _gen():\n", + " with torch.no_grad():\n", + " model.generate(**gen_kwargs)\n", + "\n", + " thread = Thread(target=_gen)\n", + " t0 = time.time()\n", + " thread.start()\n", + "\n", + " ttft = None\n", + " n_tokens = 0\n", + " for text in streamer:\n", + " if ttft is None:\n", + " ttft = time.time() - t0\n", + " n_tokens += 1\n", + " thread.join()\n", + " total = time.time() - t0\n", + "\n", + " results[mode_name] = {'ttft': ttft, 'tokens': n_tokens, 'total': total, 'tps': n_tokens / total}\n", + " print(f\" TTFT: {ttft:.3f}s | Tokens: {n_tokens} | Total: {total:.2f}s | Speed: {n_tokens/total:.2f} tok/s\")\n", + "\n", + "print(f\"\\n{'='*60}\")\n", + "print(\"对比汇总:\")\n", + "print(f\"{'模式':<10} {'TTFT':>8} {'Tokens':>8} {'总耗时':>8} {'速度':>12}\")\n", + "print(f\"{'-'*50}\")\n", + "for name, r in results.items():\n", + " print(f\"{name:<10} {r['ttft']:>7.3f}s {r['tokens']:>8} {r['total']:>7.2f}s {r['tps']:>8.2f} tok/s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Gradio 交互 DEMO\n", + "\n", + "启动完整的 Gradio Web 界面,支持:\n", + "- 图片上传\n", + "- 多种分辨率模式选择\n", + "- 多种任务类型(Free OCR / Markdown / 图表解析 / 文本定位)\n", + "- 流式文本输出\n", + "- 实时性能统计\n", + "\n", + "魔乐社区链接:(待更新)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "DeepSeek-OCR MindSpore DEMO\n", + "基于 MindSpore 2.7.0 + MindNLP 0.5.1 的文本识别与结构化解析交互式 DEMO\n", + "支持流式生成、token 时间统计和性能优化\n", + "\"\"\"\n", + "\n", + "# !pip install mindspore==2.7.0\n", + "# !pip install mindnlp==0.5.1\n", + "# !pip install diffusers==0.35.2\n", + "# !pip install gradio\n", + "# !pip install einops\n", + "# !pip install torchvision\n", + "\n", + "import importlib\n", + "import math\n", + "import os\n", + "import time\n", + "import types\n", + "import tempfile\n", + "from threading import Thread\n", + "from typing import Optional\n", + "\n", + "from PIL import Image, ImageOps\n", + "\n", + "try:\n", + " import gradio as gr # pylint: disable=import-error\n", + "except ImportError: # pragma: no cover - optional UI dependency\n", + " gr = None\n", + "\n", + "import mindspore as ms\n", + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer\n", + "\n", + "# mindnlp / mindtorch patch transformers and torch behavior for MindSpore at import time.\n", + "importlib.import_module(\"mindnlp\")\n", + "importlib.import_module(\"mindtorch\")\n", + "\n", + "ms.set_context(device_target=\"Ascend\", device_id=0)\n", + "\n", + "# ============================================================\n", + "# 全局配置\n", + "# ============================================================\n", + "MODEL_NAME = \"lvyufeng/DeepSeek-OCR\"\n", + "IMAGE_TOKEN = \"\"\n", + "IMAGE_TOKEN_ID = 128815\n", + "PATCH_SIZE = 16\n", + "DOWNSAMPLE_RATIO = 4\n", + "BOS_ID = 0\n", + "STOP_STR = \"<|end▁of▁sentence|>\"\n", + "\n", + "# 分辨率预设\n", + "RESOLUTION_PRESETS = {\n", + " \"Tiny (512, 快速)\": {\"base_size\": 512, \"image_size\": 512, \"crop_mode\": False},\n", + " \"Small (640)\": {\"base_size\": 640, \"image_size\": 640, \"crop_mode\": False},\n", + " \"Base (1024)\": {\"base_size\": 1024, \"image_size\": 1024, \"crop_mode\": False},\n", + " \"Large (1280)\": {\"base_size\": 1280, \"image_size\": 1280, \"crop_mode\": False},\n", + " \"Gundam (推荐)\": {\"base_size\": 1024, \"image_size\": 640, \"crop_mode\": True},\n", + "}\n", + "\n", + "# 任务类型\n", + "TASK_PROMPTS = {\n", + " \"Free OCR\": \"\\nFree OCR. \",\n", + " \"转换为 Markdown\": \"\\n<|grounding|>Convert the document to markdown. \",\n", + " \"解析图表\": \"\\nParse the figure. \",\n", + " \"文本定位\": \"\\n<|grounding|>Find \\\"{ref_text}\\\". \",\n", + "}\n", + "\n", + "# ============================================================\n", + "# 模型加载(从模型文件中导入辅助函数)\n", + "# ============================================================\n", + "print(\"=\" * 60)\n", + "print(\"正在加载 DeepSeek-OCR 模型...\")\n", + "print(f\"模型: {MODEL_NAME}\")\n", + "print(\"=\" * 60)\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n", + "model = AutoModel.from_pretrained(\n", + " MODEL_NAME,\n", + " _attn_implementation=\"eager\",\n", + " trust_remote_code=True,\n", + " use_safetensors=True,\n", + " device_map=\"auto\",\n", + ")\n", + "model = model.eval()\n", + "\n", + "print(\"正在合并 MoE 权重 (combine_moe)...\")\n", + "model.combine_moe()\n", + "\n", + "# 修复 NPU 不支持 scatter_add 的问题:用 one_hot + 矩阵乘法替代\n", + "def _patched_forward_for_moe(self, hidden_states):\n", + " batch_size, sequence_length, hidden_dim = hidden_states.shape\n", + " selected_experts, routing_weights = self.gate(hidden_states)\n", + " n_experts = self.config.n_routed_experts\n", + " routing_weights = routing_weights.to(hidden_states.dtype)\n", + " # 用 one_hot 替代 scatter_add\n", + " one_hot = F.one_hot(selected_experts, n_experts).to(routing_weights.dtype)\n", + " router_scores = (one_hot * routing_weights.unsqueeze(-1)).sum(dim=1)\n", + " hidden_states = hidden_states.view(-1, hidden_dim)\n", + " if self.config.n_shared_experts is not None:\n", + " shared_expert_output = self.shared_experts(hidden_states)\n", + " hidden_w1 = torch.matmul(hidden_states, self.w1)\n", + " hidden_w3 = torch.matmul(hidden_states, self.w3)\n", + " hidden_states = self.act(hidden_w1) * hidden_w3\n", + " hidden_states = torch.bmm(hidden_states, self.w2) * torch.transpose(router_scores, 0, 1).unsqueeze(-1)\n", + " final_hidden_states = hidden_states.sum(dim=0, dtype=hidden_states.dtype)\n", + " if self.config.n_shared_experts is not None:\n", + " hidden_states = final_hidden_states + shared_expert_output\n", + " return hidden_states.view(batch_size, sequence_length, hidden_dim)\n", + "\n", + "# 对所有 MoE 层应用修复后的 forward\n", + "for layer in model.model.layers:\n", + " if hasattr(layer.mlp, 'w1'): # combine_moe 已处理的层\n", + " layer.mlp.forward = types.MethodType(_patched_forward_for_moe, layer.mlp)\n", + "\n", + "print(\"模型加载完成!\")\n", + "print(\"=\" * 60)\n", + "\n", + "# 从模型的 trust_remote_code 模块中获取辅助函数\n", + "# 这些函数通过 trust_remote_code=True 加载后可在模块中找到\n", + "_model_module = type(model).__module__\n", + "\n", + "_mod = importlib.import_module(_model_module)\n", + "format_messages = _mod.format_messages\n", + "load_pil_images = _mod.load_pil_images\n", + "text_encode = _mod.text_encode\n", + "BasicImageTransform = _mod.BasicImageTransform\n", + "dynamic_preprocess = _mod.dynamic_preprocess\n", + "re_match = _mod.re_match\n", + "process_image_with_refs = _mod.process_image_with_refs\n", + "\n", + "\n", + "# ============================================================\n", + "# 图像预处理(从 model.infer() 方法中抽取)\n", + "# ============================================================\n", + "def prepare_inputs(prompt_text: str, image_file: str, base_size: int, image_size: int, crop_mode: bool):\n", + " \"\"\"\n", + " 从 model.infer() 方法 (modeling_deepseekocr.py:732-937) 中抽取的图像预处理逻辑。\n", + " 构建 conversation -> format_messages -> 图像 token 化 -> 返回模型输入张量。\n", + " \"\"\"\n", + " # 1. 构建 conversation\n", + " conversation = [\n", + " {\n", + " \"role\": \"<|User|>\",\n", + " \"content\": prompt_text,\n", + " \"images\": [image_file],\n", + " },\n", + " {\"role\": \"<|Assistant|>\", \"content\": \"\"},\n", + " ]\n", + "\n", + " # 2. format_messages 转换 prompt\n", + " formatted_prompt = format_messages(conversations=conversation, sft_format=\"plain\", system_prompt=\"\")\n", + "\n", + " # 3. 加载图片\n", + " images = load_pil_images(conversation)\n", + " image_draw = images[0].copy()\n", + "\n", + " # 4. 图像 token 化\n", + " image_transform = BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True)\n", + "\n", + " text_splits = formatted_prompt.split(IMAGE_TOKEN)\n", + "\n", + " images_list, images_crop_list, images_seq_mask = [], [], []\n", + " tokenized_str = []\n", + " images_spatial_crop = []\n", + "\n", + " for text_sep, image in zip(text_splits, images):\n", + " tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)\n", + " tokenized_str += tokenized_sep\n", + " images_seq_mask += [False] * len(tokenized_sep)\n", + "\n", + " if crop_mode:\n", + " if image.size[0] <= 640 and image.size[1] <= 640:\n", + " crop_ratio = [1, 1]\n", + " else:\n", + " images_crop_raw, crop_ratio = dynamic_preprocess(image)\n", + "\n", + " # 全局视图\n", + " global_view = ImageOps.pad(\n", + " image, (base_size, base_size),\n", + " color=tuple(int(x * 255) for x in image_transform.mean),\n", + " )\n", + " images_list.append(image_transform(global_view).to(model.dtype))\n", + "\n", + " width_crop_num, height_crop_num = crop_ratio\n", + " images_spatial_crop.append([width_crop_num, height_crop_num])\n", + "\n", + " if width_crop_num > 1 or height_crop_num > 1:\n", + " for i in range(len(images_crop_raw)):\n", + " images_crop_list.append(image_transform(images_crop_raw[i]).to(model.dtype))\n", + "\n", + " num_queries = math.ceil((image_size // PATCH_SIZE) / DOWNSAMPLE_RATIO)\n", + " num_queries_base = math.ceil((base_size // PATCH_SIZE) / DOWNSAMPLE_RATIO)\n", + "\n", + " # 图像 token 序列\n", + " tokenized_image = ([IMAGE_TOKEN_ID] * num_queries_base + [IMAGE_TOKEN_ID]) * num_queries_base\n", + " tokenized_image += [IMAGE_TOKEN_ID]\n", + " if width_crop_num > 1 or height_crop_num > 1:\n", + " tokenized_image += (\n", + " [IMAGE_TOKEN_ID] * (num_queries * width_crop_num) + [IMAGE_TOKEN_ID]\n", + " ) * (num_queries * height_crop_num)\n", + " tokenized_str += tokenized_image\n", + " images_seq_mask += [True] * len(tokenized_image)\n", + " else:\n", + " if image_size <= 640:\n", + " image = image.resize((image_size, image_size))\n", + " global_view = ImageOps.pad(\n", + " image, (image_size, image_size),\n", + " color=tuple(int(x * 255) for x in image_transform.mean),\n", + " )\n", + " images_list.append(image_transform(global_view).to(model.dtype))\n", + "\n", + " width_crop_num, height_crop_num = 1, 1\n", + " images_spatial_crop.append([width_crop_num, height_crop_num])\n", + "\n", + " num_queries = math.ceil((image_size // PATCH_SIZE) / DOWNSAMPLE_RATIO)\n", + "\n", + " tokenized_image = ([IMAGE_TOKEN_ID] * num_queries + [IMAGE_TOKEN_ID]) * num_queries\n", + " tokenized_image += [IMAGE_TOKEN_ID]\n", + " tokenized_str += tokenized_image\n", + " images_seq_mask += [True] * len(tokenized_image)\n", + "\n", + " # 最后一段文本\n", + " tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)\n", + " tokenized_str += tokenized_sep\n", + " images_seq_mask += [False] * len(tokenized_sep)\n", + "\n", + " # 添加 BOS token\n", + " tokenized_str = [BOS_ID] + tokenized_str\n", + " images_seq_mask = [False] + images_seq_mask\n", + "\n", + " # 转为张量\n", + " input_ids = torch.LongTensor(tokenized_str)\n", + " images_seq_mask_t = torch.tensor(images_seq_mask, dtype=torch.bool)\n", + "\n", + " if len(images_list) == 0:\n", + " images_ori = torch.zeros((1, 3, image_size, image_size))\n", + " images_spatial_crop_t = torch.zeros((1, 2), dtype=torch.long)\n", + " images_crop = torch.zeros((1, 3, base_size, base_size))\n", + " else:\n", + " images_ori = torch.stack(images_list, dim=0)\n", + " images_spatial_crop_t = torch.tensor(images_spatial_crop, dtype=torch.long)\n", + " if images_crop_list:\n", + " images_crop = torch.stack(images_crop_list, dim=0)\n", + " else:\n", + " images_crop = torch.zeros((1, 3, base_size, base_size))\n", + "\n", + " return {\n", + " \"input_ids\": input_ids.unsqueeze(0).cuda(),\n", + " \"images\": [(images_crop.cuda(), images_ori.cuda())],\n", + " \"images_seq_mask\": images_seq_mask_t.unsqueeze(0).cuda(),\n", + " \"images_spatial_crop\": images_spatial_crop_t,\n", + " \"image_draw\": image_draw,\n", + " }\n", + "\n", + "\n", + "# ============================================================\n", + "# 后处理:标注图生成\n", + "# ============================================================\n", + "def postprocess_output(raw_text: str, image_draw: Image.Image):\n", + " \"\"\"处理模型输出,生成带标注的图像。\"\"\"\n", + " if raw_text.endswith(STOP_STR):\n", + " raw_text = raw_text[: -len(STOP_STR)]\n", + " raw_text = raw_text.strip()\n", + "\n", + " matches_ref, matches_images, matches_other = re_match(raw_text)\n", + "\n", + " annotated_image = None\n", + " if matches_ref:\n", + " with tempfile.TemporaryDirectory() as tmp_dir:\n", + " os.makedirs(os.path.join(tmp_dir, \"images\"), exist_ok=True)\n", + " annotated_image = process_image_with_refs(image_draw, matches_ref, tmp_dir)\n", + "\n", + " # 无标注时返回原图\n", + " if annotated_image is None:\n", + " annotated_image = image_draw\n", + "\n", + " # 清理特殊标记,保留可读文本\n", + " # matches_ref 是元组列表: [(full_match, ref_text, det_coords), ...]\n", + " display_text = raw_text\n", + " for full_match, ref_text, det_coords in matches_ref:\n", + " if ref_text == \"image\":\n", + " display_text = display_text.replace(full_match, \"[图片区域]\")\n", + " else:\n", + " # 仅去除定位标签,保留引用文本内容\n", + " display_text = display_text.replace(full_match, ref_text)\n", + " display_text = display_text.replace(\"\\\\coloneqq\", \":=\").replace(\"\\\\eqqcolon\", \"=:\")\n", + "\n", + " return display_text, annotated_image\n", + "\n", + "\n", + "# ============================================================\n", + "# 流式推理 + 时间统计\n", + "# ============================================================\n", + "def format_metrics(ttft: Optional[float], token_count: int, t_start: float) -> str:\n", + " \"\"\"格式化性能指标。\"\"\"\n", + " elapsed = time.time() - t_start\n", + " lines = []\n", + " lines.append(f\"**首 Token 延迟 (TTFT)**: {ttft:.3f}s\" if ttft else \"**首 Token 延迟 (TTFT)**: 等待中...\")\n", + " lines.append(f\"**已生成 Token 数**: {token_count}\")\n", + " lines.append(f\"**总耗时**: {elapsed:.2f}s\")\n", + " if token_count > 0 and elapsed > 0:\n", + " tokens_per_sec = token_count / elapsed\n", + " lines.append(f\"**生成速度**: {tokens_per_sec:.2f} tokens/s\")\n", + " if token_count > 1 and ttft:\n", + " decode_time = elapsed - ttft\n", + " decode_speed = (token_count - 1) / decode_time if decode_time > 0 else 0\n", + " lines.append(f\"**解码速度** (不含首 token): {decode_speed:.2f} tokens/s\")\n", + " return \"\\n\\n\".join(lines)\n", + "\n", + "\n", + "def stream_ocr(image, resolution, task_type, ref_text):\n", + " \"\"\"\n", + " 流式 OCR 推理函数。\n", + " 使用 TextIteratorStreamer 实现流式 token 输出。\n", + " \"\"\"\n", + " if image is None:\n", + " yield \"请上传图片\", None, \"请先上传一张图片\"\n", + " return\n", + "\n", + " # 获取分辨率参数\n", + " preset = RESOLUTION_PRESETS[resolution]\n", + " base_size = preset[\"base_size\"]\n", + " image_size = preset[\"image_size\"]\n", + " crop_mode = preset[\"crop_mode\"]\n", + "\n", + " # 构建 prompt\n", + " prompt_template = TASK_PROMPTS[task_type]\n", + " if \"{ref_text}\" in prompt_template:\n", + " if not ref_text or not ref_text.strip():\n", + " yield \"请输入要定位的文本\", None, \"「文本定位」模式需要输入引用文本\"\n", + " return\n", + " prompt_text = prompt_template.format(ref_text=ref_text.strip())\n", + " else:\n", + " prompt_text = prompt_template\n", + "\n", + " # 保存临时图片文件供模型使用\n", + " with tempfile.NamedTemporaryFile(suffix=\".png\", delete=False) as tmp:\n", + " tmp_path = tmp.name\n", + " Image.fromarray(image).save(tmp_path)\n", + "\n", + " try:\n", + " # 1. 准备输入\n", + " model.disable_torch_init()\n", + " inputs = prepare_inputs(prompt_text, tmp_path, base_size, image_size, crop_mode)\n", + " image_draw = inputs.pop(\"image_draw\")\n", + "\n", + " # 2. 创建 streamer\n", + " streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)\n", + "\n", + " # 3. 后台线程运行 generate\n", + " generate_kwargs = dict(\n", + " input_ids=inputs[\"input_ids\"],\n", + " images=inputs[\"images\"],\n", + " images_seq_mask=inputs[\"images_seq_mask\"],\n", + " images_spatial_crop=inputs[\"images_spatial_crop\"],\n", + " temperature=0.0,\n", + " eos_token_id=tokenizer.eos_token_id,\n", + " streamer=streamer,\n", + " max_new_tokens=8192,\n", + " no_repeat_ngram_size=20,\n", + " use_cache=True,\n", + " )\n", + "\n", + " thread = Thread(target=_generate_with_no_grad, kwargs=generate_kwargs)\n", + "\n", + " # 4. 流式输出 + 时间统计\n", + " t_start = time.time()\n", + " thread.start()\n", + " first_token_time = None\n", + " token_count = 0\n", + " full_text = \"\"\n", + "\n", + " for new_text in streamer:\n", + " if first_token_time is None:\n", + " first_token_time = time.time() - t_start\n", + " token_count += 1\n", + " full_text += new_text\n", + " # 流式 yield:显示文本、暂无标注图、实时指标\n", + " display = full_text.replace(STOP_STR, \"\").strip()\n", + " yield display, None, format_metrics(first_token_time, token_count, t_start)\n", + "\n", + " thread.join()\n", + "\n", + " # 5. 最终后处理\n", + " display_text, annotated_image = postprocess_output(full_text, image_draw)\n", + " final_metrics = format_metrics(first_token_time, token_count, t_start)\n", + " yield display_text, annotated_image, final_metrics\n", + "\n", + " finally:\n", + " os.unlink(tmp_path)\n", + "\n", + "\n", + "def _generate_with_no_grad(**kwargs):\n", + " \"\"\"在 no_grad 上下文中运行 model.generate。\"\"\"\n", + " with torch.no_grad():\n", + " model.generate(**kwargs)\n", + "\n", + "\n", + "# ============================================================\n", + "# Gradio UI\n", + "# ============================================================\n", + "def _require_gradio():\n", + " \"\"\"返回 gradio 模块;未安装时给出明确错误。\"\"\"\n", + " if gr is None:\n", + " raise RuntimeError(\"gradio is required to launch the UI. Install gradio first.\")\n", + " return gr\n", + "\n", + "\n", + "def toggle_ref_text(task_type):\n", + " \"\"\"根据任务类型切换引用文本输入框可见性。\"\"\"\n", + " return _require_gradio().update(visible=(task_type == \"文本定位\"))\n", + "\n", + "\n", + "DESCRIPTION = \"\"\"\n", + "# DeepSeek-OCR MindSpore DEMO\n", + "\n", + "基于 **MindSpore 2.7.0 + MindNLP 0.5.1** 的文本识别与结构化解析交互式演示。\n", + "\n", + "**模型**: DeepSeek-OCR | **硬件**: Ascend NPU 910B | **优化**: MoE 权重合并 + KV Cache\n", + "\n", + "### 性能优化说明\n", + "| 优化项 | 说明 |\n", + "|--------|------|\n", + "| `combine_moe()` | 合并 MoE 专家权重,减少内存访问开销 |\n", + "| `scatter_add` 适配 | 用 `one_hot` + 矩阵乘法替代 NPU 不支持的 `scatter_add` |\n", + "| `use_cache=True` | 启用 KV Cache,避免重复计算注意力 |\n", + "| `no_repeat_ngram_size=20` | 控制重复生成,提升有效 token 效率 |\n", + "| `eager` attention | Ascend NPU 上兼容性最佳的注意力实现 |\n", + "| `float32` 精度 | 保证 OCR 输出质量(float16 存在精度退化)|\n", + "\n", + "### 优化前后对比(Gundam 模式,Ascend 910B,256 tokens)\n", + "| 配置 | TTFT | 生成速度 | 解码速度 | 加速比 |\n", + "|------|------|----------|----------|--------|\n", + "| **全部优化** | 9.757s | 7.95 tok/s | **11.34 tok/s** | **基线** |\n", + "| 关闭 MoE 合并 | 10.805s | 1.68 tok/s | 2.29 tok/s | **4.95x 慢** |\n", + "\n", + "### 不同分辨率模式对比(256 tokens)\n", + "| 模式 | TTFT | 生成速度 | 解码速度 | 适用场景 |\n", + "|------|------|----------|----------|----------|\n", + "| Tiny (512) | **0.214s** | **11.00 tok/s** | 11.06 tok/s | 快速预览 |\n", + "| Small (640) | 0.257s | 10.76 tok/s | 10.83 tok/s | 一般文档 |\n", + "| **Gundam (推荐)** | 9.757s | 7.95 tok/s | 11.34 tok/s | **精度最佳** |\n", + "\"\"\"\n", + "\n", + "demo = None\n", + "if gr is not None:\n", + " with gr.Blocks(title=\"DeepSeek-OCR MindSpore DEMO\") as demo:\n", + " gr.Markdown(DESCRIPTION)\n", + "\n", + " with gr.Row():\n", + " # 左侧:输入区\n", + " with gr.Column(scale=1):\n", + " input_image = gr.Image(label=\"上传图片\", type=\"numpy\", height=400)\n", + "\n", + " resolution = gr.Dropdown(\n", + " choices=list(RESOLUTION_PRESETS.keys()),\n", + " value=\"Gundam (推荐)\",\n", + " label=\"分辨率模式\",\n", + " info=\"Gundam 模式在精度和速度之间取得最佳平衡\",\n", + " )\n", + "\n", + " task_type = gr.Dropdown(\n", + " choices=list(TASK_PROMPTS.keys()),\n", + " value=\"Free OCR\",\n", + " label=\"任务类型\",\n", + " )\n", + "\n", + " ref_text_input = gr.Textbox(\n", + " label=\"引用文本(仅「文本定位」模式)\",\n", + " placeholder=\"输入要定位的文本...\",\n", + " visible=False,\n", + " )\n", + "\n", + " run_btn = gr.Button(\"开始识别\", variant=\"primary\", size=\"lg\")\n", + "\n", + " # 右侧:输出区\n", + " with gr.Column(scale=1):\n", + " output_text = gr.Textbox(\n", + " label=\"OCR 识别结果\",\n", + " lines=15,\n", + " max_lines=30,\n", + " )\n", + " output_image = gr.Image(label=\"标注结果图\", height=300)\n", + " metrics_display = gr.Markdown(label=\"性能统计\", value=\"等待推理...\")\n", + "\n", + " # 事件绑定\n", + " task_type.change(fn=toggle_ref_text, inputs=task_type, outputs=ref_text_input)\n", + "\n", + " run_btn.click(\n", + " fn=stream_ocr,\n", + " inputs=[input_image, resolution, task_type, ref_text_input],\n", + " outputs=[output_text, output_image, metrics_display],\n", + " )\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " if demo is None:\n", + " _require_gradio()\n", + " demo.queue()\n", + " demo.launch(\n", + " server_name=\"0.0.0.0\",\n", + " server_port=7860,\n", + " share=False,\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 相关链接\n", + "\n", + "- **魔乐社区 (ModelScope)**: [DeepSeek-OCR 模型页](https://modelers.cn/)\n", + "- **HuggingFace 模型**: [lvyufeng/DeepSeek-OCR](https://huggingface.co/lvyufeng/DeepSeek-OCR)\n", + "- **MindNLP 项目**: [GitHub - mindnlp](https://github.com/mindspore-lab/mindnlp)\n", + "- **MindSpore 官网**: [mindspore.cn](https://www.mindspore.cn/)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}