Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
curioyang committed Mar 4, 2025
1 parent 353429a commit 074b509
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 102 deletions.
14 changes: 7 additions & 7 deletions src/Nncase.Importer/HuggingFace/HuggingFaceUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ public static Tuple<List<double>, float> ComputeDefaultRopeParameters(Dictionary
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
*/
var baseRoPETheta = (float)config["rope_theta"];
var baseRoPETheta = (float)(double)config["rope_theta"];
var partialRotaryFactor =
1.0; // config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0

Expand All @@ -215,8 +215,8 @@ public static Tuple<List<double>, float> ComputeDefaultRopeParameters(Dictionary
}
else
{
int hiddenSize = (int)config["hidden_size"];
int numAttentionHeads = (int)config["num_attention_heads"];
int hiddenSize = (int)(long)config["hidden_size"];
int numAttentionHeads = (int)(long)config["num_attention_heads"];
headDim = hiddenSize / numAttentionHeads;
}

Expand All @@ -235,13 +235,13 @@ public static Tuple<List<double>, float> ComputeDefaultRopeParameters(Dictionary
public class DynamicCache
{
public int seenTokens = 0;
public List<object> keyCache;
public List<object> ValueCache;
public List<object>? keyCache;
public List<object>? ValueCache;

public int GetSeqLength(int layerCount = 0)
{
bool isEmptyLayer = keyCache.Count == 0 || keyCache.Count <= layerCount || (int)keyCache[layerCount] == 0;
var layer = (Call)keyCache[(Index)layerCount!];
bool isEmptyLayer = keyCache?.Count == 0 || keyCache?.Count <= layerCount || (int)keyCache?[layerCount] == 0;
var layer = (Call)keyCache?[(Index)layerCount!];
return isEmptyLayer ? 0 : layer.CheckedShape[-2].FixedValue;
}

Expand Down
215 changes: 120 additions & 95 deletions src/Nncase.Importer/HuggingFace/Qwen2.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public partial class HuggingFaceImporter
{
protected (IEnumerable<Var> Inputs, Dictionary<Var, Expr[]> VarMap) Qwen2CreateInputs()
{
long hiddenSize = (long)_config!["hidden_size"];
_inputs = new List<Var>();
_dynVarMap = new Dictionary<string, Var>();
var varMap = new Dictionary<Var, Expr[]>();
Expand All @@ -45,15 +46,16 @@ public partial class HuggingFaceImporter
if (!_fixVarMap.ContainsKey("history_len"))
_dynVarMap["history_len"] = new Var("history_len", new TensorType(DataTypes.Int32, Shape.Scalar));

var inputIdsShapeExpr = new Expr[] { (Expr)_dynVarMap["sequence_length"], (Expr)1, (Expr)_config!["hidden_size"] };
var inputIdsShapeExpr = new Expr[] { _dynVarMap["sequence_length"], 1, hiddenSize };
var attentionMaskShapeExpr =
new Expr[] { 1, 1, _dynVarMap["sequence_length"], _dynVarMap["sequence_length"] };
var positionIdsShapeExpr = new Expr[] { 1, _dynVarMap["sequence_length"] };
var pastKeyValueShapeExpr = new Expr[] { 24, 2, 1, _dynVarMap["history_len"], 2, 64 };


var inputIds = new Var("input_ids",
new TensorType(DataTypes.Float32, new Shape(Dimension.Unknown, 1, (int)_config!["hidden_size"])));
new TensorType(DataTypes.Int32, new Shape(Dimension.Unknown, 1, (int)hiddenSize)));

var attentionMask = new Var("attention_mask",
new TensorType(DataTypes.Float32, new Shape(1, 1, Dimension.Unknown, Dimension.Unknown)));
var positionIds = new Var("position_ids",
Expand All @@ -72,7 +74,8 @@ public partial class HuggingFaceImporter
return (_inputs, varMap);
}

private Tuple<Call, HuggingFaceUtils.DynamicCache> VisitQwen2ForCausalLM()
// private Tuple<Call, HuggingFaceUtils.DynamicCache> VisitQwen2ForCausalLM()
private Tuple<Call, Call> VisitQwen2ForCausalLM()
{
if (_constTensors == null)
{
Expand Down Expand Up @@ -117,71 +120,83 @@ public partial class HuggingFaceImporter
*/


var input_ids = new Var();
var (lastHiddenStates, pastKeyValues, allSelfAttns, allHiddenStates) = Qwen2Model(input_ids,
inputEmbeds: null, new HuggingFaceUtils.DynamicCache(), cachePosition: null, positionIds: null,
useCache: false, outputAttentions: false, outputHiddenStates: false);
var input_ids = _inputs[0];
var attention_mask = _inputs[1];
var position_ids = _inputs[2];
var pastKeyValues = _inputs[3];
// var (lastHiddenStates, pastKeyValues, allSelfAttns, allHiddenStates) = Qwen2Model(input_ids,
// inputEmbeds: null, new HuggingFaceUtils.DynamicCache(), cachePosition: null, positionIds: null,
// useCache: false, outputAttentions: false, outputHiddenStates: false);
var (lastHiddenStates, allSelfAttnsKV) = Qwen2Model(input_ids,
attention_mask, position_ids, pastKeyValues);
var lmHead = F.Math.MatMul(lastHiddenStates,
F.Tensors.Transpose(_constTensors["model.embed_tokens.weight"], new Dimension[] { 1, 0 }));
return Tuple.Create(lmHead, pastKeyValues);
var attentionKVCache = Concat(allSelfAttnsKV.ToArray(), 0);
return Tuple.Create(lmHead, attentionKVCache);
}

private Tuple<Call, HuggingFaceUtils.DynamicCache, List<Call>, List<Call>> Qwen2Model(
Expr input_ids,
Call? inputEmbeds,
HuggingFaceUtils.DynamicCache? pastKeyValues,
Expr? cachePosition,
Call? positionIds,
bool? useCache = false,
bool? outputAttentions = false,
bool? outputHiddenStates = false
// private Tuple<Call, HuggingFaceUtils.DynamicCache, List<Call>, List<Call>> Qwen2Model(
// Expr input_ids,
// Call? inputEmbeds,
// HuggingFaceUtils.DynamicCache? pastKeyValues,
// Expr? cachePosition,
// Call? positionIds,
// bool? useCache = false,
// bool? outputAttentions = false,
// bool? outputHiddenStates = false
// )
private Tuple<Call, List<Call>> Qwen2Model(
Var input_ids,
Var attentionMask,
Var positionIds,
Var pastKeyValues
)
{
/*
* 1.1 embedding
*/
if (inputEmbeds == null)
// if (inputEmbeds == null)
// {
var embedTokensWeight = _constTensors["model.embed_tokens.weight"];
if (_config!.Keys.Contains("pad_token_id"))
{
var embedTokensWeight = _constTensors["model.embed_tokens.weight"];
if (_config!.Keys.Contains("pad_token_id"))
// embedTokensWeight[(int)_config["pad_token_ids"]] = new float[embedTokensWeight.Shape[-1].FixedValue];
for (int i = 0; i < embedTokensWeight.Shape[-1].FixedValue; i++)
{
// embedTokensWeight[(int)_config["pad_token_ids"]] = new float[embedTokensWeight.Shape[-1].FixedValue];
for (int i = 0; i < embedTokensWeight.Shape[-1].FixedValue; i++)
{
embedTokensWeight[(int)_config["pad_token_id"], (int)i] = 0;
}
embedTokensWeight[(int)_config["pad_token_id"], (int)i] = 0;
}

inputEmbeds = Gather(embedTokensWeight, 0, input_ids);
}

if (useCache == true && pastKeyValues == null)
{
pastKeyValues = new HuggingFaceUtils.DynamicCache();
}
var inputEmbeds = Gather(embedTokensWeight, 0, input_ids);
// }

if (cachePosition == null)
{
if (pastKeyValues != null)
{
var pastSeenTokens = pastKeyValues.GetSeqLength();
int sequenceLength =
inputEmbeds.CheckedShape[1].FixedValue; // 假设 inputEmbeds 的第二个维度是 sequenceLength
var cachePositionList = Enumerable.Range(pastSeenTokens, pastSeenTokens + sequenceLength).ToArray();
cachePosition = Tensor.FromArray(cachePositionList);
}
}
// if (useCache == true && pastKeyValues == null)
// {
// pastKeyValues = new HuggingFaceUtils.DynamicCache();
// }

if (positionIds == null)
{
positionIds = Unsqueeze(cachePosition, 0);
}
// if (cachePosition == null)
// {
// if (pastKeyValues != null)
// {
// var pastSeenTokens = pastKeyValues.GetSeqLength();
// int sequenceLength =
// inputEmbeds.CheckedShape[1].FixedValue; // 假设 inputEmbeds 的第二个维度是 sequenceLength
// var cachePositionList = Enumerable.Range(pastSeenTokens, pastSeenTokens + sequenceLength).ToArray();
// cachePosition = Tensor.FromArray(cachePositionList);
// }
// }
//
// if (positionIds == null)
// {
// positionIds = Unsqueeze(cachePosition, 0);
// }

// TODO : _update_causal_mask
// causal_mask = self._update_causal_mask(
// attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
// )
Call? causalMask = null;
// Call? causalMask = null;

var hiddenStates = inputEmbeds;
var positionEmbeddings = this.RotaryEmbedding(hiddenStates, positionIds);
Expand All @@ -194,43 +209,48 @@ public partial class HuggingFaceImporter
*
*/
var decodeLayer = new List<Tuple<Call, Call>>();
for (int i = 0; i < (int)_config["num_hidden_layers"]; i++)
for (int i = 0; i < (int)(long)_config["num_hidden_layers"]; i++)
{
if (outputAttentions == true)
{
allHiddenStates.Add(hiddenStates);
}
// if (outputAttentions == true)
// {
allHiddenStates.Add(hiddenStates);
// }

var (hiddenStatesTmp, selfAttenWeights) = DecodeLayer(i, hiddenStates, causalMask, positionIds,
pastKeyValues, outputAttentions,
useCache, cachePosition, positionEmbeddings);
// var (hiddenStatesTmp, selfAttenWeights) = DecodeLayer(i, hiddenStates, causalMask, positionIds,
// pastKeyValues, outputAttentions,
// useCache, cachePosition, positionEmbeddings);
var (hiddenStatesTmp, selfAttenKV) =
DecodeLayer(i, hiddenStates, attentionMask, pastKeyValues, positionEmbeddings);

hiddenStates = hiddenStatesTmp;

if (outputAttentions == true)
{
allSelfAttns.Add(selfAttenWeights);
}
// if (outputAttentions == true)
// {
allSelfAttns.Add(selfAttenKV);
// }
}

var lastHiddenStates = Qwen2LayerNorm(hiddenStates, "model.norm.weight");
if (outputAttentions == true)
{
allHiddenStates.Add(lastHiddenStates);
}
// if (outputAttentions == true)
// {
// allHiddenStates.Add(lastHiddenStates);
// }

return Tuple.Create(lastHiddenStates, pastKeyValues, allHiddenStates, allSelfAttns);
// return Tuple.Create(lastHiddenStates, pastKeyValues, allHiddenStates, allSelfAttns);
return Tuple.Create(lastHiddenStates, allSelfAttns);
}

private Tuple<Call, Call> DecodeLayer(int count, Call hiddenStates, Call? attentionMask, Call positionIds,
HuggingFaceUtils.DynamicCache pastKeyValues, bool? outputAttentions, bool? useCache, Expr cachePosition,
// private Tuple<Call, Call> DecodeLayer(int count, Call hiddenStates, Call? attentionMask, Call positionIds,
// HuggingFaceUtils.DynamicCache pastKeyValues, bool? outputAttentions, bool? useCache, Expr cachePosition,
// Tuple<Call, Call> positionEmbeddings)
private Tuple<Call, Call> DecodeLayer(int count, Call hiddenStates, Var? attentionMask, Var pastKeyValues,
Tuple<Call, Call> positionEmbeddings)
{
var residual = hiddenStates;
hiddenStates = Qwen2LayerNorm(hiddenStates, $"model.layers.{count}.input_layernorm.weight");

// self attention
var (hiddenStatesTmp, selfAttenWeights) =
var (hiddenStatesTmp, selfAttenKV) =
Qwen2SelfAtten(count, hiddenStates, attentionMask, positionEmbeddings);
hiddenStates = hiddenStatesTmp;
hiddenStates = residual + hiddenStates;
Expand All @@ -242,12 +262,12 @@ private Tuple<Call, Call> DecodeLayer(int count, Call hiddenStates, Call? attent
hiddenStates = residual + hiddenStates;

var output = hiddenStates;
if (outputAttentions == true && selfAttenWeights is not null)
{
return Tuple.Create<Call, Call>(output, selfAttenWeights);
}
// if (outputAttentions == true && selfAttenKV is not null)
// {
return Tuple.Create<Call, Call>(output, selfAttenKV);
// }

return Tuple.Create<Call, Call>(output, null);
// return Tuple.Create<Call, Call>(output, null);
}

private Call Qwen2Mlp(int count, Call hiddenStates)
Expand All @@ -272,21 +292,22 @@ private Call Qwen2LayerNorm(Call hiddenStates, string layerName)

// Qwen2Attention : SelfAtten
// llama config find in : https://www.restack.io/p/transformer-models-answer-llama-config-json-cat-ai
private Tuple<Call, Call> Qwen2SelfAtten(int count, Call hiddenStates, Call attentionMask,
private Tuple<Call, Call> Qwen2SelfAtten(int count, Call hiddenStates, Var attentionMask,
Tuple<Call, Call> positionEmbeddings)
{
int hidden_dim = (int)_config!["hidden_size"] / (int)_config["num_attention_heads"];
var hidden_dim = (int)(long)_config!["hidden_size"] / (int)(long)_config["num_attention_heads"];
if (_config!.Keys.Contains("head_dim"))
{
hidden_dim = (int)_config["head_dim"];
hidden_dim = (int)(long)_config["head_dim"];
}

// bak: 1/21 16:42 dongliang: for 循环提取dims,然后拼起来,以防动态var失效.
var hidden_shape = hiddenStates.CheckedShape.ToList();
hidden_shape.RemoveAt(hidden_shape.Count - 1);
hidden_shape.Add(new Dimension(-1));
var hidden_shape = ShapeOf(hiddenStates);
// hidden_shape.RemoveAt(hidden_shape.Count - 1);
// hidden_shape.Add(-1);
var inputShape = hidden_shape; // inputShape is hiddenStates.shape[:-1].Add(-1)
hidden_shape.Add(new Dimension(hidden_dim));
// hidden_shape.Add(new Dimension(hidden_dim));
// hidden_shape = Concat(new IR.Tuple(hidden_shape, Tensor.FromScalar(hidden_dim)), 0);

var qProjW = _constTensors![$"model.layers.{count}.self_attn.q_proj.weight"];
var qProjB = _constTensors![$"model.layers.{count}.self_attn.q_proj.bias"];
Expand Down Expand Up @@ -318,15 +339,16 @@ private Tuple<Call, Call> Qwen2SelfAtten(int count, Call hiddenStates, Call atte
attentionMask, /*TODO: 这里可能不需要,如果使用输入*/ 0.0f, false);
hiddenStates = hiddenStatesTmp;

inputShape.Add(-1);
hiddenStates = IR.F.Tensors.Reshape(hiddenStates, new Shape(inputShape));
// inputShape.Add(-1);
inputShape = Concat(new IR.Tuple(inputShape, Tensor.FromScalar(-1)), 0);
hiddenStates = IR.F.Tensors.Reshape(hiddenStates, inputShape);
var oProjW = _constTensors![$"model.layers.{count}.self_attn.o_proj.weight"];
hiddenStates = F.Math.MatMul(hiddenStates, oProjW);

return Tuple.Create(hiddenStates, selfAttenWeight);
}

private Tuple<Call, Call> SdpaAttention(Call queryStates, Call keyStates, Call valueStates, Call? attentionMask,
private Tuple<Call, Call> SdpaAttention(Call queryStates, Call keyStates, Call valueStates, Expr? attentionMask,
float? scaling, bool? isCausal)
{
/*
Expand Down Expand Up @@ -408,22 +430,25 @@ private Tuple<Call, Call> ApplyRotaryPosEmb(Call q, Call k, Call cos, Call sin)
return Tuple.Create(qEmbed, kEmbed);
}

private Tuple<Call, Call> RotaryEmbedding(Call x, Call positionIds)
private Tuple<Call, Call> RotaryEmbedding(Expr x, Expr positionIds)
{
// rope type not in config, so it is default. :_compute_default_rope_parameters
// if "dynamic" in self.rope_type:
// self._dynamic_frequency_update(position_ids, device=x.device)

var (inv_freq, attentionScaling) = RoPEInit("default");
var a = x.CheckedShape[0];
var invFreqExpanded = Broadcast(
Unsqueeze(
Unsqueeze(Tensor.FromArray(inv_freq.ToArray()), 0), -1),
new Dimension[] { x.CheckedShape[0], inv_freq.Count, 1 });
var positionIdsExpanded = Unsqueeze(positionIds, 1);

var freqs = F.Tensors.Transpose(F.Math.MatMul(invFreqExpanded, positionIdsExpanded),
new Dimension[] { 0, 2, 1 });
// var a = x.CheckedShape[0];

var invFreqExpanded =
Tensor.FromArray(inv_freq
.ToArray()); //Unsqueeze(Unsqueeze(Tensor.FromArray(inv_freq.ToArray()), new[] { 0 }),new[] { -1 });
// var invFreqExpanded = Broadcast(
// inv_freq_tensor,
// new Dimension[] { x.CheckedShape[0], inv_freq.Count, 1 });
var positionIdsExpanded = Unsqueeze(Reshape(positionIds, new[] { -1, 1 }), 1);

var freqs = Binary(BinaryOp.Mul, invFreqExpanded, positionIdsExpanded);
//F.Tensors.Transpose(F.Math.MatMul(invFreqExpanded, positionIdsExpanded),new Dimension[] { 0, 2, 1 });
var emb = F.Tensors.Concat(new IR.Tuple(freqs, freqs), -1);
var cos = F.Math.Unary(UnaryOp.Cos, emb);
var sin = F.Math.Unary(UnaryOp.Sin, emb);
Expand All @@ -447,11 +472,11 @@ private Tuple<List<double>, float> RoPEInit(string type)
// x1 = x[..., : x.shape[-1] // 2]
// x2 = x[..., x.shape[-1] // 2 :]
// return torch.cat((-x2, x1), dim=-1)
private Call RotateHalf(Call x)
private Call RotateHalf(Expr x)
{
var x1 = Slice(x, new[] { 0, 0, 0, 0 }, new[] { -1, -1, -1, x.CheckedShape[-1] / 2 }, new[] { 1, 1, 1, 1 },
var x1 = Slice(x, new Dimension[] { 0, 0, 0, 0 }, new Dimension[] { x.CheckedShape[0], x.CheckedShape[1], x.CheckedShape[2], x.CheckedShape[-1] / 2 }, new[] { 1, 1, 1, 1 },
new[] { 1, 1, 1, 1 });
var x2 = Slice(x, new[] { 0, 0, 0, x.CheckedShape[-1] / 2 }, new[] { -1, -1, -1, -1 }, new[] { 1, 1, 1, 1 },
var x2 = Slice(x, new Dimension[] { 0, 0, 0, x.CheckedShape[-1] / 2 }, new Dimension[] { x.CheckedShape[0], x.CheckedShape[1], x.CheckedShape[2], x.CheckedShape[3] }, new[] { 1, 1, 1, 1 },
new[] { 1, 1, 1, 1 });
return Concat(new[] { Binary(BinaryOp.Mul, x2, -1), x1 }, -1);
}
Expand Down

0 comments on commit 074b509

Please sign in to comment.