Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghaoqi committed Mar 3, 2025
1 parent 500318d commit a5d14f1
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 21 deletions.
14 changes: 9 additions & 5 deletions src/Nncase.Importer/HuggingFace/HuggingFaceImporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ public sealed partial class HuggingFaceImporter : BaseImporter

private List<Var>? _inputs;
private List<Var>? _outputs;
private Dictionary<string, Var> _dynVarMap = new();
private Dictionary<string, int> _fixVarMap = new();
private Dictionary<string, Var> _dynVarMap ;
private Dictionary<string, int> _fixVarMap ;

private Dictionary<Var, Expr[]> _varMap ;

public HuggingFaceImporter(string huggingFaceDir, CompileSession compileSession)
: base(compileSession)
Expand All @@ -38,15 +40,17 @@ public HuggingFaceImporter(string huggingFaceDir, CompileSession compileSession)

protected override (IEnumerable<Var> Inputs, Dictionary<Var, Expr[]> VarMap) CreateInputs()
{
throw new NotImplementedException();
// throw new NotImplementedException();
switch (_config!["architectures"]!)
{
case "Qwen2ForCausalLM":
Qwen2CreateInputs();
break;
return Qwen2CreateInputs();

default:
throw new NotImplementedException();
}

return (null,null);
}

protected override void ConvertOp()
Expand Down
55 changes: 40 additions & 15 deletions src/Nncase.Importer/HuggingFace/Qwen2.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,44 @@ namespace Nncase.Importer
{
public partial class HuggingFaceImporter
{
private void Qwen2CreateInputs()
protected (IEnumerable<Var> Inputs, Dictionary<Var, Expr[]> VarMap) Qwen2CreateInputs()
{
_inputs = new List<Var>();
_inputs.Add(new Var("input_ids",
new TensorType(DataTypes.Float32, new Shape(Dimension.Unknown, 1, (int)_config!["hidden_size"]))));

// attention mask shape dependent on the pre-process of the model.
_inputs.Add(new Var("attention_mask",
new TensorType(DataTypes.Float32, new Shape(1, 1, Dimension.Unknown, Dimension.Unknown))));
_inputs.Add(new Var("position_ids", new TensorType(DataTypes.Float32, new Shape(1, Dimension.Unknown))));
_inputs.Add(new Var("past_key_values",
new TensorType(DataTypes.Float32, new Shape(24, 2, 1, Dimension.Unknown, 2, 64))));
_dynVarMap = new Dictionary<string, Var>();
var varMap = new Dictionary<string, Var>();

var bucketOptions = CompileSession.CompileOptions.ShapeBucketOptions;
_fixVarMap = bucketOptions.FixVarMap;

// local test set
// _fixVarMap["sequence_length"] = 10;
// _fixVarMap["history_len"] = 0;

if (!_fixVarMap.ContainsKey("sequence_length"))
_dynVarMap["sequence_length"] = new Var("sequence_length", new TensorType(DataTypes.Int32, Shape.Scalar));
if(!_fixVarMap.ContainsKey("history_len"))
_dynVarMap["history_len"] = new Var("history_len", new TensorType(DataTypes.Int32, Shape.Scalar));

var inputIdsShapeExpr = new Expr[] { _dynVarMap["sequence_length"] , 1, (int)_config!["hidden_size"] };
var attentionMaskShapeExpr = new Expr[] { 1, 1, _dynVarMap["sequence_length"], _dynVarMap["sequence_length"] };
var positionIdsShapeExpr = new Expr[] { 1, _dynVarMap["sequence_length"] };
var pastKeyValueShapeExpr = new Expr[] { 24, 2, 1, _dynVarMap["history_len"], 2, 64 };


var inputIds = new Var("input_ids", new TensorType(DataTypes.Float32, new Shape(Dimension.Unknown, 1, (int)_config!["hidden_size"])));
var attentionMask = new Var("attention_mask", new TensorType(DataTypes.Float32, new Shape(1, 1, Dimension.Unknown, Dimension.Unknown)));
var positionIds = new Var("position_ids", new TensorType(DataTypes.Float32, new Shape(1, Dimension.Unknown)));
var pastKeyValue = new Var("past_key_values", new TensorType(DataTypes.Float32, new Shape(24, 2, 1, Dimension.Unknown, 2, 64)));

_inputs.Add(inputIds);
_inputs.Add(attentionMask);
_inputs.Add(positionIds);
_inputs.Add(pastKeyValue);
varMap[inputIds] = inputIdsShapeExpr;
varMap[attentionMask] = attentionMaskShapeExpr;
varMap[positionIds] = positionIdsShapeExpr;
varMap[pastKeyValue] = pastKeyValueShapeExpr;
return (_inputs, varMap);
}

private Tuple<Call, HuggingFaceUtils.DynamicCache> VisitQwen2ForCausalLM()
Expand Down Expand Up @@ -84,11 +110,7 @@ private void Qwen2CreateInputs()
)
*/

/*
* 1. Qwen2Model
*
* 1.1 embedding
*/

var input_ids = new Var();
var (lastHiddenStates, pastKeyValues, allSelfAttns, allHiddenStates) = Qwen2Model(input_ids,
inputEmbeds: null, new HuggingFaceUtils.DynamicCache(), cachePosition: null, positionIds: null,
Expand All @@ -109,6 +131,9 @@ private void Qwen2CreateInputs()
bool? outputHiddenStates = false
)
{
/*
* 1.1 embedding
*/
if (inputEmbeds == null)
{
var embedTokensWeight = _constTensors["model.embed_tokens.weight"];
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Tests/Importer/UnitTestImporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public async Task TestImportNcnn()
[Fact]
public async Task TestImportHuggingFace()
{
var file = "/home/curio/github/Qwen2.5-0.5B-Instruct/"; // TODO: need a relative path!
var file = "/mnt/Qwen2.5-0.5B-Instruct/"; // TODO: need a relative path!
var module = Importers.ImportHuggingFace(file, CompileSession);
await InferShapeAsync(module);
Assert.NotNull(module.Entry);
Expand Down
9 changes: 9 additions & 0 deletions src/Nncase.Tests/packages.lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,14 @@
"resolved": "1.3.0",
"contentHash": "WXnEcGwmXfa8gW9N2MlcaPNUzM3NLMwnAhacbtH554F8YcoXbIkTB+uGa1Aa+9gyb/9JZgYVHnmADgJUKP52nA=="
},
"NETStandard.Library": {
"type": "Transitive",
"resolved": "2.0.3",
"contentHash": "st47PosZSHrjECdjeIzZQbzivYBJFv6P2nv4cj2ypdI204DO+vZ7l5raGMiX4eXMJ53RfOIg+/s4DHVZ54Nu2A==",
"dependencies": {
"Microsoft.NETCore.Platforms": "1.1.0"
}
},
"StyleCop.Analyzers.Unstable": {
"type": "Transitive",
"resolved": "1.2.0.556",
Expand Down Expand Up @@ -926,6 +934,7 @@
"type": "Project",
"dependencies": {
"MethodBoundaryAspect.Fody": "[2.0.149, )",
"NETStandard.Library": "[2.0.3, )",
"Nncase.CodeGen": "[1.0.0, )",
"Nncase.Compiler": "[1.0.0, )",
"Nncase.Core": "[1.0.0, )",
Expand Down

0 comments on commit a5d14f1

Please sign in to comment.