-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprogramfragmentprocessor.cpp
More file actions
167 lines (134 loc) · 6.17 KB
/
programfragmentprocessor.cpp
File metadata and controls
167 lines (134 loc) · 6.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#include "programfragmentprocessor.h"
ProgramFragmentProcessor::ProgramFragmentProcessor() {}
QJsonObject ProgramFragmentProcessor::processFragment(const QJsonObject& fragmentObj) {
QString language = fragmentObj["language"].toString();
QString code = fragmentObj["code"].toString();
QString action = fragmentObj["action"].toString();
QJsonObject result;
result["status"] = "success";
result["timestamp"] = QDateTime::currentDateTime().toString();
if (action == "validate") {
result["validationResult"] = validateCode(code);
}
else if (action == "extract-structure") {
if (language == "python" && code.contains("torch.nn")) {
result["networkStructure"] = extractPyTorchStructure(code);
} else {
result["error"] = "仅支持解析 PyTorch 代码";
}
}
else {
result["error"] = "不支持的操作: " + action;
}
return result;
}
// 从 PyTorch 代码中提取网络结构
QJsonArray ProgramFragmentProcessor::extractPyTorchStructure(const QString& code) {
QJsonArray layersArray;
// 提取 nn.Linear 层
QRegularExpression linearPattern("nn\\.Linear\\s*\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)");
QRegularExpressionMatchIterator linearMatches = linearPattern.globalMatch(code);
while (linearMatches.hasNext()) {
QRegularExpressionMatch match = linearMatches.next();
int inputSize = match.captured(1).toInt();
int outputSize = match.captured(2).toInt();
QJsonObject layerObj;
layerObj["layerType"] = "Dense";
layerObj["inputSize"] = inputSize;
layerObj["neurons"] = outputSize;
// 提取激活函数
QString activation = extractActivationFunction(code, match.capturedStart());
layerObj["activationFunction"] = activation;
layersArray.append(layerObj);
}
// 提取 Conv2d 层
QRegularExpression convPattern("nn\\.Conv2d\\s*\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*,\\s*kernel_size\\s*=\\s*(\\d+)");
QRegularExpressionMatchIterator convMatches = convPattern.globalMatch(code);
while (convMatches.hasNext()) {
QRegularExpressionMatch match = convMatches.next();
int inChannels = match.captured(1).toInt();
int outChannels = match.captured(2).toInt();
int kernelSize = match.captured(3).toInt();
QJsonObject layerObj;
layerObj["layerType"] = "Conv2d";
layerObj["inputSize"] = inChannels; // 简化处理
layerObj["neurons"] = outChannels; // 简化处理
layerObj["kernelSize"] = kernelSize;
// 提取激活函数
QString activation = extractActivationFunction(code, match.capturedStart());
layerObj["activationFunction"] = activation;
layersArray.append(layerObj);
}
return layersArray;
}
// 从代码中提取激活函数
QString ProgramFragmentProcessor::extractActivationFunction(const QString& code, int startPos) {
QString snippet = code.mid(startPos, 200);
if (snippet.contains("nn.ReLU") || snippet.contains("F.relu")) return "relu";
if (snippet.contains("nn.Sigmoid") || snippet.contains("F.sigmoid")) return "sigmoid";
if (snippet.contains("nn.Softmax") || snippet.contains("F.softmax")) return "softmax";
if (snippet.contains("nn.Tanh") || snippet.contains("F.tanh")) return "tanh";
return "";
}
// 简单的代码验证
QJsonObject ProgramFragmentProcessor::validateCode(const QString& code) {
QJsonObject validationResult;
validationResult["valid"] = true;
QJsonArray errors;
// 基本语法检查
if (code.trimmed().isEmpty()) {
errors.append("代码不能为空");
validationResult["valid"] = false;
validationResult["errors"] = errors;
return validationResult;
}
// 检查Python缩进错误 (简单检查:每行缩进必须是4的倍数)
QStringList lines = code.split('\n');
QRegularExpression indentRegex("^(\\s*)");
for (int i = 0; i < lines.size(); i++) {
QString line = lines[i].trimmed();
if (line.isEmpty() || line.startsWith("#")) continue; // 跳过空行和注释
// 获取当前行的缩进空格数
QRegularExpressionMatch match = indentRegex.match(lines[i]);
int indent = match.captured(1).length();
// 检查缩进是否是4的倍数
if (indent % 4 != 0) {
errors.append(QString("第 %1 行:缩进必须是4个空格的倍数").arg(i + 1));
validationResult["valid"] = false;
}
// 简单的缩进层次检查
if (line.endsWith(":")) {
// 冒号行后应该增加缩进
if (i + 1 < lines.size()) {
QRegularExpressionMatch nextMatch = indentRegex.match(lines[i + 1]);
int nextIndent = nextMatch.captured(1).length();
if (nextIndent <= indent) {
errors.append(QString("第 %1 行:冒号后需要增加缩进").arg(i + 1));
validationResult["valid"] = false;
}
}
}
}
// 检查常见的PyTorch相关错误
if (code.contains("nn.") || code.contains("torch.")) {
// 检查块引用错误
if (!code.contains("import torch") && !code.contains("import torch.nn")) {
errors.append("使用PyTorch模块但未导入torch或torch.nn");
validationResult["valid"] = false;
}
// 检查常见的层定义错误
QRegularExpression layerRegex("self\\.(conv|fc|pool|lstm|gru|rnn)\\d+\\s*=\\s*nn\\.");
QRegularExpressionMatchIterator it = layerRegex.globalMatch(code);
while (it.hasNext()) {
QRegularExpressionMatch match = it.next();
QString layerDef = match.captured(0);
// 检查是否缺少括号
if (!layerDef.contains("(") || !layerDef.contains(")")) {
errors.append("层定义语法错误:缺少括号");
validationResult["valid"] = false;
}
}
}
validationResult["errors"] = errors;
return validationResult;
}