diff --git a/src/Cli/Program.cs b/src/Cli/Program.cs index d5bcd85..1b5eb41 100644 --- a/src/Cli/Program.cs +++ b/src/Cli/Program.cs @@ -27,7 +27,7 @@ namespace Pytocs.Cli { public class Program { - private const string usage = + private const string usage = @"Usage: pytocs [options] @@ -37,6 +37,7 @@ pytocs [options] all its subdirectories [default: .] -p, --post-process PPLIST Post process the output using one or more post-processor(s), separated by commas. + -o, --output DIRECTORY Output folder path. -q, --quiet Run with reduced output. "; public static void Main(string[] argv) @@ -64,6 +65,7 @@ public static void Main(string[] argv) var startDir = (string) oStartDir; if (startDir == "." || startDir == "./" || startDir == ".\\") startDir = Directory.GetCurrentDirectory(); + startDir = Path.GetFullPath(startDir); typeAnalysis.Analyze(startDir); typeAnalysis.Finish(); var types = new TypeReferenceTranslator(typeAnalysis.BuildTypeDictionary()); @@ -72,6 +74,10 @@ public static void Main(string[] argv) //{ // Console.WriteLine("{0}: {1} {2}", de.Key, de.Key.Start, de.Value); //} + var outputDir = options.ContainsKey("--output") ? (string) options["--output"] : startDir; + if (outputDir == "." || outputDir == "./" || outputDir == ".\\") + outputDir = Directory.GetCurrentDirectory(); + outputDir = Path.GetFullPath(outputDir); var walker = new DirectoryWalker(fs, startDir, "*.py"); walker.Enumerate(state => @@ -91,13 +97,16 @@ public static void Main(string[] argv) logger.Error("Unable to load {0}.", path); continue; } + string outputPath = Path.ChangeExtension(path, ".py.cs").Replace(startDir, outputDir); + Directory.CreateDirectory(Path.GetDirectoryName(outputPath)!); + xlator.TranslateModuleStatements( module.Body.Statements, types, - Path.ChangeExtension(path, ".py.cs")); + outputPath); } }); - } + } else { if (!options.TryGetValue("", out var oFiles) || @@ -134,13 +143,14 @@ public static void Main(string[] argv) } } - private static IDictionary ParseOptions(string[] args) + private static IDictionary ParseOptions(string[] args) { var result = new Dictionary(); var files = new List(); - for (int i = 0; i < args.Length; ++i) + int i = 0; + while (i < args.Length) { - var arg = args[i]; + var arg = args[i++]; if (!arg.StartsWith('-')) { files = args.Skip(i).ToList(); @@ -159,17 +169,27 @@ private static IDictionary ParseOptions(string[] args) case "-r": case "--recursive": var dirname = "."; - if (i < args.Length - 1) + if (i < args.Length) { - if (!args[i + 1].StartsWith('-')) + if (!args[i].StartsWith('-')) { - ++i; - dirname = args[i]; + dirname = args[i++]; } - break; } result["--recursive"] = dirname; break; + case "-o": + case "--output": + var dirname2 = "."; + if (i < args.Length) + { + if (!args[i].StartsWith('-')) + { + dirname2 = args[i++]; + } + } + result["--output"] = dirname2; + break; } } result[""] = files; diff --git a/src/Extensions/TorchCs/ClassInfo.cs b/src/Extensions/TorchCs/ClassInfo.cs new file mode 100644 index 0000000..c76a703 --- /dev/null +++ b/src/Extensions/TorchCs/ClassInfo.cs @@ -0,0 +1,757 @@ +#region License +// Copyright 2023 ToolGood +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#endregion +using System.Text.RegularExpressions; + +namespace TorchCs +{ + public class ClassFile + { + public string FileName { get; set; } + public string Code { get; set; } + public List ClassInfos { get; set; } + public bool HasChange { get; set; } + public bool LastChange { get; set; } + + public static List LoadFiles(string folder) + { + var files = new List(); + var files2 = Directory.GetFiles(folder, "*.py.cs", SearchOption.AllDirectories); + foreach (var file in files2) { + var text = File.ReadAllText(file); + ClassFile classFile = new ClassFile(); + classFile.FileName = file; + classFile.Code = text; + classFile.ClassInfos = ClassInfo.AnalysisCode(text); + classFile.HasChange = true; + classFile.LastChange = true; + foreach (var item in classFile.ClassInfos) { + item.File = classFile; + } + files.Add(classFile); + } + return files; + } + public Dictionary MatchClassInfo(string code, List classInfos) + { + Dictionary result = new Dictionary(); + var match = Regex.Match(code, @"namespace ([a-zA-Z_][a-zA-Z0-9._]*) "); + if (match.Success) { + var ns = match.Groups[1].Value.Split('.'); + + var ms = Regex.Matches(code, @"using ([a-zA-Z_][a-zA-Z0-9_]*) = ([a-zA-Z_][a-zA-Z0-9_.]*);"); + foreach (Match m in ms) { + var key = m.Groups[1].Value; + var name = m.Groups[2].Value; + var classInfo = classInfos.FirstOrDefault(q => q.FullClassName == name); + if (classInfo != null) { + if (classInfo.File.LastChange) { + result[key] = classInfo; + } + continue; + } + var sp = name.Split("."); + for (int i = 1; i < ns.Length; i++) { + var names = new string[sp.Length + i]; + for (int j = 0; j < i; j++) { + names[j] = ns[j]; + } + for (int j = 0; j < sp.Length; j++) { + names[j + i] = sp[j]; + } + name = string.Join(".", names); + classInfo = classInfos.FirstOrDefault(q => q.FullClassName == name); + if (classInfo != null) { + if (classInfo.File.LastChange) { + result[key] = classInfo; + } + break; + } + } + } + } + return result; + } + + public Dictionary MatchClassInfo(string code, List files) + { + Dictionary result = new Dictionary(); + var match = Regex.Match(code, @"namespace ([a-zA-Z_][a-zA-Z0-9._]*) "); + if (match.Success) { + var ns = match.Groups[1].Value.Split('.'); + + var classInfos = new List(); + foreach (var file in files) { classInfos.AddRange(file.ClassInfos); } + + var ms = Regex.Matches(code, @"using ([a-zA-Z_][a-zA-Z0-9_]*) = ([a-zA-Z_][a-zA-Z0-9_.]*);"); + foreach (Match m in ms) { + var key = m.Groups[1].Value; + var name = m.Groups[2].Value; + var classInfo = classInfos.FirstOrDefault(q => q.FullClassName == name); + if (classInfo != null) { + result[key] = classInfo; + continue; + } + List names = new List(); + names.AddRange(ns); + names.Add(""); + names.Add(""); + var sp = name.Split("."); + for (int i = 0; i < sp.Length; i++) { + names[names.Count - sp.Length + i] = sp[i]; + } + name = string.Join(".", names); + classInfo = classInfos.FirstOrDefault(q => q.FullClassName == name); + if (classInfo != null) { + result[key] = classInfo; + } + } + } + return result; + } + + + } + + + public class ClassInfo + { + private const string classRegex = @"public class ([a-zA-Z_][a-zA-Z0-9_]*)([\s\S]*?)\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + private const string classRegex2 = @"public class {name}([\s\S]*?)\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + + internal ClassFile File { get; set; } + public string FullClassName { get; set; } + public string ClassName { get; set; } + public bool HasForwardMethod { get; set; } + public ClassConstructor Constructor { get; set; } + public List Fields { get; set; } + public List Methods { get; set; } + + public static List AnalysisCode(string code) + { + List classInfos = new List(); + var match = Regex.Match(code, @"namespace ([a-zA-Z_][a-zA-Z0-9._]*) "); + var match2 = Regex.Match(code, @"public static class ([a-zA-Z_][a-zA-Z0-9._]*) "); + var prefix = match.Groups[1].Value + "." + match2.Groups[1].Value; + + var ms = Regex.Matches(code, classRegex); + foreach (Match m in ms) { + ClassInfo classInfo = new ClassInfo(); + + classInfo.FullClassName = prefix + "." + m.Groups[1].Value; + classInfo.ClassName = m.Groups[1].Value; + var bodyCode = m.Groups[3].Value; + classInfo.Constructor = ClassConstructor.AnalysisCode(bodyCode, classInfo.ClassName); + classInfo.Fields = ClassField.AnalysisCode(bodyCode); + classInfo.Methods = ClassMethod.AnalysisCode(bodyCode); + classInfo.HasForwardMethod = classInfo.Methods.Any(q => q.MethodName == "forward"); + + foreach (var item in classInfo.Methods) { + item.ClassInfo = classInfo; + } + classInfos.Add(classInfo); + } + var fclass = classInfos.Where(q => q.HasForwardMethod).Select(q => q.ClassName).ToList(); + foreach (var info in classInfos) { + foreach (var item in info.Fields) { + if (fclass.Contains(item.NewType ?? item.Type)) { + item.HasForwardMethod = true; + } + } + } + return classInfos; + } + public string AddNewField(string code) + { + if (Fields.Any(q => q.IsNewField)) { + code = Regex.Replace(code, classRegex2.Replace("{name}", ClassName), new MatchEvaluator(m => { + var bodyCode = m.Groups[2].Value; + var baseClass = m.Groups[1].Value; + foreach (var field in Fields) { + bodyCode = field.AddNewField(bodyCode); + } + return $"public class {ClassName}{baseClass}{{{bodyCode}}}"; + })); + } + return code; + } + + public string ReplaceCodes(string code) + { + code = Regex.Replace(code, classRegex2.Replace("{name}", ClassName), new MatchEvaluator(m => { + var bodyCode = m.Groups[2].Value; + var baseClass = m.Groups[1].Value; + foreach (var field in Fields) { + bodyCode = field.ReplaceCodes(bodyCode); + } + bodyCode = Constructor.ReplaceCodes(bodyCode); + foreach (var method in Methods) { + bodyCode = method.ReplaceCodes(bodyCode, Fields); + } + return $"public class {ClassName}{baseClass}{{{bodyCode}}}"; + })); + return code; + } + public string ReplaceMethodParamenterType(string code, Dictionary classInfos) + { + code = Regex.Replace(code, classRegex2.Replace("{name}", ClassName), new MatchEvaluator(m => { + var bodyCode = m.Groups[2].Value; + var baseClass = m.Groups[1].Value; + + Dictionary temp = new Dictionary(); + foreach (var field in Fields) { + if (classInfos.ContainsKey(field.NewType ?? field.Type)) { + temp[field.FieldName] = classInfos[field.NewType ?? field.Type]; + } + } + foreach (var method in Methods) { + bodyCode = method.ReplaceMethodParamenterType(bodyCode, temp); + } + return $"public class {ClassName}{baseClass}{{{bodyCode}}}"; + })); + return code; + } + + public string GetMethodParamenterType(string methodName, int paramenterIndex) + { + var method = Methods.FirstOrDefault(q => q.MethodName == methodName); + if (method != null) { + if (paramenterIndex < method.Paramenters.Count) { + var p = method.Paramenters[paramenterIndex]; + return p.NewType ?? p.Type; + } + } + return null; + } + + public override string ToString() + { + return $"class: {ClassName}"; + } + } + public class ClassConstructor + { + private const string constructorRegex = @"public {name}\(([^)]*?)\)(.*?)\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + + public string ClassName { get; set; } + + public List Paramenters { get; set; } + public List Variables { get; set; } + + public static ClassConstructor AnalysisCode(string code, string className) + { + ClassConstructor classConstructor = new ClassConstructor(); + classConstructor.ClassName = className; + var reg = constructorRegex.Replace("{name}", className); + var m = Regex.Match(code, reg); + classConstructor.Paramenters = ClassMethodParamenter.AnalysisCode(m.Groups[1].Value, m.Groups[3].Value); + classConstructor.Variables = ClassMethodVariable.AnalysisCode(m.Groups[3].Value, classConstructor.Paramenters); + return classConstructor; + } + + public string ReplaceCodes(string code) + { + code = Regex.Replace(code, constructorRegex.Replace("{name}", ClassName), new MatchEvaluator(m => { + var ParamenterCode = m.Groups[1].Value; + foreach (var paramenter in Paramenters) { + ParamenterCode = paramenter.ReplaceCodes(ParamenterCode); + } + var BodyCode = m.Groups[3].Value; + foreach (var variable in Variables) { + BodyCode = variable.ReplaceCodes(BodyCode); + } + return $"public {ClassName}({ParamenterCode}){m.Groups[2].Value}{{{BodyCode}}}"; + })); + return code; + } + } + + public class ClassField + { + public string Type { get; set; } + public string NewType { get; set; } + public string FieldName { get; set; } + public bool IsNewField { get; set; } + public bool HasForwardMethod { get; set; } + + public static List AnalysisCode(string code) + { + List classFields = new List(); + HashSet fields = new HashSet(); + var ms = Regex.Matches(code, @"public ([a-zA-Z_][a-zA-Z0-9_\<\>\[\]\?]*) ([a-zA-Z_@][a-zA-Z0-9_]*);"); + foreach (Match match in ms) { + ClassField field = new ClassField(); + field.Type = match.Groups[1].Value; + field.FieldName = match.Groups[2].Value; + fields.Add(field.FieldName); + classFields.Add(field); + } + ms = Regex.Matches(code, @"public ([a-zA-Z_][a-zA-Z0-9_\<\>\[\]\?]*) ([a-zA-Z_@][a-zA-Z0-9_]*) ="); + foreach (Match match in ms) { + ClassField field = new ClassField(); + field.Type = match.Groups[1].Value; + field.FieldName = match.Groups[2].Value; + fields.Add(field.FieldName); + classFields.Add(field); + } + ms = Regex.Matches(code, @"\bthis\.([a-zA-Z_@][a-zA-Z0-9_]*)[ \t\r\n,;)\[]"); + foreach (Match m in ms) { + if (fields.Add(m.Groups[1].Value)) { + ClassField field = new ClassField(); + field.Type = "object"; + field.FieldName = m.Groups[1].Value; + field.IsNewField = true; + classFields.Add(field); + } + } + + var nnMethods = TorchSharpInfo.Instance.nnMethods; + foreach (var method in nnMethods) { + var fieldType = method.ReturnType.Name; + var methodName = method.Name; + if (methodName == "ModuleDict" || methodName == "ModuleList") { continue; } + + var r = $@"this\.(\S+) = nn\.{methodName}\("; + var ms3 = Regex.Matches(code, r); + foreach (Match m in ms3) { + var name = m.Groups[1].Value; + var f = classFields.FirstOrDefault(q => q.FieldName == name); + if (f != null) { f.NewType = fieldType; f.HasForwardMethod = true; } + } + } + + var ms2 = Regex.Matches(code, @"this\.(\S+) = new ([a-zA-Z_][a-zA-Z0-9_]+)\("); + foreach (Match m2 in ms2) { + var name = m2.Groups[1].Value; + var typeName = m2.Groups[2].Value; + var f = classFields.FirstOrDefault(q => q.FieldName == name); + if (f != null) { f.NewType = typeName; } + } + + foreach (var field1 in classFields) { + if (field1.NewType != null) { continue; } + var name = field1.FieldName; + if (code.Contains($"if (this.{name})") || code.Contains($"if (!this.{name})")) { + field1.NewType = "bool"; + } else if (Regex.IsMatch(code, $@"this\.{name} (=|==|!=) (false|true)") || Regex.IsMatch(code, $@"(\(|&& |\|\| )!?this\.{name} (&&|\|\|)") || Regex.IsMatch(code, $@"(&&|\|\|) !?this\.{name}(\)| &&| \|\|)")) { + field1.NewType = "bool"; + } else if (Regex.IsMatch(code, $@"this\.{name} (=|==|!=|\+=) """) || Regex.IsMatch(code, $@"this\.{name}\.(startswith|endswith|upper|lower|replace|strip|lstrip|rstrip)\(")) { + field1.NewType = "string"; + } else if (Regex.IsMatch(code, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) (\d+\.\d+|\d+(\.\d+)?[Ee])")) { + field1.NewType = "double"; + } else if (Regex.IsMatch(code, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+")) { + field1.NewType = "int"; + } else if (Regex.IsMatch(code, $@"this\.{name}\[[^\]]*?TensorIndex\.")) { + field1.NewType = "Tensor"; + } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(optimizer|opt|.*(_optimizer|_opt))$", RegexOptions.IgnoreCase)) { + field1.NewType = "OptimizerHelper"; + } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(scheduler|.*(_scheduler))$", RegexOptions.IgnoreCase)) { + field1.NewType = "LRScheduler"; + } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(dataset|.*_dataset)$", RegexOptions.IgnoreCase)) { + field1.NewType = "Dataset"; + } else if (field1.Type == "object" && Regex.IsMatch(name, @"^(dataloader|loader|.*_loader)$", RegexOptions.IgnoreCase)) { + field1.NewType = "DataLoader"; + } else if (field1.Type == "object" && TorchUtil.isDoubleTypeByName(name)) { + field1.NewType = "double"; + } else if (field1.Type == "object" && TorchUtil.isIntTypeByName(name)) { + field1.NewType = "int"; + } else if (field1.Type == "object" && TorchUtil.isStringTypeByName(name)) { + field1.NewType = "string"; + //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@" [\+\-\*\/] {name}[ ,;)]")) { + // classMethodParamenter.NewType = "double"; + //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@"[(, ]{name} [\+\-\*\/] ")) { + // classMethodParamenter.NewType = "double"; + } else { + var type = TorchSharpInfo.Instance.FindTypeBy_nn(code, "this." + field1.FieldName); + if (type == null) { + type = TorchSharpInfo.Instance.FindTypeBy_torch(code, "this." + field1.FieldName); + } + if (type != null) { + field1.NewType = type; + } + } + } + return classFields; + } + public string AddNewField(string code) + { + if (IsNewField) { + return $"\r\n\t\t\tpublic {NewType ?? Type} {FieldName};{code}"; + } + return code; + } + + public string ReplaceCodes(string code) + { + if (NewType == null || NewType == Type) { return code; } + return code.Replace($"public {Type} {FieldName};", $"public {NewType} {FieldName};"); + } + + public override string ToString() + { + return $"field: {NewType ?? Type} {FieldName}"; + } + + } + + public class ClassMethod + { + private const string methodRegex = @"public (virtual) ([a-zA-Z_][a-zA-Z0-9_\[\]]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) ([a-zA-Z_@][a-zA-Z0-9_]*)\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + private const string methodRegex2 = @"public (virtual|static) ([a-zA-Z_][a-zA-Z0-9_\[\]]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) {name}\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + private const string methodRegex3 = @"public (static) ([a-zA-Z_][a-zA-Z0-9_\[\]]*|Tuple<[a-zA-Z_][a-zA-Z0-9_<>, ]*>|\([^)]+?\)) ([a-zA-Z_@][a-zA-Z0-9_]*)\(([^)]*?)\) ?\{(((?
\()|(?<-BR>\))|(?\{)|(?<-BR2>\})|[^(){}])+)\}"; + // public static int get_q_k(int input_size, int window_size, object stride, object device) + internal ClassInfo ClassInfo { get; set; } + public string MethodName { get; set; } + public string ReturnType { get; set; } + public string NewReturnType { get; set; } + public bool IsForwardMethod { get; set; } + + public List Paramenters { get; set; } = new List(); + public List Variables { get; set; } = new List(); + + public static List AnalysisCodeForStaticMethod(string code) + { + List classMethods = new List(); + var ms = Regex.Matches(code, methodRegex3); + foreach (Match m in ms) { + ClassMethod classMethod = new ClassMethod(); + + classMethod.ReturnType = m.Groups[2].Value; + classMethod.MethodName = m.Groups[3].Value; + classMethod.Paramenters = ClassMethodParamenter.AnalysisCode(m.Groups[4].Value, m.Groups[5].Value); + classMethod.Variables = ClassMethodVariable.AnalysisCode(m.Groups[5].Value, classMethod.Paramenters); + classMethods.Add(classMethod); + } + return classMethods; + } + public static List AnalysisCode(string code) + { + List classMethods = new List(); + var ms = Regex.Matches(code, methodRegex); + foreach (Match m in ms) { + ClassMethod classMethod = new ClassMethod(); + classMethod.ReturnType = m.Groups[2].Value; + classMethod.MethodName = m.Groups[3].Value; + classMethod.Paramenters = ClassMethodParamenter.AnalysisCode(m.Groups[4].Value, m.Groups[5].Value); + classMethod.Variables = ClassMethodVariable.AnalysisCode(m.Groups[5].Value, classMethod.Paramenters); + classMethod.IsForwardMethod = classMethod.MethodName == "forward"; + classMethods.Add(classMethod); + } + return classMethods; + } + + public string ReplaceCodes(string code, List fields = null) + { + code = Regex.Replace(code, methodRegex2.Replace("{name}", MethodName), new MatchEvaluator(m => { + var ParamenterCode = m.Groups[3].Value; + foreach (var paramenter in Paramenters) { + ParamenterCode = paramenter.ReplaceCodes(ParamenterCode); + } + var bodyCode = m.Groups[4].Value; + foreach (var variable in Variables) { + bodyCode = variable.ReplaceCodes(bodyCode); + } + if (fields != null) { + foreach (var field in fields) { + if (field.HasForwardMethod || IsForwardMethod) { + bodyCode = Regex.Replace(bodyCode, @$"\bthis\.{field.FieldName}\(", $"this.{field.FieldName}.forward("); + bodyCode = Regex.Replace(bodyCode, @$"\bthis\.{field.FieldName}(\[([a-zA-Z_][a-zA-Z_0-9]*|\^?[0-9]+)\])\(", $"this.{field.FieldName}$1.forward("); + } + } + } + if (NewReturnType == null) { + if (ReturnType.StartsWith("Tuple<")) { + NewReturnType = ReturnType.Replace("Tuple<", "("); + NewReturnType = NewReturnType.Substring(0, NewReturnType.Length - 1) + ")"; + if (IsForwardMethod) { + NewReturnType = NewReturnType.Replace("object", "Tensor"); + NewReturnType = NewReturnType.Replace("void", "Tensor"); + } + } else if (ReturnType == "void" || ReturnType == "object" || ReturnType == "object[]") { + var ms = Regex.Matches(bodyCode, "return ([^;]*);"); + var max = 0; + foreach (Match item in ms) { + if (item.Groups[1].Value.StartsWith('(')) { + var t = item.Groups[1].Value.Substring(1, item.Groups[1].Value.Length - 2); + var ms2 = TorchUtil.splitParamenters(t); + max = Math.Max(max, ms2.Count); + } else { + max = Math.Max(max, 1); + } + } + if (max == 1) { + NewReturnType = "object"; + var f = ms[0].Value; + if (f.StartsWith("this.")) { + if (fields != null) { + f = f.Substring(5); + var p = fields.FirstOrDefault(q => q.FieldName == f); + if (p != null) { + NewReturnType = p.NewType ?? p.Type; + } + } + } else { + var p = Paramenters.FirstOrDefault(q => q.ParamenterName == f); + if (p != null) { + NewReturnType = p.NewType ?? p.Type; + } + } + } else if (max > 1) { + NewReturnType = "("; + for (int i = 0; i < max; i++) { + if (i > 0) { NewReturnType += ","; } + NewReturnType += "object"; + } + NewReturnType += ")"; + } + if (IsForwardMethod) { + NewReturnType = (NewReturnType ?? ReturnType).Replace("object", "Tensor"); + NewReturnType = NewReturnType.Replace("void", "Tensor"); + } + } + } + return $"public {m.Groups[1].Value} {NewReturnType ?? ReturnType} {MethodName}({ParamenterCode}){{{bodyCode}}}"; + })); + return code; + } + + public string ReplaceMethodParamenterType(string code, Dictionary classInfos) + { + var paramenters = Paramenters.Where(q => q.NewType == null && q.Type == "object").ToList(); + if (paramenters.Count == 0) { return code; } + + code = Regex.Replace(code, methodRegex2.Replace("{name}", MethodName), new MatchEvaluator(m1 => { + var ParamenterCode = m1.Groups[3].Value; + var bodyCode = m1.Groups[4].Value; + + var reg = @"\bthis\.([a-zA-Z_][a-zA-Z_0-9]+)\.([^\(\.]+)\((((?
\()|(?<-BR>\))|[^()])+)\)"; + var ms = Regex.Matches(bodyCode, reg); + foreach (Match m in ms) { + var fieldName = m.Groups[1].Value; + if (classInfos.ContainsKey(fieldName) == false) continue; + var name = m.Groups[2].Value; + var ps = m.Groups[3].Value; + var ps2 = TorchUtil.splitParamenters(ps); + for (int i = paramenters.Count - 1; i >= 0; i--) { + var paramenter = paramenters[i]; + var index = ps2.IndexOf(paramenter.ParamenterName); + if (index >= 0) { + var type = classInfos[fieldName].GetMethodParamenterType(name, index); + if (type != null && type != "object") { + paramenter.NewType = type; + ParamenterCode = paramenter.ReplaceCodes(ParamenterCode); + this.ClassInfo.File.HasChange = true; + paramenters.RemoveAt(i); + } + } + } + } + return $"public {m1.Groups[1].Value} {NewReturnType ?? ReturnType} {MethodName}({ParamenterCode}){{{bodyCode}}}"; + })); + return code; + } + + public override string ToString() + { + return $"method: {NewReturnType ?? ReturnType} {MethodName}"; + } + } + + public class ClassMethodParamenter + { + public string ParamenterName { get; set; } + public string Type { get; set; } + public string NewType { get; set; } + public string DefaultValue { get; set; } + + public static List AnalysisCode(string code, string text) + { + var fieldsRegex = TorchSharpInfo.Instance.TensorFieldRegex; + var methodRegex = TorchSharpInfo.Instance.TensorMethodRegex; + + + List classMethodParamenters = new List(); + if (string.IsNullOrEmpty(code)) { return classMethodParamenters; } + + var strs = Regex.Matches(code, "(.*?) ([a-zA-Z_@][a-zA-Z_0-9]*)( = ([^,]+))?(,|$)"); + + foreach (Match str in strs) { + ClassMethodParamenter classMethodParamenter = new ClassMethodParamenter(); + classMethodParamenters.Add(classMethodParamenter); + classMethodParamenter.Type = str.Groups[1].Value.Trim(); + classMethodParamenter.ParamenterName = str.Groups[2].Value.Trim(); + var name = classMethodParamenter.ParamenterName; + //if (name == "inputs") { + + //} + if (str.Groups[3].Success) { + classMethodParamenter.DefaultValue = str.Groups[4].Value.Trim(); + + if (classMethodParamenter.DefaultValue == "true" || classMethodParamenter.DefaultValue == "false") { + classMethodParamenter.NewType = "bool"; + } else if (classMethodParamenter.DefaultValue.StartsWith("\"")) { + classMethodParamenter.NewType = "string"; + } else if (Regex.IsMatch(classMethodParamenter.DefaultValue, @"\-?\d+\.\d+")) { + classMethodParamenter.NewType = "double"; + } else if (Regex.IsMatch(classMethodParamenter.DefaultValue, @"\-?\d+")) { + classMethodParamenter.NewType = "int"; + } else if (classMethodParamenter.DefaultValue == "null") { + if (Regex.IsMatch(text, @$"{name} = {name} \?\? [a-zA-Z_][a-zA-Z_0-9]* [\+\-\*\/] [a-zA-Z_][a-zA-Z_0-9]*;")) { + classMethodParamenter.NewType = "int?"; + } + } + if (classMethodParamenter.NewType != null) { continue; } + } + if (text.Contains($"if ({name})") || text.Contains($"if (!{name})")) { + classMethodParamenter.NewType = "bool"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name} (=|==|!=) (false|true)") || Regex.IsMatch(text, $@"(\(|&& |\|\| )!?{name} (&&|\|\|)") || Regex.IsMatch(text, $@"(&&|\|\|) !?{name}(\)| &&| \|\|)")) { + classMethodParamenter.NewType = "bool"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name} (=|==|!=|\+=) """) || Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\.(split|startswith|endswith|upper|lower|replace|strip|lstrip|rstrip)\(")) { + classMethodParamenter.NewType = "string"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) (\d+\.\d+|\d+(\.\d+)?[Ee])")) { + classMethodParamenter.NewType = "doulbe"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+")) { + classMethodParamenter.NewType = "int"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\[[^\]]*?TensorIndex\.")) { + classMethodParamenter.NewType = "Tensor"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\.{methodRegex}\(")) { + classMethodParamenter.NewType = "Tensor"; + } else if (Regex.IsMatch(text, $@"(^|[ \t(,;\[]){name}\.{fieldsRegex}[ ,;)\[]")) { + classMethodParamenter.NewType = "Tensor"; + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(label|pred|preds|target|targets|x_enc|x_mark_enc|x_dec|x_mark_dec)$", RegexOptions.IgnoreCase)) { + classMethodParamenter.NewType = "Tensor"; + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(dataset|.*_dataset)$", RegexOptions.IgnoreCase)) { + classMethodParamenter.NewType = "Dataset"; + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(loader|.*_loader)$", RegexOptions.IgnoreCase)) { + classMethodParamenter.NewType = "DataLoader"; + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(optimizer|opt|.*(_optimizer|_opt))$", RegexOptions.IgnoreCase)) { + classMethodParamenter.NewType = "OptimizerHelper"; + } else if (classMethodParamenter.Type == "object" && Regex.IsMatch(name, @"^(scheduler|.*(_scheduler))$", RegexOptions.IgnoreCase)) { + classMethodParamenter.NewType = "LRScheduler"; + } else if (classMethodParamenter.Type == "object" && TorchUtil.isDoubleTypeByName(name)) { + classMethodParamenter.NewType = "double"; + } else if (classMethodParamenter.Type == "object" && TorchUtil.isIntTypeByName(name)) { + if (classMethodParamenter.DefaultValue == "null") { + classMethodParamenter.NewType = "int?"; + } else { + classMethodParamenter.NewType = "int"; + } + } else if (classMethodParamenter.Type == "object" && TorchUtil.isStringTypeByName(name)) { + classMethodParamenter.NewType = "string"; + //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@" [\+\-\*\/] {name}[ ,;)]")) { + // classMethodParamenter.NewType = "double"; + //} else if (classMethodParamenter.Type == "object" && Regex.IsMatch(text, $@"[(, ]{name} [\+\-\*\/] ")) { + // classMethodParamenter.NewType = "double"; + } else { + var type = TorchSharpInfo.Instance.FindTypeBy_nn(text, classMethodParamenter.ParamenterName); + if (type == null) { + type = TorchSharpInfo.Instance.FindTypeBy_torch(text, classMethodParamenter.ParamenterName); + } + if (type != null) { + classMethodParamenter.NewType = type; + } + } + + } + return classMethodParamenters; + } + + public string ReplaceCodes(string code) + { + if (NewType == null || NewType == Type) { return code; } + return Regex.Replace(code, $@"\b{Type} {ParamenterName}\b", $"{NewType} {ParamenterName}"); + } + + public override string ToString() + { + return $"paramenter: {NewType ?? Type} {ParamenterName}"; + } + } + + public class ClassMethodVariable + { + public string Type { get; set; } + public string NewType { get; set; } + public string HiddenType { get; set; } + public string VariableName { get; set; } + + public static List AnalysisCode(string code, List paramenters) + { + List classMethodVariables = new List(); + var texts = code.Split(new char[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries); + + HashSet names = new HashSet(); + names.Add("_"); + foreach (var paramenter in paramenters) { + names.Add(paramenter.ParamenterName); + } + foreach (var text in texts) { + var m = Regex.Match(text, @"^[\t ]*([a-zA-Z_][a-zA-Z0-9_<>\[\]]*) ([a-zA-Z_][a-zA-Z0-9_]*)(;| = )"); + if (m.Success) { + if (names.Add(m.Groups[1].Value)) { + ClassMethodVariable classMethodVariable = new ClassMethodVariable(); + classMethodVariable.Type = m.Groups[1].Value; + classMethodVariable.VariableName = m.Groups[2].Value; + + classMethodVariables.Add(classMethodVariable); + } + continue; + } + m = Regex.Match(text, @"^[ \t]*([a-zA-Z_][a-zA-Z0-9_]*) = "); + if (m.Success) { + if (names.Add(m.Groups[1].Value)) { + ClassMethodVariable classMethodVariable = new ClassMethodVariable(); + classMethodVariable.VariableName = m.Groups[1].Value; + classMethodVariables.Add(classMethodVariable); + } + continue; + } + + m = Regex.Match(text, @"^[ \t]*\(([^)]+)\) = "); + if (m.Success) { + var str = m.Groups[1].Value; + var sp = str.Split(','); + foreach (var sp1 in sp) { + var s = sp1.Trim(); + if (names.Add(s)) { + ClassMethodVariable classMethodVariable = new ClassMethodVariable(); + classMethodVariable.VariableName = m.Groups[1].Value; + classMethodVariables.Add(classMethodVariable); + } + } + continue; + } + } + return classMethodVariables; + } + + public string ReplaceCodes(string code) + { + return code; + //if (Type != null && NewType != Type) { + // code = Regex.Replace(code, $@"\b{Type} {VariableName}", $"{NewType} {VariableName}"); + //} + //return code; + } + + public override string ToString() + { + return $"variable: {NewType ?? Type} {VariableName}"; + } + } + + + +} diff --git a/src/Extensions/TorchCs/Program.cs b/src/Extensions/TorchCs/Program.cs new file mode 100644 index 0000000..c3cfb2d --- /dev/null +++ b/src/Extensions/TorchCs/Program.cs @@ -0,0 +1,83 @@ +#region License +// Copyright 2023 ToolGood +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#endregion +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using TorchSharp.Data; + +namespace TorchCs +{ + public class Program + { + private const string usage = +@"Usage: + TorchCs [options] + +Options: + -d, --dir Convert all files in the directory + -n, --netstandard Generate netstandard.cs file, The parameter has -d or -dir is valid. +"; + static void Main(string[] args) + { + var options = ParseOptions(args); + if (options.Count == 0) { + Console.WriteLine(usage); + return; + } + if (options.ContainsKey("--dir")) { + Console.WriteLine("Conversion directory:" + options["--dir"].ToString()); + TorchUtil.ReplaceFolder(options["--dir"].ToString(), options.ContainsKey("--netstandard")); + } else { + foreach (var item in (List)options[""]) { + Console.WriteLine("Conversion file:" + item); + TorchUtil.ReplaceFile(item); + } + } + if (options.ContainsKey("--netstandard")) { + if (options.ContainsKey("--dir")) { + Console.WriteLine("Generate netstandard.cs file"); + TorchUtil.CreateNetstandardCode(Path.GetDirectoryName(options["--dir"].ToString())); + } + } + Console.WriteLine("Conversion completed!"); + } + + private static IDictionary ParseOptions(string[] args) + { + var result = new Dictionary(); + var files = new List(); + + int index = 0; + while (index < args.Length) { + var arg = args[index++]; + if (!arg.StartsWith('-')) { + files.Add(arg); + + } else if (arg == "-d" || arg == "--dir") { + result["--dir"] = args[index++]; + } else if (arg == "-n" || arg == "--netstandard") { + result["--netstandard"] = true; + } + } + result[""] = files; + return result; + } + + + } +} diff --git a/src/Extensions/TorchCs/Resources/netstandard.cs b/src/Extensions/TorchCs/Resources/netstandard.cs index 51901fb..ff0350f 100644 --- a/src/Extensions/TorchCs/Resources/netstandard.cs +++ b/src/Extensions/TorchCs/Resources/netstandard.cs @@ -16,7 +16,12 @@ using System.Runtime.CompilerServices; using System.Security.Cryptography; using static TorchSharp.torch; +using System.Linq; +using System.Text; +#pragma warning disable IDE1006 // 命名样式 +#pragma warning disable CS8981 // 该类型名称仅包含小写 ascii 字符。此类名称可能会成为该语言的保留值。 +#pragma warning disable CS8632 // 只能在 "#nullable" 注释上下文内的代码中使用可为 null 的引用类型的注释。 namespace System { public static partial class TorchExtension @@ -108,16 +113,147 @@ public static string rstrip(this string str) return str.TrimEnd(); } + public static T copy(this T obj) where T : ICloneable + { + return (T)obj.Clone(); + } public static void append(this ICollection list, T obj) { list.Add(obj); } + public static void remove(this ICollection list, T obj) + { + list.Remove(obj); + } + public static void extend(this ICollection list, params T[] objs) + { + foreach (var obj in objs) { + list.Add(obj); + } + } + public static int count(this ICollection list, T obj) + { + return list.Where(q => q.Equals(obj)).Count(); + } + public static int index(this ICollection list, T obj) + { + var index = -1; + foreach (var item in list) { + index++; + if (item.Equals(obj)) { + return index; + } + } + return -1; + } + public static void reverse(this ICollection list) + { + list = list.Reverse().ToList(); + } + public static void insert(this IList list, int index, T obj) + { + list.Insert(index, obj); + } + public static T pop(this IList list) + { + var last = list[list.Count - 1]; + list.RemoveAt(list.Count - 1); + return last; + } + public static ICollection copy(this ICollection list) + { + var newObj = new List(); + newObj.AddRange(list); + return newObj; + } + public static List copy(this List list) + { + var newObj = new List(); + newObj.AddRange(list); + return newObj; + } + public static ICollection keys(this IDictionary dict) { return dict.Keys; } + public static ICollection values(this IDictionary dict) + { + return dict.Values; + } + public static void clear(this IDictionary dict) + { + dict.Clear(); + } + public static T2 get(this IDictionary dict, T1 key) + { + if (dict.TryGetValue(key, out T2 result)) { + return result; + } + return default(T2); + } + public static T2 get(this IDictionary dict, T1 key, T2 def) + { + if (dict.TryGetValue(key, out T2 result)) { + return result; + } + return def; + } + public static bool has_key(this IDictionary dict, T1 key) + { + return (dict.ContainsKey(key)); + } + public static T2 pop(this IDictionary dict, T1 key) + { + if (dict.TryGetValue(key, out T2 result)) { + dict.Remove(key); + return result; + } + return default(T2); + } + public static T2 pop(this IDictionary dict, T1 key, T2 def) + { + if (dict.TryGetValue(key, out T2 result)) { + dict.Remove(key); + return result; + } + return def; + } + public static (T1, T2) popitem(this IDictionary dict) + { + T1 key = default(T1); + T2 val = default(T2); + foreach (var item in dict) { + key = item.Key; + val = item.Value; + } + if (dict.ContainsKey(key)) { + dict.Remove(key); + } + return (key, val); + } + + public static IDictionary copy(this IDictionary dict) + { + Dictionary copy = new Dictionary(); + foreach (var item in dict) { + copy[item.Key] = item.Value; + } + return copy; + } + public static Dictionary copy(this Dictionary dict) + { + Dictionary copy = new Dictionary(); + foreach (var item in dict) { + copy[item.Key] = item.Value; + } + return copy; + } + + + /// /// Simplify code, similar to python syntax /// python code : B, L = queries.shape @@ -304,6 +440,7 @@ public static string abspath(string path) { return Path.GetDirectoryName(path); } + public static long getsize(string path) { return new FileInfo(path).Length; @@ -334,6 +471,177 @@ public static void sleep(int s) } } + public class PythonFile + { + private System.IO.FileStream fileStream; + private bool bin; + public static PythonFile open(string file, string mode = "+", string encoding = "UTF-8") + { + PythonFile result = new PythonFile(); + + if (mode.Contains("+")) + { + result.fileStream = File.Open(file, FileMode.OpenOrCreate, FileAccess.ReadWrite); + if (mode.Contains("a")) + { + result.fileStream.Seek(0, SeekOrigin.End); + } + } else if (mode.Contains("a")) + { + result.fileStream = File.Open(file, FileMode.OpenOrCreate, FileAccess.Write); + result.fileStream.Seek(0, SeekOrigin.End); + } else if (mode.Contains("w")) + { + result.fileStream = File.Open(file, FileMode.OpenOrCreate, FileAccess.Write); + } else + { + result.fileStream = File.Open(file, FileMode.OpenOrCreate, FileAccess.Read); + } + result.bin = mode.Contains("b"); + return result; + } + public string[] readline(int size = 1) + { + var read = new System.IO.StreamReader(fileStream); + string[] result = new string[size]; + for (int i = 0; i < size; i++) + { + result[i] = read.ReadLine(); + } + read.ReadToEnd(); + return result; + } + public string readline() + { + var read = new System.IO.StreamReader(fileStream); + string result = read.ReadLine(); + read.ReadToEnd(); + return result; + } + public string read() + { + var read = new System.IO.StreamReader(fileStream); + var r = read.Read(); + read.ReadToEnd(); + return ((char) r).ToString(); + } + + public string read(int size = 1) + { + if (size <= 0) + { + var read = new System.IO.StreamReader(fileStream); + var r = read.ReadToEnd(); + read.ReadToEnd(); + return r; + } else + { + var read = new System.IO.StreamReader(fileStream); + StringBuilder stringBuilder = new StringBuilder(); + for (int i = 0; i < size; i++) + { + var r = read.Read(); + stringBuilder.Append((char) r); + } + read.ReadToEnd(); + return stringBuilder.ToString(); + } + } + + public void write(string txt) + { + var write = new System.IO.StreamWriter(fileStream); + write.Write(txt); + write.Close(); + } + + public void write(double num) + { + if (bin) + { + var write = new System.IO.BinaryWriter(fileStream); + write.Write(num); + write.Close(); + } else + { + var write = new System.IO.StreamWriter(fileStream); + write.Write(num.ToString()); + write.Close(); + } + } + public void write(float num) + { + if (bin) + { + var write = new System.IO.BinaryWriter(fileStream); + write.Write(num); + write.Close(); + } else + { + var write = new System.IO.StreamWriter(fileStream); + write.Write(num.ToString()); + write.Close(); + } + } + public void write(int num) + { + if (bin) + { + var write = new System.IO.BinaryWriter(fileStream); + write.Write(num); + write.Close(); + } else + { + var write = new System.IO.StreamWriter(fileStream); + write.Write(num.ToString()); + write.Close(); + } + } + public void write(long num) + { + if (bin) + { + var write = new System.IO.BinaryWriter(fileStream); + write.Write(num); + write.Close(); + } else + { + var write = new System.IO.StreamWriter(fileStream); + write.Write(num.ToString()); + write.Close(); + } + } + + public void seek(int offset, int whence = 0) + { + if (whence == 0) + { + fileStream.Seek(offset, SeekOrigin.Begin); + } else if (whence == 1) + { + fileStream.Seek(offset, SeekOrigin.Current); + } else if (whence == 2) + { + fileStream.Seek(offset, SeekOrigin.End); + } else + { + throw new Exception("whence is error."); + } + } + + public long tell() + { + return fileStream.Position; + } + + public void close() + { + fileStream.Close(); + } + } } +#pragma warning restore CS8632 // 只能在 "#nullable" 注释上下文内的代码中使用可为 null 的引用类型的注释。 +#pragma warning restore CS8981 // 该类型名称仅包含小写 ascii 字符。此类名称可能会成为该语言的保留值。 +#pragma warning restore IDE1006 // 命名样式 \ No newline at end of file diff --git a/src/Extensions/TorchCs/TorchCs.csproj b/src/Extensions/TorchCs/TorchCs.csproj index 27701f2..407feb8 100644 --- a/src/Extensions/TorchCs/TorchCs.csproj +++ b/src/Extensions/TorchCs/TorchCs.csproj @@ -1,9 +1,10 @@  + Exe net6.0 enable - enable + disable @@ -11,7 +12,7 @@ - + diff --git a/src/Extensions/TorchCs/TorchSharpInfo.cs b/src/Extensions/TorchCs/TorchSharpInfo.cs new file mode 100644 index 0000000..eed8e06 --- /dev/null +++ b/src/Extensions/TorchCs/TorchSharpInfo.cs @@ -0,0 +1,207 @@ +#region License +// Copyright 2023 ToolGood +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#endregion +using System.Reflection; +using System.Text.RegularExpressions; +using TorchSharp; + +namespace TorchCs +{ + public class TorchSharpInfo + { + private Dictionary dict=new Dictionary() { + {"Int64","long" }, + {"Int32","int" }, + {"String","string" }, + {"Single","float" }, + {"Double","double" }, + }; + public Type nnType; + public MethodInfo[] nnMethods; + public List nnModelNames; + + public Type torchType; + public MethodInfo[] torchMethods; + + public Type TensorType; + + public MethodInfo[] TensorMethods; + public string TensorFieldRegex; + public string TensorMethodRegex; + + private TorchSharpMethodList nn_methods; + private TorchSharpMethodList torch_methods; + + public static TorchSharpInfo Instance = new TorchSharpInfo(); + + private TorchSharpInfo() + { + nnType = typeof(TorchSharp.torch.nn); + nnMethods = nnType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public); + + nnModelNames = new List(); + foreach (var method in nnMethods) { + if (method.Name == "ModuleDict" || method.Name == "ModuleList") { continue; } + nnModelNames.Add(method.ReturnType.Name); + } + nnModelNames = nnModelNames.Distinct().ToList(); + + TensorType = typeof(TorchSharp.torch.Tensor); + var fields = TensorType.GetFields(); + var properties = TensorType.GetProperties(); + HashSet fs = new HashSet(); + foreach (var fieldInfo in fields) { fs.Add(fieldInfo.Name); } + foreach (var fieldInfo in properties) { fs.Add(fieldInfo.Name); } + fs.Remove("device"); + TensorFieldRegex = "(" + string.Join("|", fs) + ")"; + TensorMethods = TensorType.GetMethods(BindingFlags.Public | BindingFlags.Instance); + fs.Clear(); + foreach (var fieldInfo in TensorMethods) { fs.Add(fieldInfo.Name); } + TensorMethodRegex = "(" + string.Join("|", fs) + ")"; + + var torchType = typeof(TorchSharp.torch); + torchMethods = torchType.GetMethods(System.Reflection.BindingFlags.Static | System.Reflection.BindingFlags.Public); + + nn_methods = new TorchSharpMethodList(nnMethods); + torch_methods = new TorchSharpMethodList(torchMethods); + } + + public string FindTypeBy_nn(string code, string text) + { + var names = nn_methods.Select(q => q.MethodName).Distinct().ToList(); + string reg = $@"\bnn\.({string.Join("|", names)})\((((?
\()|(?<-BR>\))|[^()])+)\)"; + var ms = Regex.Matches(code, reg); + foreach (Match m in ms) { + if (m.Value.Contains(text) == false) { continue; } + var p = m.Groups[2].Value; + var type = FindTypeBy_nn(p, text); + if (type != null) { return type; } + var ps = TorchUtil.splitParamenters(p.Trim()); + if (ps.Contains(text) == false) { continue; } + + var methodName = m.Groups[1].Value; + var methods = nn_methods.Where(q => q.MethodName == methodName).ToList(); + foreach (var method in methods) { + if (method.Check(ps)) { + var index = ps.IndexOf(text); + var pi = method.Paramenters[index]; + if (pi.IsGenericType == false) { + if (dict.TryGetValue(pi.TypeName,out string t)) { + return t; + } + return pi.TypeName; + } + } + } + } + return null; + } + public string FindTypeBy_torch(string code, string text) + { + var names = torch_methods.Select(q => q.MethodName).Distinct().ToList(); + string reg = $@"\btorch\.({string.Join("|", names)})\((((?
\()|(?<-BR>\))|[^()])+)\)"; + var ms = Regex.Matches(code, reg); + foreach (Match m in ms) { + if (m.Value.Contains(text) == false) { continue; } + var p = m.Groups[2].Value; + var type = FindTypeBy_nn(p, text); + if (type != null) { return type; } + var ps = TorchUtil.splitParamenters(p.Trim()); + if (ps.Contains(text) == false) { continue; } + + var methodName = m.Groups[1].Value; + var methods = torch_methods.Where(q => q.MethodName == methodName).ToList(); + foreach (var method in methods) { + if (method.Check(ps)) { + var index = ps.IndexOf(text); + var pi = method.Paramenters[index]; + if (pi.IsGenericType == false) { + if (dict.TryGetValue(pi.TypeName, out string t)) { + return t; + } + return pi.TypeName; + } + } + } + } + return null; + } + + + + } + + public class TorchSharpMethodList : List + { + public TorchSharpMethodList(MethodInfo[] methods) + { + foreach (var method in methods) { + Add(new TorchSharpMethod(method)); + } + } + } + + public class TorchSharpMethod + { + public string MethodName { get; set; } + public string ReturnType { get; set; } + public List Paramenters { get; set; } + + public TorchSharpMethod() { } + public TorchSharpMethod(MethodInfo methodInfo) + { + MethodName = methodInfo.Name; + ReturnType = methodInfo.ReturnType.Name; + Paramenters = new List(); + var ps = methodInfo.GetParameters(); + for (int i = 0; i < ps.Length; i++) { + Paramenters.Add(new MethodParamenter(i, ps[i])); + } + } + public bool Check(List ps) + { + if (Paramenters.Count < ps.Count) { return false; } + foreach (var p in ps) { + if (p.Contains(":")) { + var name = p.Substring(0, p.IndexOf(':')); + if (Paramenters.Any(q => q.Name == name) == false) { + return false; + } + } + } + return true; + } + } + + public class MethodParamenter + { + public int Index { get; set; } + public string Name { get; set; } + public string TypeName { get; set; } + public bool IsGenericType { get; set; } + public bool IsOptional { get; set; } + + public MethodParamenter() { } + public MethodParamenter(int index, ParameterInfo parameter) + { + Index = index; + Name = parameter.Name; + TypeName = parameter.ParameterType.Name; + IsOptional = parameter.IsOptional; + IsGenericType = parameter.ParameterType.IsGenericType; + } + } + +} diff --git a/src/Extensions/TorchCs/TorchUtil.cs b/src/Extensions/TorchCs/TorchUtil.cs index 53711de..b74683b 100644 --- a/src/Extensions/TorchCs/TorchUtil.cs +++ b/src/Extensions/TorchCs/TorchUtil.cs @@ -19,6 +19,7 @@ using System.Text.RegularExpressions; using System.Xml.Linq; using TorchSharp; +using static TorchSharp.torch; namespace TorchCs { @@ -30,59 +31,102 @@ public class TorchUtil /// Convert all *.py.cs files in the folder ,Replace grammar rules ///
/// - public static void ReplaceFolder(string folder) + public static void ReplaceFolder(string folder, bool replaceStringToNetstandard = true) { var files = Directory.GetFiles(folder, "*.py.cs", SearchOption.AllDirectories); + HashSet classNames = new HashSet(); foreach (var file in files) { var text = File.ReadAllText(file); - File.WriteAllText(file, ReplaceCodes(text)); + getClassName(text, classNames); } + classNames.Remove("torch"); + classNames.Remove("nn"); + classNames.Remove("F"); + foreach (var file in files) { + var text = File.ReadAllText(file); + File.WriteAllText(file, ReplaceCodes(text, classNames, replaceStringToNetstandard)); + } + + var fileInfos = ClassFile.LoadFiles(folder); + var classInfos = new List(); + foreach (var file in fileInfos) { classInfos.AddRange(file.ClassInfos); } + bool IsChange; + do { + IsChange = false; + foreach (var fileInfo in fileInfos) { + fileInfo.LastChange = fileInfo.HasChange; + fileInfo.HasChange = false; + } + foreach (var fileInfo in fileInfos) { + var dict = fileInfo.MatchClassInfo(fileInfo.Code, classInfos); + foreach (var classInfo in fileInfo.ClassInfos) { + fileInfo.Code = classInfo.ReplaceMethodParamenterType(fileInfo.Code, dict); + } + if (fileInfo.HasChange) { + File.WriteAllText(fileInfo.FileName, fileInfo.Code); + IsChange = true; + } + } + } while (IsChange); } /// /// Convert file, Replace grammar rules /// /// - public static void ReplaceFile(string file) + public static void ReplaceFile(string file, bool replaceStringToNetstandard = false) { var text = File.ReadAllText(file); - File.WriteAllText(file, ReplaceCodes(text)); + File.WriteAllText(file, ReplaceCodes(text, null, replaceStringToNetstandard)); } /// /// Convert code, Replace grammar rules /// /// /// - public static string ReplaceCodes(string text) + public static string ReplaceCodes(string text, HashSet classNames = null, bool replaceToNetstandard = true) { // replace 'self' to 'this' text = Regex.Replace(text, @"\bself\.", "this."); // replace field type - text = Regex.Replace(text, @"(object|void) (\w+ = ""\S+?""[,;)])", "string $2"); - text = Regex.Replace(text, @"(object|void) (\w+ = \d+[,;)])", "int $2"); - text = Regex.Replace(text, @"(object|void) (\w+ = \d+\.\d+[,;)])", "double $2"); - text = Regex.Replace(text, @"(object|void) (\w+ = (true|false)[,;)])", "bool $2"); + text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = ""\S+?""[,;)])", "string $2"); + text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = \d+[,;)])", "int $2"); + text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = (\d+\.\d+|\d+(\.\d+)?[Ee]-?\d+)[,;)])", "double $2"); + text = Regex.Replace(text, @"(object|void|bool|int|double|string) (\w+ = (true|false)[,;)])", "bool $2"); + text = Regex.Replace(text, @"\bvoid ([a-zA-Z_][a-zA-Z0-9_]*[ ,);])", "object $1"); // replace 'd_keys = d_keys or (d_model//n_heads)' to 'd_keys = d_keys ?? d_model / n_heads;' text = Regex.Replace(text, @"([a-zA-Z_0-9]+) = (\1 \|\| (.*?;))", "$1 = $1 ?? $3 //$2"); - + // replace throw new ValueError + text = text.Replace("throw new ValueError(", "throw new ArgumentException("); text = replaceNamespace(text); text = replaceConstructor(text); - text = replaceFieldType(text); + text = replaceListSlice(text); + text = replaceNewClass(text, classNames); text = replaceMethodParameterName(text); + // Replace type by class area and static method area + var classInfos = ClassInfo.AnalysisCode(text); + foreach (var classInfo in classInfos) { + text = classInfo.AddNewField(text); // Add missing fields + text = classInfo.ReplaceCodes(text); + } + // One file is a static class. There are only static methods in the static class, so I will deal with the static methods in the file. + var sss = ClassMethod.AnalysisCodeForStaticMethod(text); + foreach (var item in sss) { + text = item.ReplaceCodes(text); + } + + text = replaceFieldType(text); text = replaceMethodParamenterType(text); text = replaceMathMethod(text); text = replaceStringToEnum(text); text = replaceMethodAlias(text); - text = replaceForwardMethod(text); - text = replaceCallForwardMethod(text); - - text = replaceListSlice(text); - text = replaceTensorList(text); text = replaceIsType(text); - text = replaceStringToNetstandard(text); + if (replaceToNetstandard) { + text = replaceStringToNetstandard(text); + } text = text.Replace("using (var torch.no_grad())", "using (var _no_grad= torch.no_grad())"); text = text.Replace("using (var torch.cuda.amp.autocast())", "using (var _autocast= torch.cuda.amp.autocast())"); @@ -127,6 +171,7 @@ private static string replaceNamespace(string text) text = text.Replace("using optim = torch.optim;", "using optim = TorchSharp.torch.optim;"); text = text.Replace("using DataLoader = torch.utils.data.DataLoader;", "using DataLoader = TorchSharp.torch.utils.data.DataLoader;"); + text = text.Replace("using sys;", ""); text = text.Replace("using math;", ""); text = text.Replace("using os;", ""); text = text.Replace("using time;", ""); @@ -170,15 +215,21 @@ private static string replaceFieldType(string text) } var r = $@"this\.(\S+) = nn\.{methodName}\("; var ms = Regex.Matches(text, r); - if (ms.Count > 0) { - foreach (Match m in ms) { - var name = m.Groups[1].Value; - text = text.Replace($"public object {name};", $"public {fieldType} {name};"); - text = text.Replace($"public void {name};", $"public {fieldType} {name};"); - text = Regex.Replace(text, @$"\bthis\.{name}\(", $"this.{name}.forward("); - } + foreach (Match m in ms) { + var name = m.Groups[1].Value; + text = text.Replace($"public object {name};", $"public {fieldType} {name};"); + text = text.Replace($"public void {name};", $"public {fieldType} {name};"); + text = Regex.Replace(text, @$"\bthis\.{name}\(", $"this.{name}.forward("); } } + var ms2 = Regex.Matches(text, @"this\.(\S+) = new ([a-zA-Z_][a-zA-Z0-9_]+)\("); + foreach (Match m2 in ms2) { + var name = m2.Groups[1].Value; + var typeName = m2.Groups[2].Value; + text = text.Replace($"public object {name};", $"public {typeName} {name};"); + text = text.Replace($"public void {name};", $"public {typeName} {name};"); + } + text = replaceFieldType3(text); text = Regex.Replace(text, @"public (object|void) (\w+_len;)", "public int $2"); @@ -204,7 +255,22 @@ private static string replaceFieldType3(string text) if (ms.Count > 0) { foreach (Match m in ms) { var name = m.Groups[2].Value; - if (text.Contains($"this.{name} = {name};")) { + if (text.Contains($"if (this.{name})") || text.Contains($"if (!this.{name})") || text.Contains($"if (this.{name} == true)") || text.Contains($"if (this.{name} == false)")) { + text = text.Replace($"public object {name};", $"public bool {name};"); + text = text.Replace($"public void {name};", $"public bool {name};"); + } else if (text.Contains($"this.{name} = false") || text.Contains($"this.{name} = true")) { + text = text.Replace($"public object {name};", $"public bool {name};"); + text = text.Replace($"public void {name};", $"public bool {name};"); + } else if (Regex.IsMatch(text, $@"this\.{name} (=|==|!=|\+=) """) || Regex.IsMatch(text, $@"this\.{name}\.(startswith|endswith|upper|lower|replace|strip|lstrip|rstrip)\(")) { + text = text.Replace($"public object {name};", $"public string {name};"); + text = text.Replace($"public void {name};", $"public string {name};"); + } else if (Regex.IsMatch(text, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+\.\d+")) { + text = text.Replace($"public object {name};", $"public doulbe {name};"); + text = text.Replace($"public void {name};", $"public doulbe {name};"); + } else if (Regex.IsMatch(text, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+")) { + text = text.Replace($"public object {name};", $"public int {name};"); + text = text.Replace($"public void {name};", $"public int {name};"); + } else if (text.Contains($"this.{name} = {name};")) { if (Regex.IsMatch(text, @$"int {name}\b")) { text = text.Replace($"public object {name};", $"public int {name};"); text = text.Replace($"public void {name};", $"public int {name};"); @@ -221,21 +287,6 @@ private static string replaceFieldType3(string text) text = text.Replace($"public object {name};", $"public bool {name};"); text = text.Replace($"public void {name};", $"public bool {name};"); } - } else if (text.Contains($"if (this.{name})") || text.Contains($"if (!this.{name})") || text.Contains($"if (this.{name} == true)") || text.Contains($"if (this.{name} == false)")) { - text = text.Replace($"public object {name};", $"public bool {name};"); - text = text.Replace($"public void {name};", $"public bool {name};"); - } else if (text.Contains($"this.{name} = false") || text.Contains($"this.{name} = true")) { - text = text.Replace($"public object {name};", $"public bool {name};"); - text = text.Replace($"public void {name};", $"public bool {name};"); - } else if (Regex.IsMatch(text, $@"this\.{name} (=|==|!=|\+=) """) || Regex.IsMatch(text, $@"this\.{name}\.(startswith|endswith|upper|lower|replace|strip|lstrip|rstrip)\(")) { - text = text.Replace($"public object {name};", $"public string {name};"); - text = text.Replace($"public void {name};", $"public string {name};"); - } else if (Regex.IsMatch(text, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+\.\d+")) { - text = text.Replace($"public object {name};", $"public doulbe {name};"); - text = text.Replace($"public void {name};", $"public doulbe {name};"); - } else if (Regex.IsMatch(text, $@"this\.{name} (=|==|!=|>|<|>=|<=|\+=|\-=|\*=|/=|%=) \d+")) { - text = text.Replace($"public object {name};", $"public int {name};"); - text = text.Replace($"public void {name};", $"public int {name};"); } } @@ -436,89 +487,6 @@ private static string replaceMathMethod(string text) text = Regex.Replace(text, @"\bmath\.inf\b", "double.PositiveInfinity"); return text; } - /// - /// Replace forward method's return type and forward method's parameter type - /// - /// - /// - private static string replaceForwardMethod(string text) - { - text = text.Replace(" Tuple", " (Tensor, Tensor)"); - text = text.Replace(" Tuple forward(", " (Tensor, Tensor) forward("); - text = text.Replace(" object[] forward(", " (Tensor, Tensor) forward("); - text = text.Replace(" Tuple> forward(", " (Tensor, List) forward("); - text = text.Replace(" object forward(", " Tensor forward("); - text = text.Replace(" void forward(", " Tensor forward("); - text = text.Replace(" forward(object x", " forward(Tensor x"); - text = text.Replace(" forward(object t", " forward(Tensor t"); - text = text.Replace(" forward(object queries, object keys, object values", " forward(Tensor queries, Tensor keys, Tensor values"); - return text; - } - /// - /// Replace common forward method calls - /// - /// - /// - private static string replaceCallForwardMethod(string text) - { - text = Regex.Replace(text, @"\bthis\.inner_attention\(", "this.inner_attention.forward("); - text = Regex.Replace(text, @"\bthis\.dropout\(", "this.dropout.forward("); - text = Regex.Replace(text, @"\bthis\.attention\(", "this.attention.forward("); - text = Regex.Replace(text, @"\bthis\.self_attention\(", "this.self_attention.forward("); - text = Regex.Replace(text, @"\bthis\.cross_attention\(", "this.cross_attention.forward("); - text = Regex.Replace(text, @"\bthis\.projection\(", "this.projection.forward("); - text = Regex.Replace(text, @"\bthis\.activation\(", "this.activation.forward("); - text = Regex.Replace(text, @"\bthis\.norm\(", "this.norm.forward("); - text = Regex.Replace(text, @"\bthis\.conv\(", "this.conv.forward("); - text = Regex.Replace(text, @"\bthis\.decomp\(", "this.decomp.forward("); - text = Regex.Replace(text, @"\bthis\.decomp1\(", "this.decomp1.forward("); - text = Regex.Replace(text, @"\bthis\.decomp2\(", "this.decomp2.forward("); - text = Regex.Replace(text, @"\bthis\.decomp3\(", "this.decomp3.forward("); - text = Regex.Replace(text, @"\bthis\.decomp4\(", "this.decomp4.forward("); - text = Regex.Replace(text, @"\bthis\.decomp5\(", "this.decomp5.forward("); - text = Regex.Replace(text, @"\bthis\.conv1\(", "this.conv1.forward("); - text = Regex.Replace(text, @"\bthis\.conv2\(", "this.conv2.forward("); - text = Regex.Replace(text, @"\bthis\.conv3\(", "this.conv3.forward("); - text = Regex.Replace(text, @"\bthis\.conv4\(", "this.conv4.forward("); - text = Regex.Replace(text, @"\bthis\.conv5\(", "this.conv5.forward("); - text = Regex.Replace(text, @"\bthis\.norm1\(", "this.norm1.forward("); - text = Regex.Replace(text, @"\bthis\.norm2\(", "this.norm2.forward("); - text = Regex.Replace(text, @"\bthis\.norm3\(", "this.norm3.forward("); - text = Regex.Replace(text, @"\bthis\.norm4\(", "this.norm4.forward("); - text = Regex.Replace(text, @"\bthis\.norm5\(", "this.norm5.forward("); - - text = Regex.Replace(text, @"\bthis\.downConv\(", "this.downConv.forward("); - text = Regex.Replace(text, @"\bthis\.maxPool\(", "this.maxPool.forward("); - text = Regex.Replace(text, @"\bthis\.avg\(", "this.avg.forward("); - text = Regex.Replace(text, @"\bthis\.layernorm\(", "this.layernorm.forward("); - text = Regex.Replace(text, @"\bthis\.tokenConv\(", "this.tokenConv.forward("); - - text = Regex.Replace(text, @"\bthis\.embedding\(", "this.embedding.forward("); - text = Regex.Replace(text, @"\bthis\.emb\(", "this.emb.forward("); - text = Regex.Replace(text, @"\bthis\.embed\(", "this.embed.forward("); - text = Regex.Replace(text, @"\bthis\.position_embedding\(", "this.position_embedding.forward("); - text = Regex.Replace(text, @"\bthis\.temporal_embedding\(", "this.temporal_embedding.forward("); - text = Regex.Replace(text, @"\bthis\.value_embedding\(", "this.value_embedding.forward("); - - text = Regex.Replace(text, @"\bthis\.month_embed\(", "this.month_embed.forward("); - text = Regex.Replace(text, @"\bthis\.day_embed\(", "this.day_embed.forward("); - text = Regex.Replace(text, @"\bthis\.hour_embed\(", "this.hour_embed.forward("); - text = Regex.Replace(text, @"\bthis\.minute_embed\(", "this.minute_embed.forward("); - text = Regex.Replace(text, @"\bthis\.weekday_embed\(", "this.weekday_embed.forward("); - - text = Regex.Replace(text, @"\bthis\.enc_embedding\(", "this.enc_embedding.forward("); - text = Regex.Replace(text, @"\bthis\.encoder\(", "this.encoder.forward("); - text = Regex.Replace(text, @"\bthis\.dec_embedding\(", "this.dec_embedding.forward("); - text = Regex.Replace(text, @"\bthis\.decoder\(", "this.decoder.forward("); - - text = Regex.Replace(text, @"\bthis\.query_projection\(", "this.query_projection.forward("); - text = Regex.Replace(text, @"\bthis\.key_projection\(", "this.key_projection.forward("); - text = Regex.Replace(text, @"\bthis\.value_projection\(", "this.value_projection.forward("); - text = Regex.Replace(text, @"\bthis\.out_projection\(", "this.out_projection.forward("); - - text = Regex.Replace(text, @"\bthis\.attn\(", "this.attn.forward("); - return text; - } /// /// Replace common Tensor list @@ -545,16 +513,19 @@ private static string replaceTensorList(string text) /// private static string replaceListSlice(string text) { - text = Regex.Replace(text, @"\[([^\[\]]*?)\]", new MatchEvaluator(m => { + text = Regex.Replace(text, @"\[(((?
\[)|(?<-BR>\])|[^\[\]])+)\]", new MatchEvaluator(m => { if (m.Groups[1].Value.Contains(":") == false) { return m.Value; } - var strs = m.Groups[1].Value.Split(','); + var ts = replaceListSlice(m.Groups[1].Value); // recurrence , exclude nesting + var strs = ts.Split(','); List list = new List(); foreach (var str in strs) { if (str.Trim() == "\":\"") { - list.Add("TensorIndex.Ellipsis"); - } else if (str.Trim() == "") { + list.Add("TensorIndex.Colon"); + } else if (str.Trim() == "") { // python code: torch.arange(B)[:, None, ] == torch.arange(B)[:, None] + //list.Add(""); + } else if ( str.Trim() == "null") { list.Add("TensorIndex.Null"); } else if (str.Contains(":")) { var ss = str.Trim().Split(':'); @@ -652,6 +623,136 @@ private static string replaceStringToNetstandard(string text) return text; } + /// + /// Add 'new' word to class initialization + /// + /// + /// + /// + private static string replaceNewClass(string text, HashSet classNames) + { + if (classNames == null) { return text; } + const string classRegex = @"using ([a-zA-Z_@][a-zA-Z0-9_]*) = ([a-zA-Z_@][a-zA-Z0-9_.@]*);"; + + List names = new List(); + var ms = Regex.Matches(text, classRegex); + foreach (Match m in ms) { + if (classNames.Contains(m.Groups[1].Value)) { + names.Add(m.Groups[1].Value); + } + } + if (names.Count == 0) { return text; } + + var namereg = string.Join("|", names); + text = Regex.Replace(text, $@"\b({namereg})\(", "new $1("); + text = Regex.Replace(text, @"\bnew new ", "new "); + return text; + } + /// + /// Get all type names, excluding static classes + /// + /// + /// + private static void getClassName(string text, HashSet classNames) + { + const string classRegex = @"public class ([a-zA-Z_][a-zA-Z0-9_]*)"; + var ms = Regex.Matches(text, classRegex); + foreach (Match m in ms) { + classNames.Add(m.Groups[1].Value); + } + } + /// + /// Split parameter, applicable to method definition and method call + /// + /// + /// + internal static List splitParamenters(string paramenters) + { + bool inText = false; + int bracketLayer = 0; // + + List result = new List(); + var index = 0; + string temp = ""; + while (index < paramenters.Length) { + var c = paramenters[index]; + if (inText) { + temp += c; + if (c == '\\') { + index++; + temp += paramenters[index]; + } else if (c == '"') { + inText = false; + } + } else if (c == '(' || c == '{' || c == '[' || c == '<') { + bracketLayer++; + temp += c; + } else if (c == ')' || c == '}' || c == ']' || c == '>') { + bracketLayer--; + temp += c; + } else if (c == ',' && bracketLayer == 0) { + result.Add(temp.Trim()); + temp = ""; + } else { + temp += c; + } + index++; + } + result.Add(temp.Trim()); + return result; + } + + /// + /// Judge whether it is a Double type according to the parameter name + /// + /// + /// + internal static bool isDoubleTypeByName(string name) + { + if (Regex.IsMatch(name, "^(dropout|lr|lr_step|factor|lr_max|num)$", RegexOptions.IgnoreCase)) { + return true; + } + if (Regex.IsMatch(name, "^.*(_dropout|_factor|_momentum|_lr|_min|_max)$", RegexOptions.IgnoreCase)) { + return true; + } + return false; + } + /// + /// Judge whether it is a Int type according to the parameter name + /// + /// + /// + internal static bool isIntTypeByName(string name) + { + if (Regex.IsMatch(name, "^(channels|index|length|step|epoch|stride|total_steps|d_k|d_v|d_q)$", RegexOptions.IgnoreCase)) { + return true; + } + if (Regex.IsMatch(name, "^.*(_len|_length|_in|_model|_out|_channels|_size|_dims|_count|_index|_epoch|_num|_side)$", RegexOptions.IgnoreCase)) { + return true; + } + if (Regex.IsMatch(name, "^(num_|n_).*$", RegexOptions.IgnoreCase)) { + return true; + } + if (Regex.IsMatch(name, "^.*(_num_|_len_).*$", RegexOptions.IgnoreCase)) { + return true; + } + return false; + } + /// + /// Judge whether it is a String type according to the parameter name + /// + /// + /// + internal static bool isStringTypeByName(string name) + { + if (Regex.IsMatch(name, "^(name|path|dir|file|device)$", RegexOptions.IgnoreCase)) { + return true; + } + if (Regex.IsMatch(name, "^.*(_path|_name|_dir|file|_str|_txt)$", RegexOptions.IgnoreCase)) { + return true; + } + return false; + } }