-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtic_plot.py
92 lines (81 loc) · 2.65 KB
/
tic_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import matplotlib.pyplot as plt
def plot_grid(grid=None, heatmap=None, cmap='jet', clim=(-1,1)):
"""Plot a grid of the table
Parameters
----------
grid : np.array, optional
The grid of moves to plot, should only have values in (0, 1, -1), by default None
heatmap : np.array, optional
The heatmap to plot, by default None
cmap : str, optional
The colormap to use, by default 'jet'
clim : tuple, optional
The color limits to use, by default (-1, 1)
"""
msize = 2000
plt.gca().set_aspect('equal')
plt.xlim((-0.5, 2.5))
plt.ylim((2.5, -0.5))
if grid is not None:
if 1 in grid:
plt.scatter(*np.nonzero(grid.T==1), marker="x", s=msize, c='k')
if -1 in grid:
plt.scatter(*np.nonzero(grid.T==-1), marker="o", s=msize, edgecolors='k', facecolors='none')
if heatmap is not None:
for (i, j), value in np.ndenumerate(heatmap.T):
if not np.isnan(value):
plt.text(i-0.2, j+0.05, '{:.4f}'.format(value))
plt.imshow(heatmap, interpolation='none', aspect='equal', cmap=cmap)
plt.colorbar()
plt.clim(clim)
plt.xticks(np.arange(0.5, 2.5))
plt.yticks(np.arange(0.49, 2.5))
plt.tick_params(left = False,
right = False,
labelleft = False,
labelbottom = False,
bottom = False)
plt.grid()
def render_grid(grid):
"""Prints a ascii representation of the grid
Parameters
----------
grid : np.array
The grid to print
"""
print("\n".join([" ".join(["{:.2f}".format(x) for x in row]) for row in grid]))
value2player = {0: '-', 1: 'X', -1: 'O'}
for i in range(3):
print('|', end='')
for j in range(3):
print(value2player[int(state[i,j])], end=' ' if j<2 else '')
print('|')
print()
if __name__ == "__main__":
state = np.zeros((3, 3))
state[1, 2] = 1
state[2, 1] = -1
render_grid(state)
plot_grid(state, state)
plt.show()
state = np.array([[ 1, 1, 0],
[ 0, -1, 0],
[-1, 0, 1]])
render_grid(state)
plot_grid(state)
plt.show()
state = np.array([0, 0, 1, 0, -1, -1, 0, 0, 0]).reshape((3, 3))
map = np.empty((3, 3))
map[:] = np.NaN
map[state==0] = 0.123456789
map[1, 0] = -0.3456789
map[0, 1] = 0.56789
render_grid(state)
plot_grid(state, map, cmap='jet')
plt.clim(-1, 1)
plt.show()
state = np.array([0, 0, -1, 0, 1, 1, 0, 0, 0]).reshape((3, 3))
render_grid(state)
plot_grid(None, state)
plt.show()