Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
From 4fffb31ea6073e689089caa02e0a17cf4b6f73af Mon Sep 17 00:00:00 2001
From: Alice <[email protected]>
Date: Tue, 4 Nov 2025 22:41:54 +0800
Subject: [PATCH 1/3] =?UTF-8?q?=E5=A4=9A=E6=A8=A1=E6=80=81=E8=B0=83?=
=?UTF-8?q?=E4=BC=98=E6=8F=90=E4=BA=A4?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
.../janus_pro/janus/models/processing_vlm.py | 28 ++++++++++++++++---
.../models/qwen2_vl/modeling_qwen2_vl.py | 14 ++++++----
2 files changed, 33 insertions(+), 9 deletions(-)

diff --git a/llm/inference/janus_pro/janus/models/processing_vlm.py b/llm/inference/janus_pro/janus/models/processing_vlm.py
index 7f881e39..c93bfd92 100644
--- a/llm/inference/janus_pro/janus/models/processing_vlm.py
+++ b/llm/inference/janus_pro/janus/models/processing_vlm.py
@@ -110,9 +110,13 @@ class VLChatProcessor(ProcessorMixin):
print(f"Add image tag = {image_tag} to the tokenizer")

self.image_tag = image_tag
+ self.image_tag_id = None
self.image_start_tag = image_start_tag
+ self.image_start_tag_id = None
self.image_end_tag = image_end_tag
+ self.image_end_tag_id = None
self.pad_tag = pad_tag
+ self.pad_tag_id = None

self.num_image_tokens = num_image_tokens
self.add_special_token = add_special_token
@@ -185,17 +189,29 @@ class VLChatProcessor(ProcessorMixin):

@property
def image_id(self):
- image_id = self.tokenizer.vocab.get(self.image_tag)
+ if self.image_tag_id is None:
+ image_id = self.tokenizer.vocab.get(self.image_tag)
+ self.image_tag_id = image_id
+ else:
+ image_id = self.image_tag_id
return image_id

@property
def image_start_id(self):
- image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
+ if self.image_start_tag_id is None:
+ image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
+ self.image_start_tag_id = image_start_id
+ else:
+ image_start_id = self.image_start_tag_id
return image_start_id

@property
def image_end_id(self):
- image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
+ if self.image_end_tag_id is None:
+ image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
+ self.image_end_tag_id = image_end_id
+ else:
+ image_end_id = self.image_end_tag_id
return image_end_id

@property
@@ -208,7 +224,11 @@ class VLChatProcessor(ProcessorMixin):

@property
def pad_id(self):
- pad_id = self.tokenizer.vocab.get(self.pad_tag)
+ if self.pad_tag_id is None:
+ pad_id = self.tokenizer.vocab.get(self.pad_tag)
+ self.pad_tag_id = pad_id
+ else:
+ pad_id = self.pad_tag_id
# pad_id = self.tokenizer.pad_token_id
# if pad_id is None:
# pad_id = self.tokenizer.eos_token_id
diff --git a/mindnlp/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/mindnlp/transformers/models/qwen2_vl/modeling_qwen2_vl.py
index d059dcbe..ffb100cf 100644
--- a/mindnlp/transformers/models/qwen2_vl/modeling_qwen2_vl.py
+++ b/mindnlp/transformers/models/qwen2_vl/modeling_qwen2_vl.py
@@ -265,6 +265,7 @@ class PatchEmbed(nn.Module):
self.embed_dim = embed_dim

kernel_size = (temporal_patch_size, patch_size, patch_size)
+ self.kernel_size = kernel_size
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)

def forward(self, hidden_states: mindspore.Tensor) -> mindspore.Tensor:
@@ -272,7 +273,10 @@ class PatchEmbed(nn.Module):
hidden_states = hidden_states.view(
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
)
- hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+
+ hidden_states = mindspore.mint.nn.functional.conv3d(hidden_states.to(dtype=target_dtype), self.proj.weight,
+ stride=self.kernel_size).view(-1, self.embed_dim)
+
return hidden_states


