forked from imsatoshi/GeneTrader
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_results.py
79 lines (65 loc) · 3.04 KB
/
plot_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import matplotlib.pyplot as plt
import re
from collections import defaultdict
import csv
# 读取数据文件
with open('fitness_log.txt', 'r') as file:
data = file.read()
# 解析数据
generation_data = defaultdict(lambda: {'fitnesses': [], 'profits': [], 'win_rates': []})
for line in data.split('\n'):
match = re.search(r'Generation: (\d+).+Final Fitness: ([-\d.]+)', line)
if match:
gen = int(match.group(1))
fit = float(match.group(2))
generation_data[gen]['fitnesses'].append(fit)
profit_match = re.search(r'Total Profit %: ([-\d.]+)', line)
if profit_match:
profit = float(profit_match.group(1))
generation_data[gen]['profits'].append(profit)
win_rate_match = re.search(r'Win Rate: ([-\d.]+)', line)
if win_rate_match:
win_rate = float(win_rate_match.group(1))
generation_data[gen]['win_rates'].append(win_rate)
# 提取每代的最大值
generations = sorted(generation_data.keys())
max_fitnesses = [max(generation_data[gen]['fitnesses']) for gen in generations]
max_profits = [max(generation_data[gen]['profits']) for gen in generations]
max_win_rates = [max(generation_data[gen]['win_rates']) for gen in generations]
# 保存数据到CSV文件
with open('genetic_algorithm_results.csv', 'w', newline='') as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(['Generation', 'Max Fitness', 'Max Profit (%)', 'Max Win Rate'])
for gen, fit, profit, win_rate in zip(generations, max_fitnesses, max_profits, max_win_rates):
csvwriter.writerow([gen, fit, profit, win_rate])
# 创建三个子图
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 18))
# 绘制 Fitness vs Generation 散点图和连线
for gen in generations:
ax1.scatter([gen] * len(generation_data[gen]['fitnesses']), generation_data[gen]['fitnesses'], alpha=0.5)
ax1.plot(generations, max_fitnesses, color='red', linewidth=2, label='Max Fitness')
ax1.set_xlabel('Generation')
ax1.set_ylabel('Fitness')
ax1.set_title('Fitness vs Generation')
ax1.legend()
# 绘制 Total Profit vs Generation 散点图和连线
for gen in generations:
ax2.scatter([gen] * len(generation_data[gen]['profits']), generation_data[gen]['profits'], alpha=0.5)
ax2.plot(generations, max_profits, color='red', linewidth=2, label='Max Profit (%)')
ax2.set_xlabel('Generation')
ax2.set_ylabel('Total Profit (%)')
ax2.set_title('Total Profit (%) vs Generation')
ax2.legend()
# 绘制 Win Rate vs Generation 散点图和连线
for gen in generations:
ax3.scatter([gen] * len(generation_data[gen]['win_rates']), generation_data[gen]['win_rates'], alpha=0.5)
ax3.plot(generations, max_win_rates, color='red', linewidth=2, label='Max Win Rate')
ax3.set_xlabel('Generation')
ax3.set_ylabel('Win Rate')
ax3.set_title('Win Rate vs Generation')
ax3.legend()
plt.tight_layout()
# 保存图像到本地
plt.savefig('genetic_algorithm_results.png', dpi=300, bbox_inches='tight')
print("数据已保存到 genetic_algorithm_results.csv")
print("图像已保存到 genetic_algorithm_results.png")