diff --git a/sanguo_data.py b/sanguo_data.py index cba5600..b7ddf10 100644 --- a/sanguo_data.py +++ b/sanguo_data.py @@ -15,7 +15,7 @@ def __init__(self, source = 'sanguo-utf8.txt', block_size = 192, training_set_ra self.decoder = None self.data = None - def ingest(self, gen_dataset=True): + def ingest(self, gen_dataset=True, gen_token_map=True): with open(self.source, 'r', encoding='utf-8') as f: self.text = f.read() print(f"Length of text: {len(self.text)}") # 606051 Chinese characters @@ -41,6 +41,9 @@ def ingest(self, gen_dataset=True): self.data = torch.tensor(self.encoder(self.text), dtype=torch.long) # print(self.data.shape, self.data.dtype) + if gen_token_map: + self.save_token_map() + if gen_dataset: self.gen_dataset()