forked from ClausewitzCPU0/SC2AI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathch9_Building_Neural_Network_Training_data.py
311 lines (269 loc) · 12.5 KB
/
ch9_Building_Neural_Network_Training_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
"""
应用深度学习 创建训练数据
"""
import sc2
from sc2 import run_game, maps, Race, Difficulty, position, Result
from sc2.player import Bot, Computer
from sc2.constants import NEXUS, PROBE, PYLON, ASSIMILATOR, GATEWAY, \
CYBERNETICSCORE, STALKER, STARGATE, VOIDRAY, OBSERVER, ROBOTICSFACILITY
import random
import cv2
import numpy as np
import time
HEADLESS = False # HEADLESS=False时,程序不运行可视化部分(intel函数画图部分)常用于linux服务器上训练
class SentdeBot(sc2.BotAI):
def __init__(self):
self.ITERATIONS_PER_MINUTE = 165
self.MAX_WORKERS = 80 # 限制最大农民数
self.do_something_after = 0
self.train_data = [] # 训练数据
def on_end(self, game_result):
print('--- on_end called ---')
print(game_result)
if game_result == Result.Victory: # 只有胜利时才会保存数据
# 文件是乱码的,因为它们是 Numpy 专用的二进制格式后的数据。
# 我们可以使用 np.load() 函数来读取数据就可以正常显示了。
np.save("train_data/{}.npy".format(str(int(time.time()))), np.array(self.train_data))
async def on_step(self, iteration: int): # iteration类似游戏时钟 每分钟165个迭代(待确认)
self.iteration = iteration
await self.scout()
await self.distribute_workers()
await self.build_workers()
await self.build_pylons()
await self.build_assimilators()
await self.expand()
await self.offensive_force_buildings()
await self.build_offensive_force()
await self.attack()
await self.intel()
def random_location_variance(self, enemy_start_location):
"""
随机给出敌方主矿附近的侦查坐标
:param enemy_start_location: 敌方出生点
:return: 侦查坐标
"""
x = enemy_start_location[0]
y = enemy_start_location[1]
x += ((random.randrange(-20, 20)) / 100) * enemy_start_location[0]
y += ((random.randrange(-20, 20)) / 100) * enemy_start_location[1]
if x < 0:
x = 0
if y < 0:
y = 0
if x > self.game_info.map_size[0]:
x = self.game_info.map_size[0]
if y > self.game_info.map_size[1]:
y = self.game_info.map_size[1]
# 无法直接返回xy二维坐标,大概因为游戏是三维的原因。需要用sc2的position转换坐标 注意传入的是tuple
go_to = position.Point2(position.Pointlike((x, y)))
return go_to
async def scout(self):
"""
侦查部分
"""
if len(self.units(OBSERVER)) > 0:
scout = self.units(OBSERVER)[0]
if scout.is_idle:
enemy_location = self.enemy_start_locations[0]
move_to = self.random_location_variance(enemy_location)
print(move_to)
await self.do(scout.move(move_to))
else:
for rf in self.units(ROBOTICSFACILITY).ready.noqueue:
if self.can_afford(OBSERVER) and self.supply_left > 0:
await self.do(rf.train(OBSERVER))
async def intel(self):
"""
原作者随便起的名字,你也可以起名为amd
该函数将游戏运行过程可视化
"""
# print('dir:', dir(self)) # 你总是可以使用dir命令来获取帮助,也可以直接看源码
game_data = np.zeros((self.game_info.map_size[1], self.game_info.map_size[0], 3), np.uint8) # 反转图片像素
# UNIT:[SIZE,(RGB COLOR)]
draw_dict = {
NEXUS: [15, (0, 255, 0)],
PYLON: [3, (20, 235, 0)],
PROBE: [1, (55, 200, 0)],
ASSIMILATOR: [2, (55, 200, 0)],
GATEWAY: [3, (200, 100, 0)],
CYBERNETICSCORE: [3, (150, 150, 0)],
STARGATE: [5, (255, 0, 0)],
VOIDRAY: [3, (255, 100, 0)],
}
# 画出每个单位的位置
for unit_type in draw_dict:
for unit in self.units(unit_type).ready:
pos = unit.position
cv2.circle(game_data, (int(pos[0]), int(pos[1])),
draw_dict[unit_type][0], draw_dict[unit_type][1], -1)
# 画出敌方单位位置
main_base_names = ["nexus", "commandcenter", "hatchery"]
for enemy_building in self.known_enemy_structures:
pos = enemy_building.position
if enemy_building.name.lower() not in main_base_names:
cv2.circle(game_data, (int(pos[0]), int(pos[1])), 5, (200, 50, 212), -1)
for enemy_building in self.known_enemy_structures:
pos = enemy_building.position
if enemy_building.name.lower() in main_base_names:
cv2.circle(game_data, (int(pos[0]), int(pos[1])), 15, (0, 0, 255), -1)
# 区分战斗单位和工作单位
for enemy_unit in self.known_enemy_units:
if not enemy_unit.is_structure:
worker_names = ["probe",
"scv",
"drone"]
# if that unit is a PROBE, SCV, or DRONE... it's a worker
pos = enemy_unit.position
if enemy_unit.name.lower() in worker_names:
cv2.circle(game_data, (int(pos[0]), int(pos[1])), 1, (55, 0, 155), -1)
else:
cv2.circle(game_data, (int(pos[0]), int(pos[1])), 3, (50, 0, 215), -1)
# 画出OB位置,尺寸尽可能小,以突出侦查的重要信息
for obs in self.units(OBSERVER).ready:
pos = obs.position
cv2.circle(game_data, (int(pos[0]), int(pos[1])), 1, (255, 255, 255), -1)
# 追踪当前的人口和资源
line_max = 50
mineral_ratio = self.minerals / 1500
if mineral_ratio > 1.0:
mineral_ratio = 1.0
vespene_ratio = self.vespene / 1500
if vespene_ratio > 1.0:
vespene_ratio = 1.0
population_ratio = self.supply_left / self.supply_cap
if population_ratio > 1.0:
population_ratio = 1.0
plausible_supply = self.supply_cap / 200.0
military_weight = len(self.units(VOIDRAY)) / (self.supply_cap - self.supply_left)
if military_weight > 1.0:
military_weight = 1.0
cv2.line(game_data, (0, 19), (int(line_max * military_weight), 19), (250, 250, 200), 3) # worker/supply ratio
cv2.line(game_data, (0, 15), (int(line_max * plausible_supply), 15), (220, 200, 200),
3) # plausible supply (supply/200.0)
cv2.line(game_data, (0, 11), (int(line_max * population_ratio), 11), (150, 150, 150),
3) # population ratio (supply_left/supply)
cv2.line(game_data, (0, 7), (int(line_max * vespene_ratio), 7), (210, 200, 0), 3) # gas / 1500
cv2.line(game_data, (0, 3), (int(line_max * mineral_ratio), 3), (0, 255, 25), 3) # minerals minerals/1500
# 转换坐标
self.flipped = cv2.flip(game_data, 0) # 翻转
if not HEADLESS:
resized = cv2.resize(self.flipped, dsize=None, fx=2, fy=2)
cv2.imshow('Intel', resized)
cv2.waitKey(1) # 1ms
async def build_workers(self):
"""
选择空闲基地建造农民
noqueue意味着当前建造列表为空
"""
if len(self.units(NEXUS)) * 24 > len(self.units(PROBE)): # 每矿农民补满就不补了
if len(self.units(PROBE)) < self.MAX_WORKERS:
for nexus in self.units(NEXUS).ready.noqueue:
if self.can_afford(PROBE):
await self.do(nexus.train(PROBE))
async def build_pylons(self):
"""
人口空余不足5时造水晶。
"""
if self.supply_left < 5 and not self.already_pending(PYLON):
nexuses = self.units(NEXUS).ready
if nexuses.exists:
if self.can_afford(PYLON):
await self.build(PYLON, near=nexuses.first) # near表示建造地点。后期可以用深度学习优化
async def build_assimilators(self):
"""
建造气矿
"""
for nexus in self.units(NEXUS).ready:
vespenes = self.state.vespene_geyser.closer_than(15.0, nexus)
for vespene in vespenes:
if not self.can_afford(ASSIMILATOR):
break
worker = self.select_build_worker(vespene.position)
if worker is None:
break
if not self.units(ASSIMILATOR).closer_than(1.0, vespene).exists:
await self.do(worker.build(ASSIMILATOR, vespene))
async def expand(self):
"""
何时扩张 简化版
"""
if self.units(NEXUS).amount < (self.iteration / self.ITERATIONS_PER_MINUTE) and self.can_afford(NEXUS):
await self.expand_now()
async def offensive_force_buildings(self):
"""
建造产兵/科技建筑
"""
if self.units(PYLON).ready.exists:
pylon = self.units(PYLON).ready.random
# 建造BY
if self.units(GATEWAY).ready.exists and not self.units(CYBERNETICSCORE):
if self.can_afford(CYBERNETICSCORE) and not self.already_pending(CYBERNETICSCORE):
await self.build(CYBERNETICSCORE, near=pylon)
# 建造1个BG解锁科技即可
elif len(self.units(GATEWAY)) < 1:
if self.can_afford(GATEWAY) and not self.already_pending(GATEWAY):
await self.build(GATEWAY, near=pylon)
# 造VR,准备出OB
if self.units(CYBERNETICSCORE).ready.exists:
if len(self.units(ROBOTICSFACILITY)) < 1:
if self.can_afford(ROBOTICSFACILITY) and not self.already_pending(ROBOTICSFACILITY):
await self.build(ROBOTICSFACILITY, near=pylon)
# 这个VS放的早啊
if self.units(CYBERNETICSCORE).ready.exists:
if len(self.units(STARGATE)) < (self.iteration / self.ITERATIONS_PER_MINUTE):
if self.can_afford(STARGATE) and not self.already_pending(STARGATE):
await self.build(STARGATE, near=pylon)
async def build_offensive_force(self):
"""
建造战斗单位(只要虚空)
"""
for sg in self.units(STARGATE).ready.noqueue:
if self.can_afford(VOIDRAY) and self.supply_left > 0:
await self.do(sg.train(VOIDRAY))
def find_target(self, state):
"""
寻找敌方单位
注意这个函数不是异步的,不用加async
"""
if len(self.known_enemy_units) > 0:
return random.choice(self.known_enemy_units)
elif len(self.known_enemy_structures) > 0:
return random.choice(self.known_enemy_structures)
else:
return self.enemy_start_locations[0]
async def attack(self):
"""
随机做出不同的攻击选择
"""
if len(self.units(VOIDRAY).idle) > 0:
choice = random.randrange(0, 4)
target = False
if self.iteration > self.do_something_after:
if choice == 0:
# no attack
wait = random.randrange(20, 165)
self.do_something_after = self.iteration + wait
elif choice == 1:
# attack_unit_closest_nexus
if len(self.known_enemy_units) > 0:
target = self.known_enemy_units.closest_to(random.choice(self.units(NEXUS)))
elif choice == 2:
# attack enemy structures
if len(self.known_enemy_structures) > 0:
target = random.choice(self.known_enemy_structures)
elif choice == 3:
# attack_enemy_start
target = self.enemy_start_locations[0]
if target:
for vr in self.units(VOIDRAY).idle:
await self.do(vr.attack(target))
y = np.zeros(4) # 表示神经网络的输出,类似 [1,0,0,0],这个list表示不攻击(choice=1)
y[choice] = 1
print(y)
self.train_data.append([y, self.flipped]) # 收集测试数据
def main():
run_game(maps.get("AutomatonLE"), [
Bot(Race.Protoss, SentdeBot()),
Computer(Race.Protoss, Difficulty.Medium)], realtime=False) # realtime设为False可以加速
if __name__ == '__main__':
main()