@@ -330,7 +334,7 @@ class VisionAttention(nn.Module):
v = v.swapaxes(0, 1)
attn_weights = ops.matmul(q, k.swapaxes(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=mindspore.float32).to(q.dtype)
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_output = ops.matmul(attn_weights, v)
attn_output = attn_output.swapaxes(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
@@ -559,7 +563,7 @@ class Qwen2VLAttention(nn.Module):
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=mindspore.float32).to(query_states.dtype)
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = ops.matmul(attn_weights, value_states)

@@ -637,7 +641,7 @@ class Qwen2VLDecoderLayer(nn.Module):

residual = hidden_states

- hidden_states = self.input_layernorm(hidden_states)
+ hidden_states, _ = mindspore.ops.rms_norm(hidden_states, self.input_layernorm.weight, self.input_layernorm.variance_epsilon)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
@@ -654,7 +658,7 @@ class Qwen2VLDecoderLayer(nn.Module):

# Fully Connected
residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states, _ = mindspore.ops.rms_norm(hidden_states, self.post_attention_layernorm.weight, self.post_attention_layernorm.variance_epsilon)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

--
2.47.1.windows.2

Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
From a9b4c9b85f237fcae52ea396cf7be606cdf16410 Mon Sep 17 00:00:00 2001
From: Alice <[email protected]>
Date: Thu, 6 Nov 2025 16:39:55 +0800
Subject: [PATCH 2/3] =?UTF-8?q?=E7=BB=A7=E7=BB=AD=E4=BC=98=E5=8C=96?=
=?UTF-8?q?=EF=BC=8C=E5=8A=A0=E5=85=A5rmsnorm,jit?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
.../janus_pro/janus/models/clip_encoder.py | 6 +++--
.../janus_pro/janus/models/modeling_vlm.py | 27 ++++++++-----------
.../janus_pro/janus/models/siglip_vit.py | 11 +++++---
.../janus_pro/janus/models/timm_layers.py | 2 +-
.../models/llama/modeling_llama.py | 6 ++---
5 files changed, 27 insertions(+), 25 deletions(-)

diff --git a/llm/inference/janus_pro/janus/models/clip_encoder.py b/llm/inference/janus_pro/janus/models/clip_encoder.py
index a0620cfe..acb6ea3f 100644
--- a/llm/inference/janus_pro/janus/models/clip_encoder.py
+++ b/llm/inference/janus_pro/janus/models/clip_encoder.py
@@ -56,6 +56,7 @@ class CLIPVisionTower(nn.Module):
self.vision_tower, self.forward_kwargs = self.build_vision_tower(
vision_tower_params
)
+ self.vision_tower.jit()

if pixel_mean is not None and pixel_std is not None:
image_norm = Normalize(
@@ -112,10 +113,11 @@ class CLIPVisionTower(nn.Module):
Returns:
image_features (torch.Tensor): [b, n_patch, d]
"""
-
+
if self.image_norm is not None:
images = self.image_norm(images)
-
+
image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
+
image_features = self.feature_select(image_forward_outs)
return image_features
diff --git a/llm/inference/janus_pro/janus/models/modeling_vlm.py b/llm/inference/janus_pro/janus/models/modeling_vlm.py
index 7178c398..3d2d2d74 100644
--- a/llm/inference/janus_pro/janus/models/modeling_vlm.py
+++ b/llm/inference/janus_pro/janus/models/modeling_vlm.py
@@ -241,12 +241,16 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
Returns:
input_embeds (torch.Tensor): [b, T, D]
"""
-
+
+
bs, n = pixel_values.shape[0:2]
# "b n c h w -> (b n) c h w"
images = ops.reshape(
pixel_values, (bs * n, pixel_values.shape[2], pixel_values.shape[3], pixel_values.shape[4]))
- images_embeds = self.aligner(self.vision_model(images))
+
+ vr = self.vision_model(images)
+
+ images_embeds = self.aligner(vr)

# "(b n) t d -> b (n t) d"
images_embeds = ops.reshape(
@@ -259,33 +263,24 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
condition = input_ids < 0
input_ids = (1-condition) * input_ids + condition * \
0 # ignore the image embeddings
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
+

# replace with the image embeddings
# 627 576
# inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
- print("inputs_embeds:", inputs_embeds.shape)
- print("images_embeds[images_emb_mask].dtype", images_embeds[images_emb_mask].dtype)
- print("inputs_embeds.dtype", inputs_embeds.dtype)
+
+
padding_size = images_seq_mask.shape[1] - images_emb_mask.shape[1]
padding = Tensor(np.full((images_seq_mask.shape[0], padding_size), False), dtype=images_emb_mask.dtype)
padded_images_emb_mask = ops.concat((images_emb_mask, padding), dim=1)
- print("padded_images_emb_mask.shape:",padded_images_emb_mask.shape)
- print("images_embeds.shape:",images_embeds.shape)
- print("images_seq_mask.shape:",images_seq_mask.shape)
first_true = images_seq_mask.nonzero().squeeze()[0][1] # 42
last_true = images_seq_mask.nonzero().squeeze()[-1][1] # 42
- print("first_true:",first_true)
- print("last_true:",last_true)
left = inputs_embeds[:,:first_true]
- print(left.shape)
right = inputs_embeds[:, last_true+1:]
- print(right.shape)
inputs_embeds = ops.cat((left, images_embeds, right),1)
- print("inputs_embeds.shape:",inputs_embeds.shape)
- print("inputs_embeds.dtype:",inputs_embeds.dtype)
-
-
+

# inputs_embeds = images_embeds[padded_images_emb_mask] * images_seq_mask + inputs_embeds * (1 - images_seq_mask)
return inputs_embeds
diff --git a/llm/inference/janus_pro/janus/models/siglip_vit.py b/llm/inference/janus_pro/janus/models/siglip_vit.py
index 56a6f299..d896eeb2 100644
--- a/llm/inference/janus_pro/janus/models/siglip_vit.py
+++ b/llm/inference/janus_pro/janus/models/siglip_vit.py
@@ -580,7 +580,11 @@ class VisionTransformer(nn.Module):
if return_prefix_tokens:
return tuple(zip(outputs, prefix_tokens))
return tuple(outputs)
-
+
+ @mindspore.jit(backend='GE')
+ def run_blocks_jit(self, x: mindspore.Tensor) -> mindspore.Tensor:
+ return self.blocks(x)
+
def forward_features(self, x: mindspore.Tensor) -> mindspore.Tensor:
x = self.patch_embed(x)
x = self._pos_embed(x)
@@ -590,10 +594,11 @@ class VisionTransformer(nn.Module):
# x = checkpoint_seq(self.blocks, x)
# else:
# x = self.blocks(x)
- x = self.blocks(x)
+ x = self.run_blocks_jit(x)
x = self.norm(x)
return x
-
+
+ @mindspore.jit(backend='GE')
def forward_head(self, x: mindspore.Tensor, pre_logits: bool = False) -> mindspore.Tensor:
if self.attn_pool is not None:
x = self.attn_pool(x)
diff --git a/llm/inference/janus_pro/janus/models/timm_layers.py b/llm/inference/janus_pro/janus/models/timm_layers.py
index 8960d256..61e68c54 100644
--- a/llm/inference/janus_pro/janus/models/timm_layers.py
+++ b/llm/inference/janus_pro/janus/models/timm_layers.py
@@ -46,7 +46,7 @@ class Mlp(nn.Module):

def forward(self, x):
x = self.fc1(x)
- x = self.act(x)
+ x = mindspore.ops.gelu(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
diff --git a/mindnlp/transformers/models/llama/modeling_llama.py b/mindnlp/transformers/models/llama/modeling_llama.py
index 9c5cb555..c8c55492 100644
--- a/mindnlp/transformers/models/llama/modeling_llama.py
+++ b/mindnlp/transformers/models/llama/modeling_llama.py
@@ -429,7 +429,7 @@ class LlamaAttention(nn.Module):
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=mindspore.float32).to(query_states.dtype)
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = ops.matmul(attn_weights, value_states)

@@ -508,7 +508,7 @@ class LlamaDecoderLayer(nn.Module):
"""
residual = hidden_states

- hidden_states = self.input_layernorm(hidden_states)
+ hidden_states, _ = mindspore.ops.rms_norm(hidden_states, self.input_layernorm.weight, self.input_layernorm.variance_epsilon)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
@@ -526,7 +526,7 @@ class LlamaDecoderLayer(nn.Module):

# Fully Connected
residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states, _ = mindspore.ops.rms_norm(hidden_states, self.post_attention_layernorm.weight, self.post_attention_layernorm.variance_epsilon)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

--
2.47.1.windows.2

Loading