From aac73de77d9dcd599d167634e6ca9871dcc7aeeb Mon Sep 17 00:00:00 2001 From: Kensuke Matsuzaki Date: Sat, 25 Jan 2025 14:27:56 +0900 Subject: [PATCH 1/6] Export onnx --- export_onnx.py | 27 +++++++++++++++++++++++++++ requirements-ort-dml.txt | 3 +++ 2 files changed, 30 insertions(+) create mode 100644 export_onnx.py create mode 100644 requirements-ort-dml.txt diff --git a/export_onnx.py b/export_onnx.py new file mode 100644 index 0000000..4af4b40 --- /dev/null +++ b/export_onnx.py @@ -0,0 +1,27 @@ +import torch + +from nn.utility import load_network + +net = load_network("model/sl-model.bin", False) + +board_size = 9 +input_tensor = torch.rand((1, 6, board_size, board_size), dtype=torch.float32) + +torch.onnx.export( + net, + (input_tensor,), + "sl-model.onnx", + input_names=["input"], + output_names=["policy", "value",], + dynamic_axes={ + "input": { + 0: "batch", + }, + "policy": { + 0: "batch", + }, + "value": { + 0: "batch", + }, + } +) diff --git a/requirements-ort-dml.txt b/requirements-ort-dml.txt new file mode 100644 index 0000000..c1092eb --- /dev/null +++ b/requirements-ort-dml.txt @@ -0,0 +1,3 @@ +onnx +onnxcli +onnxruntime-directml From 1630fe8895a95290d462c735c729fe050be20dc5 Mon Sep 17 00:00:00 2001 From: Kensuke Matsuzaki Date: Sat, 25 Jan 2025 14:28:19 +0900 Subject: [PATCH 2/6] Use ONNXRuntime to inference model --- mcts/tree.py | 3 +-- nn/network/dual_net.py | 13 ++++++----- nn/utility.py | 51 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 7 deletions(-) diff --git a/mcts/tree.py b/mcts/tree.py index 01fb935..1c82229 100644 --- a/mcts/tree.py +++ b/mcts/tree.py @@ -6,7 +6,6 @@ import copy import time import numpy as np -import torch from board.constant import PASS, RESIGN from board.coordinate import Coordinate @@ -277,7 +276,7 @@ def process_mini_batch(self, board: GoBoard, use_logit: bool=False): # pylint: d board (GoBoard): 碁盤の情報。 use_logit (bool): Policyの出力をlogitにするフラグ """ - input_planes = torch.Tensor(np.array(self.batch_queue.input_plane)) + input_planes = np.array(self.batch_queue.input_plane) if use_logit: raw_policy, value_data = self.network.inference_with_policy_logits(input_planes) diff --git a/nn/network/dual_net.py b/nn/network/dual_net.py index bb3a09e..8f551bd 100644 --- a/nn/network/dual_net.py +++ b/nn/network/dual_net.py @@ -1,6 +1,7 @@ """Dual Networkの実装。 """ +import numpy as np from typing import Tuple from torch import nn import torch @@ -78,7 +79,7 @@ def forward_with_softmax(self, input_plane: torch.Tensor) -> Tuple[torch.Tensor, return self.softmax(policy), self.softmax(value) - def inference(self, input_plane: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def inference(self, input_plane: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """前向き伝搬処理を実行する。探索用に使うメソッドのため、デバイス間データ転送も内部処理する。 Args: @@ -87,12 +88,13 @@ def inference(self, input_plane: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens Returns: Tuple[torch.Tensor, torch.Tensor]: Policy, Valueの推論結果。 """ + input_plane = torch.Tensor(input_plane) policy, value = self.forward(input_plane.to(self.device)) - return self.softmax(policy).cpu(), self.softmax(value).cpu() + return self.softmax(policy).detach().cpu().numpy(), self.softmax(value).detach().cpu().numpy() - def inference_with_policy_logits(self, input_plane: torch.Tensor) \ - -> Tuple[torch.Tensor, torch.Tensor]: + def inference_with_policy_logits(self, input_plane: np.ndarray) \ + -> Tuple[np.ndarray, np.ndarray]: """前向き伝搬処理を実行する。Gumbel AlphaZero用の探索に使うメソッドのため、 デバイス間データ転送も内部処理する。 @@ -102,8 +104,9 @@ def inference_with_policy_logits(self, input_plane: torch.Tensor) \ Returns: Tuple[torch.Tensor, torch.Tensor]: Policy, Valueの推論結果。 """ + input_plane = torch.Tensor(input_plane) policy, value = self.forward(input_plane.to(self.device)) - return policy.cpu(), self.softmax(value).cpu() + return policy.detach().cpu().numpy(), self.softmax(value).detach().cpu().numpy() def make_common_blocks(num_blocks: int, num_filters: int) -> torch.nn.Sequential: diff --git a/nn/utility.py b/nn/utility.py index 00a6af1..bb8e90d 100644 --- a/nn/utility.py +++ b/nn/utility.py @@ -4,6 +4,7 @@ import time import torch import numpy as np +import onnxruntime as ort from common.print_console import print_err from nn.network.dual_net import DualNet @@ -146,6 +147,8 @@ def load_network(model_file_path: str, use_gpu: bool) -> DualNet: Returns: DualNet: パラメータロード済みのニューラルネットワーク。 """ + if model_file_path.endswith(".onnx"): + return OrtWrapper(model_file_path, use_gpu) device = get_torch_device(use_gpu=use_gpu) network = DualNet(device) network.to(device) @@ -157,3 +160,51 @@ def load_network(model_file_path: str, use_gpu: bool) -> DualNet: torch.set_grad_enabled(False) return network + +class OrtWrapper: + def __init__(self, model_file_path: str, use_gpu: bool): + providers = [] + if use_gpu: + providers += [ + "DmlExecutionProvider", + "CUDAExecutionProvider", + ] + providers.append("CPUExecutionProvider") + self.ort_sess = ort.InferenceSession(model_file_path, providers=providers) + + + def _run(self, input): + outputs = self.ort_sess.run(None, {'input': input}) + return outputs[0], outputs[1] + + + def _softmax(self, x): + return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True) + + + def inference(self, input_plane: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """前向き伝搬処理を実行する。探索用に使うメソッドのため、デバイス間データ転送も内部処理する。 + + Args: + input_plane (np.ndarray): 入力特徴テンソル。 + + Returns: + Tuple[np.ndarray, np.ndarray]: Policy, Valueの推論結果。 + """ + policy, value = self._run(input_plane) + return self._softmax(policy), self._softmax(value) + + + def inference_with_policy_logits(self, input_plane: np.ndarray) \ + -> Tuple[np.ndarray, np.ndarray]: + """前向き伝搬処理を実行する。Gumbel AlphaZero用の探索に使うメソッドのため、 + デバイス間データ転送も内部処理する。 + + Args: + input_plane (np.ndarray): 入力特徴テンソル。 + + Returns: + Tuple[np.ndarray, np.ndarray]: Policy, Valueの推論結果。 + """ + policy, value = self._run(input_plane) + return policy, self._softmax(value) From b6d6a3cbd01ac941b79acb38dbf5073b3c779641 Mon Sep 17 00:00:00 2001 From: Kensuke Matsuzaki Date: Sun, 26 Jan 2025 23:15:22 +0900 Subject: [PATCH 3/6] Add __init__.py --- .gitignore | 4 +++- animation/__init__.py | 0 board/__init__.py | 0 common/__init__.py | 0 graph/__init__.py | 0 gtp/__init__.py | 0 mcts/__init__.py | 0 nn/__init__.py | 0 selfplay/__init__.py | 0 sgf/__init__.py | 0 10 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 animation/__init__.py create mode 100644 board/__init__.py create mode 100644 common/__init__.py create mode 100644 graph/__init__.py create mode 100644 gtp/__init__.py create mode 100644 mcts/__init__.py create mode 100644 nn/__init__.py create mode 100644 selfplay/__init__.py create mode 100644 sgf/__init__.py diff --git a/.gitignore b/.gitignore index 56d6129..c2311d6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,6 @@ **/*.bin **/*.ckpt archive/* -.coverage \ No newline at end of file +.coverage +.eggs/ +*.egg-info/ diff --git a/animation/__init__.py b/animation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/board/__init__.py b/board/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/graph/__init__.py b/graph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gtp/__init__.py b/gtp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mcts/__init__.py b/mcts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nn/__init__.py b/nn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/selfplay/__init__.py b/selfplay/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sgf/__init__.py b/sgf/__init__.py new file mode 100644 index 0000000..e69de29 From a67f1bebd3c5a779994ca6a606d3ac800e64df3c Mon Sep 17 00:00:00 2001 From: Kensuke Matsuzaki Date: Sun, 26 Jan 2025 23:41:08 +0900 Subject: [PATCH 4/6] Fix mypy errors --- board/coordinate.py | 4 +-- board/go_board.py | 20 ++++++------- board/handicap.py | 4 +-- board/pattern.py | 12 ++++---- board/record.py | 18 ++++++------ board/string.py | 52 ++++++++++++++++----------------- board/zobrist_hash.py | 4 +-- common/print_console.py | 6 ++-- get_final_status.py | 11 ++++--- gtp/client.py | 60 +++++++++++++++++++------------------- gtp/gogui.py | 10 +++---- mcts/batch_data.py | 4 +-- mcts/dump.py | 8 ++--- mcts/node.py | 50 ++++++++++++++++--------------- mcts/sequential_halving.py | 8 ++--- mcts/time_manager.py | 16 +++++----- mcts/tree.py | 24 +++++++-------- nn/data_generator.py | 10 +++---- nn/feature.py | 6 ++-- nn/learn.py | 9 +++--- nn/policy_player.py | 3 +- nn/utility.py | 14 ++++----- selfplay/worker.py | 2 +- sgf/reader.py | 14 ++++++--- sgf/selfplay_record.py | 14 ++++----- 25 files changed, 197 insertions(+), 186 deletions(-) diff --git a/board/coordinate.py b/board/coordinate.py index d0b1f28..6d2c8ed 100644 --- a/board/coordinate.py +++ b/board/coordinate.py @@ -38,9 +38,9 @@ def convert_from_gtp_format(self, pos: str) -> int: x_coord = i y_coord = self.board_size - int(pos[1:]) - pos = x_coord + OB_SIZE + (y_coord + OB_SIZE) * self.board_size_with_ob + p = x_coord + OB_SIZE + (y_coord + OB_SIZE) * self.board_size_with_ob - return pos + return p def convert_to_gtp_format(self, pos: int) -> str: """プログラム内部の座標からGTP形式に変換する。 diff --git a/board/go_board.py b/board/go_board.py index 725d2bc..c7a9c36 100644 --- a/board/go_board.py +++ b/board/go_board.py @@ -1,6 +1,6 @@ """碁盤のデータ定義と操作処理。 """ -from typing import List, Tuple, NoReturn +from typing import Deque, List, Tuple from collections import deque import numpy as np @@ -106,7 +106,7 @@ def get_cross4(pos: int) -> List[int]: self.clear() - def clear(self) -> NoReturn: + def clear(self) -> None: """盤面の初期化 """ self.moves = 1 @@ -128,7 +128,7 @@ def clear(self) -> NoReturn: self.strings.clear() self.record.clear() - def put_stone(self, pos: int, color: Stone) -> NoReturn: + def put_stone(self, pos: int, color: Stone) -> None: """指定された座標に指定された色の石を石を置く。 Args: @@ -184,7 +184,7 @@ def put_stone(self, pos: int, color: Stone) -> NoReturn: self.record.save(self.moves, color, pos, self.positional_hash) self.moves += 1 - def put_handicap_stone(self, pos: int, color: Stone) -> NoReturn: + def put_handicap_stone(self, pos: int, color: Stone) -> None: """指定された座標に指定された色の置き石を置く。 Args: @@ -408,7 +408,7 @@ def get_all_legal_pos(self, color: Stone) -> List[int]: """ return [pos for pos in self.onboard_pos if self.is_legal(pos, color)] - def display(self, sym: int=0) -> NoReturn: + def display(self, sym: int=0) -> None: """盤面を表示する。 """ print_err(self.get_board_string(sym=sym)) @@ -440,7 +440,7 @@ def get_board_string(self, sym: int=0) -> str: return board_string - def display_self_atari(self, color: Stone) -> NoReturn: + def display_self_atari(self, color: Stone) -> None: """アタリに突っ込んだ時に取られる石の数を表示する。取られない場合は0。デバッグ用。 Args: @@ -457,7 +457,7 @@ def display_self_atari(self, color: Stone) -> NoReturn: self_atari_string += '\n' print_err(self_atari_string) - def get_board_size(self) -> NoReturn: + def get_board_size(self) -> int: """碁盤の大きさを取得する。 Returns: @@ -508,7 +508,7 @@ def get_symmetrical_coordinate(self, pos: int, sym: int) -> int: """ return self.sym_map[sym][pos] - def set_komi(self, komi: float) -> NoReturn: + def set_komi(self, komi: float) -> None: """コミを設定する。 Args: @@ -535,7 +535,7 @@ def get_to_move(self) -> Stone: last_move_color, _, _ = self.record.get(self.moves - 1) return Stone.get_opponent_color(last_move_color) - def get_move_history(self) -> List[Tuple[Stone, int, np.array]]: + def get_move_history(self) -> List[Tuple[Stone, int, np.ndarray]]: """着手の履歴を取得する。 Returns: @@ -579,7 +579,7 @@ def count_score(self) -> int: # pylint: disable=R0912 for pos in self.onboard_pos: # pylint: disable=R1702 if board[pos] is Stone.EMPTY: pos_list = [] - pos_queue = deque() + pos_queue: Deque[int] = deque() pos_queue.append(pos) color = Stone.EMPTY while pos_queue: diff --git a/board/handicap.py b/board/handicap.py index a977d99..284ed41 100644 --- a/board/handicap.py +++ b/board/handicap.py @@ -1,6 +1,6 @@ """置き石の座標。 """ -from typing import List +from typing import List, Optional handicap_coordinate_map = { @@ -67,7 +67,7 @@ } -def get_handicap_coordinates(size: int, handicaps: int) -> List[int]: +def get_handicap_coordinates(size: int, handicaps: int) -> Optional[List[str]]: """置き石の座標リストを取得する。 Args: diff --git a/board/pattern.py b/board/pattern.py index 6e9d350..707d444 100644 --- a/board/pattern.py +++ b/board/pattern.py @@ -33,7 +33,7 @@ class Pattern: """配石パターンクラス。 """ - def __init__(self, board_size: int, pos_func: Callable[[int], int]): + def __init__(self, board_size: int, pos_func: Callable[[int, int], int]): """Patternクラスのコンストラクタ。 Args: @@ -99,7 +99,7 @@ def __init__(self, board_size: int, pos_func: Callable[[int], int]): self.clear() - def clear(self) -> NoReturn: + def clear(self) -> None: """周囲の石のパターンを初期状態にする。 """ board_start = OB_SIZE @@ -116,7 +116,7 @@ def clear(self) -> NoReturn: self.pat3[self.POS(board_start, y_pos)] = \ self.pat3[self.POS(board_start, y_pos)] | 0x0cc3 - def remove_stone(self, pos: int) -> NoReturn: + def remove_stone(self, pos: int) -> None: """周囲の石のパターンから石を取り除く。 Args: @@ -125,7 +125,7 @@ def remove_stone(self, pos: int) -> NoReturn: for i, shift in enumerate(self.update_pos): self.pat3[pos + shift] = self.pat3[pos + shift] & pattern_mask[i][0] - def put_stone(self, pos: int, color: Stone) -> NoReturn: + def put_stone(self, pos: int, color: Stone) -> None: """周囲の石のパターンの石を追加する。 Args: @@ -161,7 +161,7 @@ def get_eye_color(self, pos: int) -> Stone: """ return self.eye[self.pat3[pos]] - def display(self, pos: int) -> NoReturn: + def display(self, pos: int) -> None: """指定した座標の周囲の石のパターンを表示する。(デバッグ用) Args: @@ -299,7 +299,7 @@ def get_pat3_symmetry8(pat3: int) -> List[int]: return symmetries -def copy_pattern(dst: Pattern, src: Pattern) -> NoReturn: +def copy_pattern(dst: Pattern, src: Pattern) -> None: """配石パターンのデータをコピーする。 Args: diff --git a/board/record.py b/board/record.py index e8658c1..018f152 100644 --- a/board/record.py +++ b/board/record.py @@ -1,6 +1,6 @@ """着手の履歴の保持。 """ -from typing import NoReturn, Tuple +from typing import Tuple import numpy as np from board.constant import PASS, MAX_RECORDS @@ -19,7 +19,7 @@ def __init__(self): self.hash_value = np.zeros(shape=MAX_RECORDS, dtype=np.uint64) self.handicap_pos = [] - def clear(self) -> NoReturn: + def clear(self) -> None: """データを初期化する。 """ self.color = [Stone.EMPTY] * MAX_RECORDS @@ -27,7 +27,7 @@ def clear(self) -> NoReturn: self.hash_value.fill(0) self.handicap_pos = [] - def save(self, moves: int, color: Stone, pos: int, hash_value: np.array) -> NoReturn: + def save(self, moves: int, color: Stone, pos: int, hash_value: np.ndarray) -> None: """着手の履歴の記録する。 Args: @@ -43,7 +43,7 @@ def save(self, moves: int, color: Stone, pos: int, hash_value: np.array) -> NoRe else: print_err("Cannot save move record.") - def save_handicap(self, pos: int) -> NoReturn: + def save_handicap(self, pos: int) -> None: """置き石の座標を記録する。 Args: @@ -51,7 +51,7 @@ def save_handicap(self, pos: int) -> NoReturn: """ self.handicap_pos.append(pos) - def has_same_hash(self, hash_value: np.array) -> bool: + def has_same_hash(self, hash_value: np.ndarray) -> bool: """同じハッシュ値があるかを確認する。 Args: @@ -60,9 +60,9 @@ def has_same_hash(self, hash_value: np.array) -> bool: Returns: bool: 同じハッシュ値がある場合はTrue、なければFalse。 """ - return np.any(self.hash_value == hash_value) + return np.any(self.hash_value == hash_value).item() - def get(self, moves: int) -> Tuple[Stone, int, np.array]: + def get(self, moves: int) -> Tuple[Stone, int, np.ndarray]: """指定した着手を取得する。 Args: @@ -73,7 +73,7 @@ def get(self, moves: int) -> Tuple[Stone, int, np.array]: """ return (self.color[moves], self.pos[moves], self.hash_value[moves]) - def get_hash_history(self) -> np.array: + def get_hash_history(self) -> np.ndarray: """ハッシュ値の履歴を取得する。 Returns: @@ -82,7 +82,7 @@ def get_hash_history(self) -> np.array: return self.hash_value -def copy_record(dst: Record, src: Record) -> NoReturn: +def copy_record(dst: Record, src: Record) -> None: """着手履歴をコピーする。 Args: diff --git a/board/string.py b/board/string.py index 3f49ba5..1c5517a 100644 --- a/board/string.py +++ b/board/string.py @@ -1,6 +1,6 @@ """連の定義と処理の実装。 """ -from typing import Callable, List, NoReturn +from typing import Callable, List from board.constant import STRING_END, LIBERTY_END, NEIGHBOR_END, OB_SIZE from board.coordinate import Coordinate from board.stone import Stone @@ -15,7 +15,7 @@ def __init__(self, board_size: int): Args: board_size (int): 碁盤のサイズ。 """ - self.color = 0 + self.color = Stone.EMPTY self.libs = 0 self.lib = [0] * ((board_size + 2) ** 2) self.neighbors = 0 @@ -24,7 +24,7 @@ def __init__(self, board_size: int): self.size = 0 self.flag = False - def initialize(self, pos: int, color: Stone) -> NoReturn: + def initialize(self, pos: int, color: Stone) -> None: """連の生成処理。 Args: @@ -68,17 +68,17 @@ def has_neighbor(self, neighbor: int) -> bool: """ return self.neighbor[neighbor] != 0 - def remove(self) -> NoReturn: + def remove(self) -> None: """連を削除する。 """ self.flag = False - def add_stone(self) -> NoReturn: + def add_stone(self) -> None: """連を構成する石の数を1つ増やす。 """ self.size += 1 - def add_size(self, size: int) -> NoReturn: + def add_size(self, size: int) -> None: """連を構成する石の個数を加算する。 Args: @@ -158,7 +158,7 @@ def add_liberty(self, pos: int, head: int) -> int: return pos - def remove_liberty(self, pos: int) -> NoReturn: + def remove_liberty(self, pos: int) -> None: """指定した座標の呼吸点を取り除く。 Args: @@ -175,7 +175,7 @@ def remove_liberty(self, pos: int) -> NoReturn: self.lib[pos] = 0 self.libs -= 1 - def add_neighbor(self, string_id: int) -> NoReturn: + def add_neighbor(self, string_id: int) -> None: """隣接する敵連IDを追加する。 Args: @@ -195,7 +195,7 @@ def add_neighbor(self, string_id: int) -> NoReturn: self.neighbor[neighbor] = string_id self.neighbors += 1 - def remove_neighbor(self, remove_id: int) -> NoReturn: + def remove_neighbor(self, remove_id: int) -> None: """指定した隣接する敵連IDを除去する。 Args: @@ -266,7 +266,7 @@ def __init__(self, board_size: int, pos_func: Callable[[int, int], int], \ self.POS = pos_func # pylint: disable=C0103 self.get_neighbor4 = get_neighbor4 - def clear(self) -> NoReturn: + def clear(self) -> None: """全ての連を削除する。 """ self.string_id = [0] * len(self.string_id) @@ -274,7 +274,7 @@ def clear(self) -> NoReturn: for string in self.string: string.remove() - def remove_liberty(self, pos: int, lib: int) -> NoReturn: + def remove_liberty(self, pos: int, lib: int) -> None: """指定した座標の連の呼吸点を除去する。 Args: @@ -283,7 +283,7 @@ def remove_liberty(self, pos: int, lib: int) -> NoReturn: """ self.string[self.get_id(pos)].remove_liberty(lib) - def remove_string(self, board: List[Stone], remove_pos: int) -> NoReturn: + def remove_string(self, board: List[Stone], remove_pos: int) -> List[int]: """連を盤上から除去する。 Args: @@ -364,7 +364,7 @@ def get_num_liberties(self, pos: int) -> int: """ return self.string[self.get_id(pos)].get_num_liberties() - def make_string(self, board: List[Stone], pos: int, color: Stone) -> NoReturn: + def make_string(self, board: List[Stone], pos: int, color: Stone) -> None: """連を作成する。 Args: @@ -394,7 +394,7 @@ def make_string(self, board: List[Stone], pos: int, color: Stone) -> NoReturn: self.string[string_id].add_neighbor(neighbor_id) self.string[neighbor_id].add_neighbor(string_id) - def _add_stone_to_string(self, string_id: int, pos: int) -> NoReturn: + def _add_stone_to_string(self, string_id: int, pos: int) -> None: """指定した座標を連に追加する。 Args: @@ -416,7 +416,7 @@ def _add_stone_to_string(self, string_id: int, pos: int) -> NoReturn: self.string[string_id].add_stone() - def add_stone(self, board: List[Stone], pos: int, color: Stone, string_id: int) -> NoReturn: + def add_stone(self, board: List[Stone], pos: int, color: Stone, string_id: int) -> None: """連に石を1つ追加する。 Args: @@ -440,8 +440,8 @@ def add_stone(self, board: List[Stone], pos: int, color: Stone, string_id: int) self.string[string_id].add_neighbor(neighbor_id) self.string[neighbor_id].add_neighbor(string_id) - def connect_string(self, board: List[Stone], pos: int, \ - color: Stone, ids: List[int]) -> NoReturn: + def connect_string(self, board: List[Stone], pos: int, + color: Stone, ids: List[int]) -> None: """連を接続する。 Args: @@ -457,7 +457,7 @@ def connect_string(self, board: List[Stone], pos: int, \ if len(unique_ids) > 1: self._merge_string(unique_ids[0], unique_ids[1:]) - def _merge_string(self, dst_id: int, src_ids: List[int]) -> NoReturn: + def _merge_string(self, dst_id: int, src_ids: List[int]) -> None: """複数の連を接続する。 Args: @@ -470,7 +470,7 @@ def _merge_string(self, dst_id: int, src_ids: List[int]) -> NoReturn: self._merge_neighbor(dst_id, src_id) self.string[src_id].remove() - def _merge_stones(self, dst_id: int, src_id: int) -> NoReturn: + def _merge_stones(self, dst_id: int, src_id: int) -> None: """連を構成する石の座標を連結する。 Args: @@ -500,7 +500,7 @@ def _merge_stones(self, dst_id: int, src_id: int) -> NoReturn: self.string[dst_id].add_size(self.string[src_id].get_size()) - def _merge_liberty(self, dst_id: int, src_id: int) -> NoReturn: + def _merge_liberty(self, dst_id: int, src_id: int) -> None: """連が持つ呼吸点の座標を連結する。 Args: @@ -519,7 +519,7 @@ def _merge_liberty(self, dst_id: int, src_id: int) -> NoReturn: self.string[dst_id].libs += 1 src_lib = self.string[src_id].lib[src_lib] - def _merge_neighbor(self, dst_id: int, src_id: int) -> NoReturn: + def _merge_neighbor(self, dst_id: int, src_id: int) -> None: """隣接する敵連IDを連結する。 Args: @@ -545,7 +545,7 @@ def _merge_neighbor(self, dst_id: int, src_id: int) -> NoReturn: neighbor = self.string[src_id].neighbor[neighbor] - def _remove_neighbor_string(self, neighbor_id: int, remove_id: int) -> NoReturn: + def _remove_neighbor_string(self, neighbor_id: int, remove_id: int) -> None: """指定した敵連IDを削除する。 Args: @@ -554,7 +554,7 @@ def _remove_neighbor_string(self, neighbor_id: int, remove_id: int) -> NoReturn: """ self.string[neighbor_id].remove_neighbor(remove_id) - def _add_neighbor(self, neighbor_id: int, add_id: int) -> NoReturn: + def _add_neighbor(self, neighbor_id: int, add_id: int) -> None: """隣接する敵連IDを追加する。 Args: @@ -563,7 +563,7 @@ def _add_neighbor(self, neighbor_id: int, add_id: int) -> NoReturn: """ self.string[neighbor_id].add_neighbor(add_id) - def display(self) -> NoReturn: + def display(self) -> None: """盤上に存在する全ての連の情報を表示する。(デバッグ用) """ coordinate = Coordinate(self.board_size) @@ -597,7 +597,7 @@ def display(self) -> NoReturn: print_err(f"\tNeighbor {len(neighbors)} : {neighbors}") -def copy_string(dst: String, src: String) -> NoReturn: +def copy_string(dst: String, src: String) -> None: """連の情報をコピーする。 Args: @@ -614,7 +614,7 @@ def copy_string(dst: String, src: String) -> NoReturn: dst.flag = src.flag -def copy_strings(dst: StringData, src: StringData) -> NoReturn: +def copy_strings(dst: StringData, src: StringData) -> None: """全ての連の情報をコピーする。ただし、存在しない場合は存在フラグをオフにするだけにする。 Args: diff --git a/board/zobrist_hash.py b/board/zobrist_hash.py index 098cb43..2c66169 100644 --- a/board/zobrist_hash.py +++ b/board/zobrist_hash.py @@ -10,7 +10,7 @@ size=[4, (BOARD_SIZE + OB_SIZE * 2) ** 2], dtype=np.uint64) -def affect_stone_hash(hash_value: np.array, pos: int, color: Stone) -> np.array: +def affect_stone_hash(hash_value: np.ndarray, pos: int, color: Stone) -> np.ndarray: """1つの石のハッシュ値を作用させる。 Args: @@ -24,7 +24,7 @@ def affect_stone_hash(hash_value: np.array, pos: int, color: Stone) -> np.array: return hash_value ^ hash_bit_mask[color.value][pos] -def affect_string_hash(hash_value: np.array, pos_list: List[int], color: Stone) -> np.array: +def affect_string_hash(hash_value: np.ndarray, pos_list: List[int], color: Stone) -> np.ndarray: """複数の石のハッシュ値を作用させる。 Args: diff --git a/common/print_console.py b/common/print_console.py index 3b165d2..66bc4d7 100644 --- a/common/print_console.py +++ b/common/print_console.py @@ -1,9 +1,9 @@ """コンソール出力のラッパー """ -from typing import Any, NoReturn +from typing import Any import sys -def print_out(message: Any) -> NoReturn: +def print_out(message: Any) -> None: """メッセージを標準出力に出力する。 Args: @@ -11,7 +11,7 @@ def print_out(message: Any) -> NoReturn: """ print(message) -def print_err(message: Any) -> NoReturn: +def print_err(message: Any) -> None: """メッセージを標準エラー出力に出力する。 Args: diff --git a/get_final_status.py b/get_final_status.py index 41cbd55..86eb399 100644 --- a/get_final_status.py +++ b/get_final_status.py @@ -6,7 +6,7 @@ import os import math import subprocess -from typing import NoReturn +from typing import List import click WORKER_THREAD = 4 @@ -43,6 +43,9 @@ def get_gnugo_judgment(filename: str, is_japanese_rule: bool) -> str: with subprocess.Popen(gnugo_command, stdin=subprocess.PIPE, \ stdout=subprocess.PIPE, encoding='utf-8') as process: + assert process.stdin is not None + assert process.stdout is not None + process.stdin.write("\n".join(exec_commands)) process.stdin.flush() process.stdout.flush() @@ -64,7 +67,7 @@ def get_gnugo_judgment(filename: str, is_japanese_rule: bool) -> str: return responses[2] -def adjust_by_gnugo_judgment(filename: str) -> NoReturn: +def adjust_by_gnugo_judgment(filename: str) -> None: """_summary_ Args: @@ -88,7 +91,7 @@ def adjust_by_gnugo_judgment(filename: str) -> NoReturn: with open(filename, encoding="utf-8", mode="w") as out_file: out_file.write(adjusted_sgf) -def judgment_worker(kifu_list: str) -> NoReturn: +def judgment_worker(kifu_list: List[str]) -> None: """_summary_ Args: @@ -100,7 +103,7 @@ def judgment_worker(kifu_list: str) -> NoReturn: @click.command() @click.option('--kifu-dir', type=click.STRING, default='archive', help='') -def adjust_result(kifu_dir: str) -> NoReturn: +def adjust_result(kifu_dir: str) -> None: """_summary_ Args: diff --git a/gtp/client.py b/gtp/client.py index fc46824..26634a2 100644 --- a/gtp/client.py +++ b/gtp/client.py @@ -3,7 +3,7 @@ import os import random import sys -from typing import List, NoReturn +from typing import List, NoReturn, Tuple from program import PROGRAM_NAME, VERSION, PROTOCOL_VERSION from board.constant import PASS, RESIGN @@ -115,7 +115,7 @@ def __init__(self, board_size: int, superko: bool, model_file_path: str, \ print_err(f"Failed to load {model_file_path}") - def _known_command(self, command: str) -> NoReturn: + def _known_command(self, command: str) -> None: """known_commandコマンドを処理する。 対応しているコマンドの場合は'true'を表示し、対応していないコマンドの場合は'unknown command'を表示する @@ -127,7 +127,7 @@ def _known_command(self, command: str) -> NoReturn: else: respond_failure("unknown command") - def _list_commands(self) -> NoReturn: + def _list_commands(self) -> None: """list_commandsコマンドを処理する。 対応している全てのコマンドを表示する。 """ @@ -136,7 +136,7 @@ def _list_commands(self) -> NoReturn: response += '\n' + command respond_success(response) - def _komi(self, s_komi: str) -> NoReturn: + def _komi(self, s_komi: str) -> None: """komiコマンドを処理する。 入力されたコミを設定する。 @@ -147,7 +147,7 @@ def _komi(self, s_komi: str) -> NoReturn: self.board.set_komi(komi) respond_success("") - def _play(self, color: str, pos: str) -> NoReturn: + def _play(self, color: str, pos: str) -> None: """playコマンドを処理する。 入力された座標に指定された色の石を置く。 @@ -173,7 +173,7 @@ def _play(self, color: str, pos: str) -> NoReturn: respond_success("") - def _undo(self) -> NoReturn: + def _undo(self) -> None: """undoコマンドを処理する。 """ history = self.board.get_move_history() @@ -187,7 +187,7 @@ def _undo(self) -> NoReturn: respond_success("") - def _genmove(self, color: str) -> NoReturn: + def _genmove(self, color: str) -> None: """genmoveコマンドを処理する。 入力された手番で思考し、着手を生成する。 @@ -231,7 +231,7 @@ def _genmove(self, color: str) -> NoReturn: respond_success(self.coordinate.convert_to_gtp_format(pos)) - def _boardsize(self, size: str) -> NoReturn: + def _boardsize(self, size: str) -> None: """boardsizeコマンドを処理する。 指定したサイズの碁盤に設定する。 @@ -244,7 +244,7 @@ def _boardsize(self, size: str) -> NoReturn: self.time_manager.initialize() respond_success("") - def _clear_board(self) -> NoReturn: + def _clear_board(self) -> None: """clear_boardコマンドを処理する。 盤面を初期化する。 """ @@ -252,7 +252,7 @@ def _clear_board(self) -> NoReturn: self.time_manager.initialize() respond_success("") - def _time_settings(self, arg_list: List[str]) -> NoReturn: + def _time_settings(self, arg_list: List[str]) -> None: """time_settingsコマンドを処理する。 持ち時間のみを設定する。 @@ -264,7 +264,7 @@ def _time_settings(self, arg_list: List[str]) -> NoReturn: self.time_manager.set_remaining_time(Stone.WHITE, time) respond_success("") - def _time_left(self, arg_list: List[str]) -> NoReturn: + def _time_left(self, arg_list: List[str]) -> None: """time_leftコマンドを処理する。 指定した手番の残りの時間を設定する。 @@ -281,19 +281,19 @@ def _time_left(self, arg_list: List[str]) -> NoReturn: self.time_manager.set_remaining_time(color, float(arg_list[1])) respond_success("") - def _get_komi(self) -> NoReturn: + def _get_komi(self) -> None: """get_komiコマンドを処理する。 """ respond_success(str(self.board.get_komi())) - def _showboard(self) -> NoReturn: + def _showboard(self) -> None: """showboardコマンドを処理する。 現在の盤面を表示する。 """ self.board.display() respond_success("") - def _loadsgf(self, arg_list: List[str]) -> NoReturn: + def _loadsgf(self, arg_list: List[str]) -> None: """loadsgfコマンドを処理する。 指定したSGFファイルの指定手番まで進めた局面にする。 @@ -311,7 +311,7 @@ def _loadsgf(self, arg_list: List[str]) -> NoReturn: moves = int(arg_list[1]) self._load_sgf_data(sgf_data, moves) - def _readsgf(self, arg_list: List[str]) -> NoReturn: + def _readsgf(self, arg_list: List[str]) -> None: """tamago-readsgfコマンドを処理する。 指定したSGF文字列の局面にする。 @@ -322,7 +322,7 @@ def _readsgf(self, arg_list: List[str]) -> NoReturn: sgf_data = SGFReader(sgf_text, board_size=self.board.get_board_size(), literal=True) self._load_sgf_data(sgf_data) - def _load_sgf_data(self, sgf_data: SGFReader, moves: int=9999) -> NoReturn: + def _load_sgf_data(self, sgf_data: SGFReader, moves: int=9999) -> None: """SGFデータを読み込み、指定手番まで進めた局面にする。 Args: @@ -339,7 +339,7 @@ def _load_sgf_data(self, sgf_data: SGFReader, moves: int=9999) -> NoReturn: respond_success("") - def _fixed_handicap(self, handicaps: str) -> NoReturn: + def _fixed_handicap(self, handicaps: str) -> None: """fixed_handicapコマンドを処理する。 指定した数の置き石を置く。 @@ -365,7 +365,7 @@ def _fixed_handicap(self, handicaps: str) -> NoReturn: respond_success(" ".join(handicap_list)) - def _decode_analyze_arg(self, arg_list: List[str]) -> (Stone, float): + def _decode_analyze_arg(self, arg_list: List[str]) -> Tuple[Stone, float]: """analyzeコマンド(lz-analyze, cgos-analyze)の引数を解釈する。 不正な引数の場合は更新間隔として負値を返す。 @@ -376,7 +376,7 @@ def _decode_analyze_arg(self, arg_list: List[str]) -> (Stone, float): (Stone, float): 手番の色、更新間隔(秒) """ to_move = self.board.get_to_move() - interval = 0 + interval = 0.0 error_value = (to_move, -1.0) # 受けつける形式の例 # lz-analyze B 10 @@ -404,19 +404,19 @@ def _decode_analyze_arg(self, arg_list: List[str]) -> (Stone, float): return error_value return (to_move, interval) - def _analyze_or_animate(self, mode: str, arg_list: List[str]) -> NoReturn: + def _analyze_or_animate(self, mode: str, arg_list: List[str]) -> None: if max(self.animation_pv_wait, self.animation_move_wait) >= 0: self._animate(arg_list, self.animation_pv_wait, self.animation_move_wait) else: self._analyze(mode, arg_list) - def _animate(self, arg_list: List[str], pv_wait: float, move_wait: float) -> NoReturn: + def _animate(self, arg_list: List[str], pv_wait: float, move_wait: float) -> None: to_move, _ = self._decode_analyze_arg(arg_list) respond_success("", ongoing=True) animate_mcts(self.mcts, self.board, to_move, pv_wait, move_wait) print_out("") - def _analyze(self, mode: str, arg_list: List[str]) -> NoReturn: + def _analyze(self, mode: str, arg_list: List[str]) -> None: """analyzeコマンド(lz-analyze, cgos-analyze)を実行する。 Args: @@ -437,7 +437,7 @@ def _analyze(self, mode: str, arg_list: List[str]) -> NoReturn: } self.mcts.ponder(self.board, to_move, analysis_query) - def _genmove_analyze(self, mode: str, arg_list: List[str]) -> NoReturn: + def _genmove_analyze(self, mode: str, arg_list: List[str]) -> None: """genmove_analyzeコマンド(lz-genmove_analyze, cgos-genmove_analyze)を実行する。 Args: @@ -475,7 +475,7 @@ def _genmove_analyze(self, mode: str, arg_list: List[str]) -> NoReturn: print_out(f"play {self.coordinate.convert_to_gtp_format(pos)}\n") - def _dump_tree(self) -> NoReturn: + def _dump_tree(self) -> None: """tamago-dump_treeコマンドを実行する。現在のMCTSツリーの状態をJSON形式で出力する。 """ json_str = self.mcts.dump_to_json(self.board, self.superko) @@ -598,7 +598,7 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915 else: respond_failure("unknown_command") -def respond_success(response: str, ongoing: bool = False) -> NoReturn: +def respond_success(response: str, ongoing: bool = False) -> None: """コマンド処理成功時の応答メッセージを表示する。 Args: @@ -608,7 +608,7 @@ def respond_success(response: str, ongoing: bool = False) -> NoReturn: terminator = "" if ongoing else '\n' print(f"={gtp_command_id} " + response + terminator) -def respond_failure(response: str) -> NoReturn: +def respond_failure(response: str) -> None: """コマンド処理失敗時の応答メッセージを表示する。 Args: @@ -616,25 +616,25 @@ def respond_failure(response: str) -> NoReturn: """ print(f"?{gtp_command_id} " + response + '\n') -def _version() -> NoReturn: +def _version() -> None: """versionコマンドを処理する。 プログラムのバージョンを表示する。 """ respond_success(VERSION) -def _protocol_version() -> NoReturn: +def _protocol_version() -> None: """protocol_versionコマンドを処理する。 GTPのプロトコルバージョンを表示する。 """ respond_success(PROTOCOL_VERSION) -def _name() -> NoReturn: +def _name() -> None: """nameコマンドを処理する。 プログラム名を表示する。 """ respond_success(PROGRAM_NAME) -def _quit() -> NoReturn: +def _quit() -> None: """quitコマンドを処理する。 プログラムを終了する。 """ diff --git a/gtp/gogui.py b/gtp/gogui.py index c468342..31a1520 100644 --- a/gtp/gogui.py +++ b/gtp/gogui.py @@ -45,11 +45,11 @@ def display_policy_distribution(model: DualNet, board: GoBoard, color: Stone) -> str: 表示用文字列。 """ board_size = board.get_board_size() - input_plane = generate_input_planes(board, color) - input_plane = torch.tensor(input_plane.reshape(1, 6, board_size, board_size)) #pylint: disable=E1121 + input_plane_data = generate_input_planes(board, color) + input_plane = torch.tensor(input_plane_data.reshape(1, 6, board_size, board_size)) #pylint: disable=E1121 policy, _ = model.inference(input_plane) - max_policy, min_policy = 0, 1 + max_policy, min_policy = 0.0, 1.0 log_policies = [math.log(policy[0][i]) for i in range(board_size * board_size)] for i, log_policy in enumerate(log_policies): @@ -86,8 +86,8 @@ def display_policy_score(model: DualNet, board: GoBoard, color: Stone) -> str: str: 表示用文字列。 """ board_size = board.get_board_size() - input_plane = generate_input_planes(board, color) - input_plane = torch.tensor(input_plane.reshape(1, 6, board_size, board_size)) #pylint: disable=E1121 + input_plane_data = generate_input_planes(board, color) + input_plane = torch.tensor(input_plane_data.reshape(1, 6, board_size, board_size)) #pylint: disable=E1121 policy_predict, _ = model.inference(input_plane) policies = [policy_predict[0][i] for i in range(board_size ** 2)] response = "" diff --git a/mcts/batch_data.py b/mcts/batch_data.py index b3b749c..63836f3 100644 --- a/mcts/batch_data.py +++ b/mcts/batch_data.py @@ -14,7 +14,7 @@ def __init__(self): self.path = [] self.node_index = [] - def push(self, input_plane: np.array, path: List[Tuple[int, int]], node_index: int): + def push(self, input_plane: np.ndarray, path: List[Tuple[int, int]], node_index: int) -> None: """キューにデータをプッシュする。 Args: @@ -26,7 +26,7 @@ def push(self, input_plane: np.array, path: List[Tuple[int, int]], node_index: i self.path.append(path) self.node_index.append(node_index) - def clear(self): + def clear(self) -> None: """キューのデータを全て削除する。 """ self.input_plane = [] diff --git a/mcts/dump.py b/mcts/dump.py index 222a3e4..bf4d989 100644 --- a/mcts/dump.py +++ b/mcts/dump.py @@ -1,5 +1,5 @@ import json -from typing import Any, Tuple, List, Dict, NoReturn +from typing import Any, Tuple, List, Dict from program import PROGRAM_NAME, VERSION, PROTOCOL_VERSION from board.go_board import GoBoard, copy_board @@ -32,7 +32,7 @@ def dump_mcts_to_json(tree_dict: Dict[str, Any], board: GoBoard, superko: bool) } return json.dumps(state) -def enrich_mcts_dict(state: Dict[str, Any]) -> NoReturn: +def enrich_mcts_dict(state: Dict[str, Any]) -> None: """MCTSの状態を表す辞書に便利項目をいろいろ追加する。 Args: @@ -62,7 +62,7 @@ def enrich_mcts_dict(state: Dict[str, Any]) -> NoReturn: assert child_index < tree["num_nodes"], "Child index must be less than num_nodes." # 「親は子より前」「兄弟は order の小さい方が前」を保証したリスト - sorted_indices_list = [] + sorted_indices_list: List[Any] = [] tree["sorted_indices_list"] = sorted_indices_list # expanded_children_index, sorted_indices_list, 兄弟内 order @@ -142,7 +142,7 @@ def _recovered_move_history(converted_move_history: List[Tuple[str, int]]) -> Li def _stone_to_str(color: Stone) -> str: return 'black' if color == Stone.BLACK else 'white' -def _str_to_stone(color_str: str) -> str: +def _str_to_stone(color_str: str) -> Stone: return Stone.BLACK if color_str == 'black' else Stone.WHITE def _get_updated_board_string(root_board: GoBoard, initial_move_color: Stone, gtp_moves_along_path: List[str]) -> str: diff --git a/mcts/node.py b/mcts/node.py index 2a07da0..d22d712 100644 --- a/mcts/node.py +++ b/mcts/node.py @@ -1,10 +1,12 @@ """モンテカルロ木探索で使用するノードの実装。 """ import json -from typing import Any, Callable, Dict, List, NoReturn +from typing import Any, Callable, Dict, List, Self import numpy as np import torch + +from board.coordinate import Coordinate from board.constant import BOARD_SIZE from board.go_board import GoBoard from common.print_console import print_err @@ -38,7 +40,7 @@ def __init__(self, num_actions: int=MAX_ACTIONS): self.noise = np.zeros(num_actions, dtype=np.float64) self.num_children = 0 - def expand(self, policy: Dict[int, float]) -> NoReturn: + def expand(self, policy: Dict[int, float]) -> None: """ノードを展開し、初期化する。 Args: @@ -59,7 +61,7 @@ def expand(self, policy: Dict[int, float]) -> NoReturn: self.set_policy(policy) - def set_policy(self, policy_map: Dict[int, float]) -> NoReturn: + def set_policy(self, policy_map: Dict[int, float]) -> None: """着手候補の座標とPolicyの値を設定する。 Args: @@ -73,7 +75,7 @@ def set_policy(self, policy_map: Dict[int, float]) -> NoReturn: self.num_children = index - def add_virtual_loss(self, index) -> NoReturn: + def add_virtual_loss(self, index) -> None: """Virtual Lossを加算する。 Args: @@ -83,7 +85,7 @@ def add_virtual_loss(self, index) -> NoReturn: self.children_virtual_loss[index] += 1 - def update_policy(self, policy: Dict[int, float]) -> NoReturn: + def update_policy(self, policy: Dict[int, float]) -> None: """Policyを更新する。 Args: @@ -93,7 +95,7 @@ def update_policy(self, policy: Dict[int, float]) -> NoReturn: self.children_policy[i] = policy[self.action[i]] - def set_leaf_value(self, index: int, value: float) -> NoReturn: + def set_leaf_value(self, index: int, value: float) -> None: """末端のValueを設定する。 Args: @@ -101,12 +103,12 @@ def set_leaf_value(self, index: int, value: float) -> NoReturn: value (float): 設定するValueの値。 Returns: - NoReturn: _description_ + None: _description_ """ self.children_value[index] = value - def set_raw_value(self, value: float) -> NoReturn: + def set_raw_value(self, value: float) -> None: """ノードに対応する局面のValueを設定する。 Args: @@ -115,7 +117,7 @@ def set_raw_value(self, value: float) -> NoReturn: self.raw_value = value - def update_child_value(self, index: int, value: float) -> NoReturn: + def update_child_value(self, index: int, value: float) -> None: """子ノードにValueを加算し、Virtual Lossを元に戻す。 Args: @@ -127,7 +129,7 @@ def update_child_value(self, index: int, value: float) -> NoReturn: self.children_virtual_loss[index] -= 1 - def update_node_value(self, value: float) -> NoReturn: + def update_node_value(self, value: float) -> None: """ノードにValueを加算し、Virtual Lossを元に戻す。 Args: @@ -154,7 +156,7 @@ def select_next_action(self, cgos_mode: bool) -> int: # PASSのValueを0.1だけ引く pucb_values[self.num_children - 1] -= 0.1 - return np.argmax(pucb_values[:self.num_children]) + return int(np.argmax(pucb_values[:self.num_children])) def get_num_children(self) -> int: @@ -172,7 +174,7 @@ def get_best_move_index(self) -> int: Returns: int: 探索回数最大の子ノードのインデックス。 """ - return np.argmax(self.children_visits[:self.num_children]) + return int(np.argmax(self.children_visits[:self.num_children])) def get_best_move(self) -> int: @@ -208,7 +210,7 @@ def get_child_index(self, index: int) -> int: return self.children_index[index] - def set_child_index(self, index: int, child_index: int) -> NoReturn: + def set_child_index(self, index: int, child_index: int) -> None: """指定した子ノードの遷移先のインデックスを設定する。 Args: @@ -251,7 +253,7 @@ def _make_serializable(self, dic): val = val.item() dic[key] = val - def print_search_result(self, board: GoBoard, pv_dict: Dict[str, List[str]]) -> NoReturn: + def print_search_result(self, board: GoBoard, pv_dict: Dict[str, List[str]]) -> None: """探索結果を表示する。探索した手の探索回数とValueの平均値を表示する。 Args: @@ -272,13 +274,13 @@ def print_search_result(self, board: GoBoard, pv_dict: Dict[str, List[str]]) -> print_err(msg) - def set_gumbel_noise(self) -> NoReturn: + def set_gumbel_noise(self) -> None: """Gumbelノイズを設定する。 """ self.noise = np.random.gumbel(loc=0.0, scale=1.0, size=self.noise.size) - def calculate_completed_q_value(self, use_mixed_value :bool=True) -> np.array: + def calculate_completed_q_value(self, use_mixed_value: bool=True) -> np.ndarray: """Completed-Q valueを計算する。 Args: @@ -305,7 +307,7 @@ def calculate_completed_q_value(self, use_mixed_value :bool=True) -> np.array: return np.where(self.children_visits[:self.num_children] > 0, q_value, value) - def calculate_improved_policy(self) -> np.array: + def calculate_improved_policy(self) -> np.ndarray: """Improved Policyを計算する。 Returns: @@ -343,7 +345,7 @@ def select_move_by_sequential_halving_for_root(self, count_threshold: int) -> in evaluation_value = np.where(counts >= count_threshold, -10000.0, \ self.children_policy[:self.num_children] + self.noise[:self.num_children] \ + sigma_base * q_mean) - return np.argmax(evaluation_value) + return int(np.argmax(evaluation_value)) def select_move_by_sequential_halving_for_node(self) -> int: @@ -358,7 +360,7 @@ def select_move_by_sequential_halving_for_node(self) -> int: evaluation_value = improved_policy \ - (self.children_visits[:self.num_children] / (1.0 + self.node_visits)) - return np.argmax(evaluation_value) + return int(np.argmax(evaluation_value)) def calculate_value_evaluation(self, index: int) -> float: @@ -375,7 +377,7 @@ def calculate_value_evaluation(self, index: int) -> float: return self.children_value_sum[index] / self.children_visits[index] - def print_all_node_info(self) -> NoReturn: + def print_all_node_info(self) -> None: """子ノードの情報を全て表示する。 """ msg = "" @@ -396,8 +398,8 @@ def print_all_node_info(self) -> NoReturn: print_err(msg) - def get_analysis(self, board: GoBoard, mode: str, \ - pv_lists_func: Callable[[List[str], int], List[str]]) -> str: # pylint: disable=R0914 + def get_analysis(self, board: GoBoard, mode: str, + pv_lists_func: Callable[[Self, Coordinate], Dict[str, List[str]]]) -> str: # pylint: disable=R0914 """解析結果文字列を生成する。 Args: @@ -412,8 +414,8 @@ def get_analysis(self, board: GoBoard, mode: str, \ return self.get_analysis_from_status_list(mode, children_status_list) - def get_analysis_status_list(self, board: GoBoard, \ - pv_lists_func: Callable[[List[str], int], List[str]]): + def get_analysis_status_list(self, board: GoBoard, + pv_lists_func: Callable[[Self, Coordinate], Dict[str, List[str]]]): sorted_list = [] for i in range(self.num_children): sorted_list.append((self.children_visits[i], i)) diff --git a/mcts/sequential_halving.py b/mcts/sequential_halving.py index 776c82d..486c369 100644 --- a/mcts/sequential_halving.py +++ b/mcts/sequential_halving.py @@ -1,11 +1,11 @@ """Sequential Halving """ -from typing import Dict, Tuple +from typing import Dict, List, Tuple import math def get_sequence_of_considered_visits(max_num_considered_actions: int, \ - num_simulations: int) -> Tuple[int]: + num_simulations: int) -> Tuple[int, ...]: """探索回数に対応する探索回数閾値の列を取得する。 Args: @@ -18,7 +18,7 @@ def get_sequence_of_considered_visits(max_num_considered_actions: int, \ if max_num_considered_actions <= 1: return tuple(range(num_simulations)) log2max = int(math.ceil(math.log2(max_num_considered_actions))) - sequence = [] + sequence: List[int] = [] visits = [0] * max_num_considered_actions num_considered = max_num_considered_actions @@ -44,7 +44,7 @@ def get_candidates_and_visit_pairs(max_num_considered_actions: int, \ Returns: Dict[int, int]: 探索幅をキー、探索回数をバリューに持つ辞書。 """ - visit_dict = {} + visit_dict: Dict[int, int] = {} visit_list = get_sequence_of_considered_visits(max_num_considered_actions, num_simulations) max_count = max(visit_list) count_list = [0] * (max_count + 1) diff --git a/mcts/time_manager.py b/mcts/time_manager.py index 6a08aa2..4e6a241 100644 --- a/mcts/time_manager.py +++ b/mcts/time_manager.py @@ -1,7 +1,7 @@ """探索時間を制御する処理。 """ from enum import Enum -from typing import NoReturn +# from typing import NoReturn import time from board.stone import Stone @@ -36,10 +36,10 @@ def __init__(self, mode: TimeControl, constant_visits: int=CONST_VISITS, constan self.constant_visits = constant_visits self.constant_time = constant_time self.default_time = remaining_time - self.search_speed = VISITS_PER_SEC + self.search_speed = float(VISITS_PER_SEC) self.remaining_time = [remaining_time] * 2 - self.time_limit = 0 - self.start_time = 0 + self.time_limit = 0.0 + self.start_time = 0.0 def initialize(self): @@ -48,7 +48,7 @@ def initialize(self): self.remaining_time = [self.default_time] * 2 - def set_search_speed(self, visits: int, consumption_time: float) -> NoReturn: + def set_search_speed(self, visits: int, consumption_time: float) -> None: """探索速度を設定する。 Args: @@ -83,7 +83,7 @@ def get_num_visits_threshold(self, color: Stone) -> int: return int(self.constant_visits) - def set_remaining_time(self, color: Stone, remaining_time: float) -> NoReturn: + def set_remaining_time(self, color: Stone, remaining_time: float) -> None: """残り時間を設定する。 Args: @@ -109,7 +109,7 @@ def substract_consumption_time(self, color: Stone, consumption_time: float): self.remaining_time[1] -= consumption_time - def set_mode(self, mode:TimeControl) -> NoReturn: + def set_mode(self, mode:TimeControl) -> None: """思考時間管理の設定を変更する。 Args: @@ -118,7 +118,7 @@ def set_mode(self, mode:TimeControl) -> NoReturn: self.mode = mode - def start_timer(self) -> NoReturn: + def start_timer(self) -> None: """思考時間の計測を開始する。 """ self.start_time = time.time() diff --git a/mcts/tree.py b/mcts/tree.py index 01fb935..3234df2 100644 --- a/mcts/tree.py +++ b/mcts/tree.py @@ -1,6 +1,6 @@ """モンテカルロ木探索の実装。 """ -from typing import Any, Dict, List, NoReturn, Tuple, Callable +from typing import Any, Dict, List, Tuple, Callable import sys import select import copy @@ -46,7 +46,7 @@ def __init__(self, network: DualNet, tree_size: int=MCTS_TREE_SIZE, \ self.to_move = Stone.BLACK - def _initialize_search(self, board: GoBoard, color: Stone) -> NoReturn: + def _initialize_search(self, board: GoBoard, color: Stone) -> None: self.num_nodes = 0 self.current_root = self.expand_node(board, color) input_plane = generate_input_planes(board, color, 0) @@ -105,7 +105,7 @@ def search_best_move(self, board: GoBoard, color: Stone, time_manager: TimeManag return next_move - def ponder(self, board: GoBoard, color: Stone, analysis_query: Dict[str, Any]) -> NoReturn: + def ponder(self, board: GoBoard, color: Stone, analysis_query: Dict[str, Any]) -> None: """探索回数の制限なく探索を実行する。 Args: @@ -128,7 +128,7 @@ def ponder(self, board: GoBoard, color: Stone, analysis_query: Dict[str, Any]) - def search(self, board: GoBoard, color: Stone, time_manager: TimeManager, \ - analysis_query: Dict[str, Any]) -> NoReturn: # pylint: disable=R0914 + analysis_query: Dict[str, Any]) -> None: # pylint: disable=R0914 """探索を実行する。 Args: board (GoBoard): 現在の局面情報。 @@ -174,7 +174,7 @@ def search(self, board: GoBoard, color: Stone, time_manager: TimeManager, \ sys.stdout.flush() - def search_with_callback(self, board: GoBoard, color: Stone, callback: Callable[[Tuple[int, int]], bool]) -> NoReturn: + def search_with_callback(self, board: GoBoard, color: Stone, callback: Callable[[List[Tuple[int, int]]], bool]) -> None: """探索を実行し、探索系列をコールバック関数へ渡す動作をくり返す。 コールバック関数の戻り値が真になれば終了する。 Args: @@ -187,7 +187,7 @@ def search_with_callback(self, board: GoBoard, color: Stone, callback: Callable[ self._initialize_search(board, color) search_board = copy.deepcopy(board) while True: - path = [] + path: List[Tuple[int, int]] = [] copy_board(dst=search_board, src=board) self.search_mcts(search_board, color, self.current_root, path) finished = callback(path) @@ -197,7 +197,7 @@ def search_with_callback(self, board: GoBoard, color: Stone, callback: Callable[ def search_mcts(self, board: GoBoard, color: Stone, current_index: int, \ - path: List[Tuple[int, int]]) -> NoReturn: + path: List[Tuple[int, int]]) -> None: """モンテカルロ木探索を実行する。 Args: @@ -244,7 +244,7 @@ def search_mcts(self, board: GoBoard, color: Stone, current_index: int, \ self.search_mcts(board, color, next_node_index, path) - def expand_node(self, board: GoBoard, color: Stone) -> NoReturn: + def expand_node(self, board: GoBoard, color: Stone) -> int: """ノードを展開する。 Args: @@ -357,7 +357,7 @@ def generate_move_with_sequential_halving(self, board: GoBoard, color: Stone, \ def search_by_sequential_halving(self, board: GoBoard, color: Stone, \ - threshold: int) -> NoReturn: + threshold: int) -> None: """指定された探索回数だけSequential Halving探索を実行する。 Args: @@ -385,7 +385,7 @@ def search_by_sequential_halving(self, board: GoBoard, color: Stone, \ def search_sequential_halving(self, board: GoBoard, color: Stone, current_index: int, \ - path: List[Tuple[int, int]], count_threshold: int) -> NoReturn: # pylint: disable=R0913 + path: List[Tuple[int, int]], count_threshold: int) -> None: # pylint: disable=R0913 """Sequential Halving探索を実行する。 Args: @@ -438,7 +438,7 @@ def get_pv_lists(self, root: MCTSNode, coord: Coordinate) -> Dict[str, List[str] Returns: Dict[str, List[str]]: 各手の最善応手系列を記録した辞書。 """ - pv_dict = {} + pv_dict: Dict[str, List[str]] = {} for i in range(root.num_children): if root.children_visits[i] > 0: @@ -448,7 +448,7 @@ def get_pv_lists(self, root: MCTSNode, coord: Coordinate) -> Dict[str, List[str] return pv_dict - def get_best_move_sequence(self, pv_list: List[str], index: int) -> List[str]: + def get_best_move_sequence(self, pv_list: List[int], index: int) -> List[int]: """最善応手系列を取得する。 Args: diff --git a/nn/data_generator.py b/nn/data_generator.py index 823b17b..e659e0a 100644 --- a/nn/data_generator.py +++ b/nn/data_generator.py @@ -3,7 +3,7 @@ import glob import os import random -from typing import List, NoReturn +from typing import List import numpy as np from board.go_board import GoBoard from board.stone import Stone @@ -13,8 +13,8 @@ from learning_param import BATCH_SIZE, DATA_SET_SIZE -def _save_data(save_file_path: str, input_data: np.ndarray, policy_data: np.ndarray,\ - value_data: np.ndarray, kifu_counter: int) -> NoReturn: +def _save_data(save_file_path: str, input_data: List[np.ndarray], policy_data: List[np.ndarray],\ + value_data: List[int], kifu_counter: int) -> None: """学習データをnpzファイルとして出力する。 Args: @@ -34,7 +34,7 @@ def _save_data(save_file_path: str, input_data: np.ndarray, policy_data: np.ndar # pylint: disable=R0914 def generate_supervised_learning_data(program_dir: str, kifu_dir: str, \ - board_size: int=9) -> NoReturn: + board_size: int=9) -> None: """教師あり学習のデータを生成して保存する。 Args: @@ -87,7 +87,7 @@ def generate_supervised_learning_data(program_dir: str, kifu_dir: str, \ def generate_reinforcement_learning_data(program_dir: str, kifu_dir_list: List[str], \ - board_size: int=9) -> NoReturn: + board_size: int=9) -> None: """強化学習で使用するデータを生成し、保存する。 Args: diff --git a/nn/feature.py b/nn/feature.py index 7a8706c..ce2fb00 100644 --- a/nn/feature.py +++ b/nn/feature.py @@ -96,7 +96,7 @@ def generate_rl_target_data(board: GoBoard, improved_policy_data: str, sym: int= coord = board.coordinate.convert_from_gtp_format(pos) target_data[coord] = float(target) - target = [target_data[board.get_symmetrical_coordinate(pos, sym)] for pos in board.onboard_pos] - target.append(target_data[PASS]) + target_list = [target_data[board.get_symmetrical_coordinate(pos, sym)] for pos in board.onboard_pos] + target_list.append(target_data[PASS]) - return np.array(target) + return np.array(target_list) diff --git a/nn/learn.py b/nn/learn.py index a77ebb4..5db6ed9 100644 --- a/nn/learn.py +++ b/nn/learn.py @@ -1,6 +1,5 @@ """深層学習の実装。 """ -from typing import NoReturn import glob import os import time @@ -19,7 +18,7 @@ def train_on_cpu(program_dir: str, board_size: int, batch_size: \ - int, epochs: int) -> NoReturn: # pylint: disable=R0914,R0915 + int, epochs: int) -> None: # pylint: disable=R0914,R0915 """教師あり学習を実行し、学習したモデルを保存する。 Args: @@ -124,7 +123,7 @@ def train_on_cpu(program_dir: str, board_size: int, batch_size: \ def train_on_gpu(program_dir: str, board_size: int, batch_size: int, \ - epochs: int) -> NoReturn: # pylint: disable=R0914,R0915 + epochs: int) -> None: # pylint: disable=R0914,R0915 """教師あり学習を実行し、学習したモデルを保存する。 Args: @@ -232,7 +231,7 @@ def train_on_gpu(program_dir: str, board_size: int, batch_size: int, \ def train_with_gumbel_alphazero_on_cpu(program_dir: str, board_size: int, \ - batch_size: int) -> NoReturn: # pylint: disable=R0914,R0915 + batch_size: int) -> None: # pylint: disable=R0914,R0915 """教師あり学習を実行し、学習したモデルを保存する。CPUで実行。 Args: @@ -316,7 +315,7 @@ def train_with_gumbel_alphazero_on_cpu(program_dir: str, board_size: int, \ def train_with_gumbel_alphazero_on_gpu(program_dir: str, board_size: int, \ - batch_size: int) -> NoReturn: # pylint: disable=R0914,R0915 + batch_size: int) -> None: # pylint: disable=R0914,R0915 """教師あり学習を実行し、学習したモデルを保存する。GPUで実行。 Args: diff --git a/nn/policy_player.py b/nn/policy_player.py index 37594ce..a61d574 100644 --- a/nn/policy_player.py +++ b/nn/policy_player.py @@ -1,6 +1,7 @@ """Policy Networkのみを使用した着手生成処理 """ import random +from typing import Any, List import torch @@ -29,7 +30,7 @@ def generate_move_from_policy(network: DualNet, board: GoBoard, color: Stone) -> policy = policy[0].numpy().tolist() # 合法手のみ候補手としてピックアップ - candidates = [{"pos": pos, "policy": policy[i]} \ + candidates: List[Any] = [{"pos": pos, "policy": policy[i]} \ for i, pos in enumerate(board.onboard_pos) if board.is_legal(pos, color)] # パスは候補手確定 diff --git a/nn/utility.py b/nn/utility.py index 00a6af1..ebc5266 100644 --- a/nn/utility.py +++ b/nn/utility.py @@ -1,6 +1,6 @@ """深層学習に関するユーティリティ。 """ -from typing import NoReturn, Dict, List, Tuple +from typing import Dict, List, Protocol, Tuple import time import torch import numpy as np @@ -41,7 +41,7 @@ def _calculate_losses(loss: Dict[str, float], iteration: int) \ def print_learning_process(loss_data: Dict[str, float], epoch: int, index: int, \ - iteration: int, start_time: float) -> NoReturn: + iteration: int, start_time: float) -> None: """学習経過情報を表示する。 Args: @@ -60,7 +60,7 @@ def print_learning_process(loss_data: Dict[str, float], epoch: int, index: int, def print_evaluation_information(loss_data: Dict[str, float], epoch: int, \ - iteration: int, start_time: float) -> NoReturn: + iteration: int, start_time: float) -> None: """テストデータの評価情報を表示する。 Args: @@ -77,7 +77,7 @@ def print_evaluation_information(loss_data: Dict[str, float], epoch: int, \ print_err(f"\tvalue loss : {value_loss:6f}") -def save_model(network: torch.nn.Module, path: str) -> NoReturn: +def save_model(network: torch.nn.Module, path: str) -> None: """ニューラルネットワークのパラメータを保存する。 Args: @@ -122,14 +122,14 @@ def split_train_test_set(file_list: List[str], train_data_ratio: float) \ return train_data_set, test_data_set -def apply_softmax(logits: np.array) -> np.array: +def apply_softmax(logits: np.ndarray) -> np.ndarray: """Softmax関数を適用する。 Args: - logits (np.array): Softmax関数の入力値。 + logits (np.ndarray): Softmax関数の入力値。 Returns: - np.array: Softmax関数適用後の値。 + np.ndarray: Softmax関数適用後の値。 """ shift_exp = np.exp(logits - np.max(logits)) diff --git a/selfplay/worker.py b/selfplay/worker.py index ac18750..ce2de07 100644 --- a/selfplay/worker.py +++ b/selfplay/worker.py @@ -19,7 +19,7 @@ # pylint: disable=R0913,R0914 def selfplay_worker(save_dir: str, model_file_path: str, index_list: List[int], \ - size: int, visits: int, use_gpu: bool) -> NoReturn: + size: int, visits: int, use_gpu: bool) -> None: """自己対戦実行ワーカ。 Args: diff --git a/sgf/reader.py b/sgf/reader.py index f5481cc..5c42c21 100644 --- a/sgf/reader.py +++ b/sgf/reader.py @@ -1,6 +1,6 @@ """SGF形式のファイル読み込み処理。 """ -from typing import NoReturn +from typing import Generator, List, Optional, Tuple from board.coordinate import Coordinate from board.constant import PASS, OB_SIZE from board.stone import Stone @@ -33,6 +33,12 @@ class SGFReader: # pylint: disable=R0902 """SGFファイル読み込み。 """ + move: List[Tuple[int, int, Stone]] + event: Optional[str] + black_player_name: Optional[str] + white_player_name: Optional[str] + application: Optional[str] + copyright: Optional[str] def __init__(self, filename_or_text: str, board_size: int, literal: bool=False): # pylint: disable=R0912 """コンストラクタ @@ -43,7 +49,7 @@ def __init__(self, filename_or_text: str, board_size: int, literal: bool=False): """ self.board_size = board_size self.board_size_with_ob = board_size + OB_SIZE * 2 - self.move = [0] * board_size * board_size * 3 + self.move = [(0, 0, Stone.EMPTY)] * board_size * board_size * 3 self.komi = 7.0 self.result = MatchResult.DRAW self.comment = [""] * board_size * board_size * 3 @@ -289,7 +295,7 @@ def _get_move(self, sgf_text: str, cursor: int, color: Stone) -> int: return cursor + tmp_cursor - def get_moves(self) -> int: + def get_moves(self) -> Generator[int, None, None]: """最初から1つずつ着手を取得する。 Yields: @@ -368,7 +374,7 @@ def get_comment(self, index: int) -> str: """ return self.comment[index] - def display(self) -> NoReturn: + def display(self) -> None: """読み込んだSGFファイルの情報を表示する。(デバッグ用) """ message = "" diff --git a/sgf/selfplay_record.py b/sgf/selfplay_record.py index 014600d..a0acb1a 100644 --- a/sgf/selfplay_record.py +++ b/sgf/selfplay_record.py @@ -23,18 +23,18 @@ def __init__(self, save_dir: str, coord: Coordinate): """ self.record_moves = 0 self.color = [Stone.EMPTY] * MAX_RECORDS - self.pos = [0] * MAX_RECORDS + self.pos = ["0"] * MAX_RECORDS self.coord = coord self.policy_target = [""] * MAX_RECORDS self.save_dir = save_dir self.file_index = 1 - def clear(self) -> NoReturn: + def clear(self) -> None: """レコードの初期化。 """ self.record_moves = 0 - def set_index(self, index: int) -> NoReturn: + def set_index(self, index: int) -> None: """ファイルのインデックスを設定する。 Args: @@ -42,7 +42,7 @@ def set_index(self, index: int) -> NoReturn: """ self.file_index = index - def save_record(self, root: MCTSNode, pos: int, color: Stone) -> NoReturn: + def save_record(self, root: MCTSNode, pos: int, color: Stone) -> None: """着手とImproved Policyを記録する。 Args: @@ -57,14 +57,14 @@ def save_record(self, root: MCTSNode, pos: int, color: Stone) -> NoReturn: policy_target = f"{root.get_num_children()}" for i in range(root.get_num_children()): - pos = self.coord.convert_to_gtp_format(root.get_child_move(i)) - policy_target += f" {pos}:{improved_policy[i]:.3e}" + p = self.coord.convert_to_gtp_format(root.get_child_move(i)) + policy_target += f" {p}:{improved_policy[i]:.3e}" self.policy_target[self.record_moves] = policy_target self.record_moves += 1 - def write_record(self, winner: Stone, komi: float, is_resign: bool, score: float) -> NoReturn: + def write_record(self, winner: Stone, komi: float, is_resign: bool, score: float) -> None: """自己対戦のファイルを出力する。 Args: From 42cf46fbd259cf5f6b2574517ba729a181c04de4 Mon Sep 17 00:00:00 2001 From: Kensuke Matsuzaki Date: Sun, 26 Jan 2025 23:41:08 +0900 Subject: [PATCH 5/6] Fix mypy errors --- nn/utility.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/nn/utility.py b/nn/utility.py index f1b4684..42bdfea 100644 --- a/nn/utility.py +++ b/nn/utility.py @@ -10,6 +10,14 @@ from nn.network.dual_net import DualNet +class GoNet(Protocol): + def inference(self, input_plane: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + pass + + def inference_with_policy_logits(self, input_plane: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + pass + + def get_torch_device(use_gpu: bool) -> torch.device: """torch.deviceを取得する。 From d8a824f092412f57716d8b4732d02e7bfca00f37 Mon Sep 17 00:00:00 2001 From: Kensuke Matsuzaki Date: Tue, 28 Jan 2025 22:27:18 +0900 Subject: [PATCH 6/6] Fix missing implementation for ONNX Runtime support --- export_onnx.py | 61 +++++++++++++++++++++++++++--------------- gtp/gogui.py | 6 ++--- mcts/tree.py | 2 +- nn/network/dual_net.py | 6 ++--- nn/policy_player.py | 8 +++--- nn/utility.py | 7 ++--- selfplay/worker.py | 2 ++ 7 files changed, 55 insertions(+), 37 deletions(-) diff --git a/export_onnx.py b/export_onnx.py index 4af4b40..9dbd32f 100644 --- a/export_onnx.py +++ b/export_onnx.py @@ -1,27 +1,46 @@ +import click import torch +from board.constant import BOARD_SIZE from nn.utility import load_network +from nn.network.dual_net import DualNet -net = load_network("model/sl-model.bin", False) +@click.command() +@click.option('--size', type=click.IntRange(2, BOARD_SIZE), default=BOARD_SIZE, + help=f"碁盤のサイズを指定。デフォルトは{BOARD_SIZE}。") +@click.option('--model-path', type=click.STRING, + help=f"使用するPyTorchモデルのパスを指定する。プログラムのホームディレクトリの相対パスで指定。") +@click.option('--output-path', type=click.STRING, + help=f"保存するONNXモデルのパスを指定する。プログラムのホームディレクトリの相対パスで指定。") +def convert_to_onnx(model_path: str, output_path: str, size: int) -> None: + """Convert to ONNX format. + """ + net = load_network(model_path, False) + if not isinstance(net, torch.nn.Module): + print("Model is not instance of torch.nn.Module.") + return -board_size = 9 -input_tensor = torch.rand((1, 6, board_size, board_size), dtype=torch.float32) + input_tensor = torch.rand((1, 6, size, size), dtype=torch.float32) -torch.onnx.export( - net, - (input_tensor,), - "sl-model.onnx", - input_names=["input"], - output_names=["policy", "value",], - dynamic_axes={ - "input": { - 0: "batch", - }, - "policy": { - 0: "batch", - }, - "value": { - 0: "batch", - }, - } -) + torch.onnx.export( + net, + (input_tensor,), + output_path, + input_names=["input"], + output_names=["policy", "value",], + dynamic_axes={ + "input": { + 0: "batch", + }, + "policy": { + 0: "batch", + }, + "value": { + 0: "batch", + }, + } + ) + + +if __name__ == "__main__": + convert_to_onnx() # pylint: disable=E1120 diff --git a/gtp/gogui.py b/gtp/gogui.py index 31a1520..fd586ec 100644 --- a/gtp/gogui.py +++ b/gtp/gogui.py @@ -7,7 +7,7 @@ from board.go_board import GoBoard from board.stone import Stone from nn.feature import generate_input_planes -from nn.network.dual_net import DualNet +from nn.utility import DualNet class GoguiAnalyzeCommand: # pylint: disable=R0903 """Gogui解析コマンドの基本情報クラス。 @@ -46,7 +46,7 @@ def display_policy_distribution(model: DualNet, board: GoBoard, color: Stone) -> """ board_size = board.get_board_size() input_plane_data = generate_input_planes(board, color) - input_plane = torch.tensor(input_plane_data.reshape(1, 6, board_size, board_size)) #pylint: disable=E1121 + input_plane = input_plane_data.reshape(1, 6, board_size, board_size) #pylint: disable=E1121 policy, _ = model.inference(input_plane) max_policy, min_policy = 0.0, 1.0 @@ -87,7 +87,7 @@ def display_policy_score(model: DualNet, board: GoBoard, color: Stone) -> str: """ board_size = board.get_board_size() input_plane_data = generate_input_planes(board, color) - input_plane = torch.tensor(input_plane_data.reshape(1, 6, board_size, board_size)) #pylint: disable=E1121 + input_plane = input_plane_data.reshape(1, 6, board_size, board_size) #pylint: disable=E1121 policy_predict, _ = model.inference(input_plane) policies = [policy_predict[0][i] for i in range(board_size ** 2)] response = "" diff --git a/mcts/tree.py b/mcts/tree.py index 21ec21c..43fb111 100644 --- a/mcts/tree.py +++ b/mcts/tree.py @@ -13,7 +13,7 @@ from board.stone import Stone from common.print_console import print_err from nn.feature import generate_input_planes -from nn.network.dual_net import DualNet +from nn.utility import DualNet from mcts.batch_data import BatchQueue from mcts.constant import NOT_EXPANDED, PLAYOUTS, NN_BATCH_SIZE, \ MAX_CONSIDERED_NODES, RESIGN_THRESHOLD, MCTS_TREE_SIZE diff --git a/nn/network/dual_net.py b/nn/network/dual_net.py index 8f551bd..9099d4f 100644 --- a/nn/network/dual_net.py +++ b/nn/network/dual_net.py @@ -88,8 +88,7 @@ def inference(self, input_plane: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: Returns: Tuple[torch.Tensor, torch.Tensor]: Policy, Valueの推論結果。 """ - input_plane = torch.Tensor(input_plane) - policy, value = self.forward(input_plane.to(self.device)) + policy, value = self.forward(torch.Tensor(input_plane).to(self.device)) return self.softmax(policy).detach().cpu().numpy(), self.softmax(value).detach().cpu().numpy() @@ -104,8 +103,7 @@ def inference_with_policy_logits(self, input_plane: np.ndarray) \ Returns: Tuple[torch.Tensor, torch.Tensor]: Policy, Valueの推論結果。 """ - input_plane = torch.Tensor(input_plane) - policy, value = self.forward(input_plane.to(self.device)) + policy, value = self.forward(torch.Tensor(input_plane).to(self.device)) return policy.detach().cpu().numpy(), self.softmax(value).detach().cpu().numpy() diff --git a/nn/policy_player.py b/nn/policy_player.py index a61d574..7e848af 100644 --- a/nn/policy_player.py +++ b/nn/policy_player.py @@ -3,13 +3,11 @@ import random from typing import Any, List -import torch - from board.constant import PASS from board.go_board import GoBoard from board.stone import Stone from nn.feature import generate_input_planes -from nn.network.dual_net import DualNet +from nn.utility import DualNet def generate_move_from_policy(network: DualNet, board: GoBoard, color: Stone) -> int: """Policy Networkを使用して着手を生成する。 @@ -24,10 +22,10 @@ def generate_move_from_policy(network: DualNet, board: GoBoard, color: Stone) -> """ board_size = board.get_board_size() input_plane = generate_input_planes(board, color) - input_data = torch.tensor(input_plane.reshape(1, 6, board_size, board_size)) #pylint: disable=E1121 + input_data = input_plane.reshape(1, 6, board_size, board_size) #pylint: disable=E1121 policy, _ = network.inference(input_data) - policy = policy[0].numpy().tolist() + policy = policy[0].tolist() # 合法手のみ候補手としてピックアップ candidates: List[Any] = [{"pos": pos, "policy": policy[i]} \ diff --git a/nn/utility.py b/nn/utility.py index 42bdfea..9e1db65 100644 --- a/nn/utility.py +++ b/nn/utility.py @@ -7,10 +7,10 @@ import onnxruntime as ort from common.print_console import print_err -from nn.network.dual_net import DualNet +from nn.network.dual_net import DualNet as TorchDualNet -class GoNet(Protocol): +class DualNet(Protocol): def inference(self, input_plane: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: pass @@ -158,7 +158,7 @@ def load_network(model_file_path: str, use_gpu: bool) -> DualNet: if model_file_path.endswith(".onnx"): return OrtWrapper(model_file_path, use_gpu) device = get_torch_device(use_gpu=use_gpu) - network = DualNet(device) + network = TorchDualNet(device) network.to(device) try: network.load_state_dict(torch.load(model_file_path)) @@ -169,6 +169,7 @@ def load_network(model_file_path: str, use_gpu: bool) -> DualNet: return network + class OrtWrapper: def __init__(self, model_file_path: str, use_gpu: bool): providers = [] diff --git a/selfplay/worker.py b/selfplay/worker.py index ce2de07..5d6a2c3 100644 --- a/selfplay/worker.py +++ b/selfplay/worker.py @@ -15,6 +15,7 @@ from mcts.tree import MCTSTree from mcts.time_manager import TimeManager, TimeControl from nn.utility import load_network +from nn.network.dual_net import DualNet from learning_param import SELF_PLAY_VISITS # pylint: disable=R0913,R0914 @@ -34,6 +35,7 @@ def selfplay_worker(save_dir: str, model_file_path: str, index_list: List[int], init_board = GoBoard(board_size=size, komi=7.0, check_superko=True) record = SelfPlayRecord(save_dir, board.coordinate) network = load_network(model_file_path=model_file_path, use_gpu=use_gpu) + assert isinstance(network, DualNet) network.training = False np.random.seed(random.choice(index_list))