From d40fa04e4a20dbffcec0be8628b0072feb167a49 Mon Sep 17 00:00:00 2001 From: jacksona Date: Thu, 14 Feb 2019 11:21:21 -0800 Subject: [PATCH 01/16] first cut at training with BT over many blocks --- preprocessing.py | 34 ++++++++++++++++++++++++++++++++++ train.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/preprocessing.py b/preprocessing.py index 4f4b1d079..e4eae7faa 100644 --- a/preprocessing.py +++ b/preprocessing.py @@ -261,6 +261,40 @@ def get_tpu_bt_input_tensors(games, games_nr, batch_size, num_repeats=1, return dataset +def get_many_tpu_bt_many_input_tensors(games, games_nr, batch_size, + start_at, num_datasets, + moves=2**21 + window_size=500e3 + window_increment=25000): + datasets = [] + for i in range(num_datasets): + # TODO(amj) mixin calibration games with some math. (from start_at that + # is proportionally along compared to last_game_number? comparing + # timestamps?) + dataset = games.moves_from_games(start_at + (i * window_increment), + start_at + (i * window_increment) + window_size, + moves=moves, + shuffle=True, + column_family=TFEXAMPLE, + column='example') + dataset = dataset.repeat(1) + dataset = dataset.batch(batch_size) + dataset = dataset.filter(lambda t: tf.equal(tf.shape(t)[0], batch_size)) + dataset = dataset.map( + functools.partial(batch_parse_tf_example, batch_size)) + if random_rotation: + # Unbatch the dataset so we can rotate it + dataset = dataset.apply(tf.contrib.data.unbatch()) + dataset = dataset.apply(tf.contrib.data.map_and_batch( + _random_rotation_pure_tf, + batch_size, + drop_remainder=True)) + + dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) + datasets.append(dataset) + return datasets + + def make_dataset_from_selfplay(data_extracts): ''' Returns an iterable of tf.Examples. diff --git a/train.py b/train.py index b0f433a6a..ac7d0af5e 100644 --- a/train.py +++ b/train.py @@ -132,6 +132,41 @@ def after_run(self, run_context, run_values): self.before_weights = None +def train_many(start_at=1000000, num_datasets=3): + """ Trains on a set of bt_datasets, skipping eval for now. + (from preprocessing.get_many_tpu_bt_input_tensors) + """ + if not FLAGS.use_tpu and FLAGS.use_bt: + raise ValueError("Only tpu & bt mode supported") + + tf.logging.set_verbosity(tf.logging.INFO) + estimator = dual_net.get_estimator() + effective_batch_size = FLAGS.train_batch_size * FLAGS.num_tpu_cores + + def _input_fn(params): + games = bigtable_input.GameQueue( + FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table) + games_nr = bigtable_input.GameQueue( + FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table + '-nr') + + datasets = preprocessing.get_many_tpu_bt_input_tensors( + games, games_nr, params['batch_size'], + start_at=start_at num_datasets=num_datasets) + + d = datasets[0] + for d_next in datasets[1:]: + d.concatenate(d_next) + return d + hooks = [] + + steps = num_datasets * (2**21) + logging.info("Training, steps = %s, batch = %s -> %s examples", + steps or '?', effective_batch_size, + (steps * effective_batch_size) if steps else '?') + + estimator.train(_input_fn, steps=steps, hooks=hooks) + + def train(*tf_records: "Records to train on"): """Train on examples.""" tf.logging.set_verbosity(tf.logging.INFO) From 32fdca9585de77d2f4a31ff0520264d8c501376c Mon Sep 17 00:00:00 2001 From: jacksona Date: Thu, 14 Feb 2019 11:26:19 -0800 Subject: [PATCH 02/16] bad at math --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index ac7d0af5e..c7989b5bc 100644 --- a/train.py +++ b/train.py @@ -159,7 +159,7 @@ def _input_fn(params): return d hooks = [] - steps = num_datasets * (2**21) + steps = num_datasets * FLAGS.steps_to_train logging.info("Training, steps = %s, batch = %s -> %s examples", steps or '?', effective_batch_size, (steps * effective_batch_size) if steps else '?') From 39807ee4b4ce4d787208990bf2d9916fb02aa015 Mon Sep 17 00:00:00 2001 From: jacksona Date: Thu, 14 Feb 2019 13:42:08 -0800 Subject: [PATCH 03/16] try it this way, concat inside get_many --- preprocessing.py | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/preprocessing.py b/preprocessing.py index e4eae7faa..403186979 100644 --- a/preprocessing.py +++ b/preprocessing.py @@ -263,36 +263,40 @@ def get_tpu_bt_input_tensors(games, games_nr, batch_size, num_repeats=1, def get_many_tpu_bt_many_input_tensors(games, games_nr, batch_size, start_at, num_datasets, - moves=2**21 - window_size=500e3 + moves=2**21, + window_size=500e3, window_increment=25000): - datasets = [] + dataset = None for i in range(num_datasets): # TODO(amj) mixin calibration games with some math. (from start_at that # is proportionally along compared to last_game_number? comparing # timestamps?) - dataset = games.moves_from_games(start_at + (i * window_increment), + ds = games.moves_from_games(start_at + (i * window_increment), start_at + (i * window_increment) + window_size, moves=moves, shuffle=True, column_family=TFEXAMPLE, column='example') - dataset = dataset.repeat(1) - dataset = dataset.batch(batch_size) - dataset = dataset.filter(lambda t: tf.equal(tf.shape(t)[0], batch_size)) - dataset = dataset.map( - functools.partial(batch_parse_tf_example, batch_size)) - if random_rotation: - # Unbatch the dataset so we can rotate it - dataset = dataset.apply(tf.contrib.data.unbatch()) - dataset = dataset.apply(tf.contrib.data.map_and_batch( - _random_rotation_pure_tf, - batch_size, - drop_remainder=True)) - - dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) - datasets.append(dataset) - return datasets + ds = dataset.repeat(1) + if datasets: + dataset = dataset.concatenate(ds) + else: + dataset = ds + + dataset = dataset.batch(batch_size) + dataset = dataset.filter(lambda t: tf.equal(tf.shape(t)[0], batch_size)) + dataset = dataset.map( + functools.partial(batch_parse_tf_example, batch_size)) + if random_rotation: + # Unbatch the dataset so we can rotate it + dataset = dataset.apply(tf.contrib.data.unbatch()) + dataset = dataset.apply(tf.contrib.data.map_and_batch( + _random_rotation_pure_tf, + batch_size, + drop_remainder=True)) + + dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) + return dataset def make_dataset_from_selfplay(data_extracts): From 3a3f2a9be38c051ea529b0d9b8f1b4396d8adb4c Mon Sep 17 00:00:00 2001 From: jacksona Date: Thu, 14 Feb 2019 13:42:20 -0800 Subject: [PATCH 04/16] syntax --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index c7989b5bc..0bde2e4ea 100644 --- a/train.py +++ b/train.py @@ -151,11 +151,11 @@ def _input_fn(params): datasets = preprocessing.get_many_tpu_bt_input_tensors( games, games_nr, params['batch_size'], - start_at=start_at num_datasets=num_datasets) + start_at=start_at, num_datasets=num_datasets) d = datasets[0] for d_next in datasets[1:]: - d.concatenate(d_next) + d = d.concatenate(d_next) return d hooks = [] From 145a6fc7aaa37331f90fbcef811ce910e91293dc Mon Sep 17 00:00:00 2001 From: jacksona Date: Thu, 14 Feb 2019 13:42:47 -0800 Subject: [PATCH 05/16] typo --- preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/preprocessing.py b/preprocessing.py index 403186979..9fac87f95 100644 --- a/preprocessing.py +++ b/preprocessing.py @@ -261,7 +261,7 @@ def get_tpu_bt_input_tensors(games, games_nr, batch_size, num_repeats=1, return dataset -def get_many_tpu_bt_many_input_tensors(games, games_nr, batch_size, +def get_many_tpu_bt_input_tensors(games, games_nr, batch_size, start_at, num_datasets, moves=2**21, window_size=500e3, From 68f3e792ebc25312db161841734203de2c3feff0 Mon Sep 17 00:00:00 2001 From: jacksona Date: Thu, 14 Feb 2019 13:43:55 -0800 Subject: [PATCH 06/16] another typo --- preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/preprocessing.py b/preprocessing.py index 9fac87f95..ffd20fdfe 100644 --- a/preprocessing.py +++ b/preprocessing.py @@ -275,7 +275,7 @@ def get_many_tpu_bt_input_tensors(games, games_nr, batch_size, start_at + (i * window_increment) + window_size, moves=moves, shuffle=True, - column_family=TFEXAMPLE, + column_family=bigtable_input.TFEXAMPLE, column='example') ds = dataset.repeat(1) if datasets: From 73891969c63de352aa8094863a169efebea7f7aa Mon Sep 17 00:00:00 2001 From: jacksona Date: Thu, 14 Feb 2019 13:47:59 -0800 Subject: [PATCH 07/16] help i cant type --- preprocessing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/preprocessing.py b/preprocessing.py index ffd20fdfe..eb6a0b396 100644 --- a/preprocessing.py +++ b/preprocessing.py @@ -277,8 +277,8 @@ def get_many_tpu_bt_input_tensors(games, games_nr, batch_size, shuffle=True, column_family=bigtable_input.TFEXAMPLE, column='example') - ds = dataset.repeat(1) - if datasets: + ds = ds.repeat(1) + if dataset: dataset = dataset.concatenate(ds) else: dataset = ds From b68bc5ecd16d588c8322861e990c7f346ce68a92 Mon Sep 17 00:00:00 2001 From: jacksona Date: Thu, 14 Feb 2019 14:06:18 -0800 Subject: [PATCH 08/16] move batching inside the loop --- preprocessing.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/preprocessing.py b/preprocessing.py index eb6a0b396..ee2ea423d 100644 --- a/preprocessing.py +++ b/preprocessing.py @@ -278,12 +278,9 @@ def get_many_tpu_bt_input_tensors(games, games_nr, batch_size, column_family=bigtable_input.TFEXAMPLE, column='example') ds = ds.repeat(1) - if dataset: - dataset = dataset.concatenate(ds) - else: - dataset = ds + ds = ds.batch(batch_size) + dataset = dataset.concatenate(ds) if dataset else ds - dataset = dataset.batch(batch_size) dataset = dataset.filter(lambda t: tf.equal(tf.shape(t)[0], batch_size)) dataset = dataset.map( functools.partial(batch_parse_tf_example, batch_size)) From 936638a1cbfa5fd3cec015c4e03ea345bef7db75 Mon Sep 17 00:00:00 2001 From: jacksona Date: Thu, 14 Feb 2019 16:16:55 -0800 Subject: [PATCH 09/16] collapse from key,data to just data. make rotation always on --- preprocessing.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/preprocessing.py b/preprocessing.py index ee2ea423d..62adcd21e 100644 --- a/preprocessing.py +++ b/preprocessing.py @@ -278,19 +278,18 @@ def get_many_tpu_bt_input_tensors(games, games_nr, batch_size, column_family=bigtable_input.TFEXAMPLE, column='example') ds = ds.repeat(1) - ds = ds.batch(batch_size) + ds = ds.map(lambda row_name, s: s) dataset = dataset.concatenate(ds) if dataset else ds - dataset = dataset.filter(lambda t: tf.equal(tf.shape(t)[0], batch_size)) + dataset = dataset.batch(batch_size,drop_remainder=False) dataset = dataset.map( functools.partial(batch_parse_tf_example, batch_size)) - if random_rotation: - # Unbatch the dataset so we can rotate it - dataset = dataset.apply(tf.contrib.data.unbatch()) - dataset = dataset.apply(tf.contrib.data.map_and_batch( - _random_rotation_pure_tf, - batch_size, - drop_remainder=True)) + # Unbatch the dataset so we can rotate it + dataset = dataset.apply(tf.contrib.data.unbatch()) + dataset = dataset.apply(tf.contrib.data.map_and_batch( + _random_rotation_pure_tf, + batch_size, + drop_remainder=True)) dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) return dataset From 9091ccad68ec3c15595d57041629e1459ee35a71 Mon Sep 17 00:00:00 2001 From: jacksona Date: Thu, 14 Feb 2019 16:17:10 -0800 Subject: [PATCH 10/16] fix double concat. --- train.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/train.py b/train.py index 0bde2e4ea..6d7fc366b 100644 --- a/train.py +++ b/train.py @@ -149,16 +149,11 @@ def _input_fn(params): games_nr = bigtable_input.GameQueue( FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table + '-nr') - datasets = preprocessing.get_many_tpu_bt_input_tensors( + return preprocessing.get_many_tpu_bt_input_tensors( games, games_nr, params['batch_size'], start_at=start_at, num_datasets=num_datasets) - d = datasets[0] - for d_next in datasets[1:]: - d = d.concatenate(d_next) - return d hooks = [] - steps = num_datasets * FLAGS.steps_to_train logging.info("Training, steps = %s, batch = %s -> %s examples", steps or '?', effective_batch_size, From 0cab4e938810e85ba8806e116c200a506b390a7a Mon Sep 17 00:00:00 2001 From: jacksona Date: Fri, 15 Feb 2019 11:17:15 -0800 Subject: [PATCH 11/16] PR comments. --- preprocessing.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/preprocessing.py b/preprocessing.py index 62adcd21e..20dbce91f 100644 --- a/preprocessing.py +++ b/preprocessing.py @@ -262,25 +262,25 @@ def get_tpu_bt_input_tensors(games, games_nr, batch_size, num_repeats=1, def get_many_tpu_bt_input_tensors(games, games_nr, batch_size, - start_at, num_datasets, - moves=2**21, - window_size=500e3, - window_increment=25000): + start_at, num_datasets, + moves=2**21, + window_size=500e3, + window_increment=25000): dataset = None for i in range(num_datasets): # TODO(amj) mixin calibration games with some math. (from start_at that # is proportionally along compared to last_game_number? comparing # timestamps?) ds = games.moves_from_games(start_at + (i * window_increment), - start_at + (i * window_increment) + window_size, - moves=moves, - shuffle=True, - column_family=bigtable_input.TFEXAMPLE, - column='example') - ds = ds.repeat(1) - ds = ds.map(lambda row_name, s: s) + start_at + (i * window_increment) + window_size, + moves=moves, + shuffle=True, + column_family=bigtable_input.TFEXAMPLE, + column='example') dataset = dataset.concatenate(ds) if dataset else ds + dataset = dataset.repeat(1) + dataset = dataset.map(lambda row_name, s: s) dataset = dataset.batch(batch_size,drop_remainder=False) dataset = dataset.map( functools.partial(batch_parse_tf_example, batch_size)) From fc0e3a0286b8d61963a2a2a063399a99556fca6a Mon Sep 17 00:00:00 2001 From: jacksona Date: Wed, 20 Mar 2019 12:27:33 -0700 Subject: [PATCH 12/16] extract to flags --- train.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 4cce99b5e..ae982bb69 100644 --- a/train.py +++ b/train.py @@ -54,6 +54,15 @@ flags.DEFINE_bool('freeze', False, 'Whether to freeze the graph at the end of training.') +flags.DEFINE_bool('train_many', False, + 'Whether to run train repeatedly, automatically incrementing the window') + +flags.DEFINE_integer('window_start_at', 10000000, + 'The game number to start the window at (when used with `many`)') + +flags.DEFINE_integer('num_repeats', 3, + 'Used with `many`. The number of times to increment the window and re-train.') + flags.register_multi_flags_validator( ['use_bt', 'use_tpu'], @@ -242,8 +251,12 @@ def main(argv): tf_records = argv[1:] logging.info("Training on %s records: %s to %s", len(tf_records), tf_records[0], tf_records[-1]) - with utils.logged_timer("Training"): - train(*tf_records) + if FLAGS.train_many: + with utils.logged_timer("Training"): + train_many(FLAGS.start_at, FLAGS.num_datasets) + else: + with utils.logged_timer("Training"): + train(*tf_records) if FLAGS.export_path: dual_net.export_model(FLAGS.export_path) if FLAGS.freeze: From 8518517c7b1d5e932d02432e4fd4d91894fd70a0 Mon Sep 17 00:00:00 2001 From: jacksona Date: Wed, 20 Mar 2019 12:49:59 -0700 Subject: [PATCH 13/16] add moar params --- train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index ae982bb69..26df2483a 100644 --- a/train.py +++ b/train.py @@ -148,7 +148,7 @@ def after_run(self, run_context, run_values): self.before_weights = None -def train_many(start_at=1000000, num_datasets=3): +def train_many(start_at=1000000, num_datasets=3, moves=2**24): """ Trains on a set of bt_datasets, skipping eval for now. (from preprocessing.get_many_tpu_bt_input_tensors) """ @@ -167,6 +167,8 @@ def _input_fn(params): return preprocessing.get_many_tpu_bt_input_tensors( games, games_nr, params['batch_size'], + moves=moves, + window_size=FLAGS.window_size, start_at=start_at, num_datasets=num_datasets) hooks = [] @@ -253,7 +255,7 @@ def main(argv): len(tf_records), tf_records[0], tf_records[-1]) if FLAGS.train_many: with utils.logged_timer("Training"): - train_many(FLAGS.start_at, FLAGS.num_datasets) + train_many(FLAGS.start_at, FLAGS.num_datasets) else: with utils.logged_timer("Training"): train(*tf_records) From b29875ae9340db2a9a5597036813802bfb46dbbe Mon Sep 17 00:00:00 2001 From: jacksona Date: Wed, 20 Mar 2019 12:55:44 -0700 Subject: [PATCH 14/16] fix main code branch --- train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 26df2483a..bdd3b7b53 100644 --- a/train.py +++ b/train.py @@ -172,6 +172,7 @@ def _input_fn(params): start_at=start_at, num_datasets=num_datasets) hooks = [] + steps = num_datasets * FLAGS.steps_to_train logging.info("Training, steps = %s, batch = %s -> %s examples", steps or '?', effective_batch_size, @@ -250,13 +251,13 @@ def _input_fn(): def main(argv): """Train on examples and export the updated model weights.""" - tf_records = argv[1:] - logging.info("Training on %s records: %s to %s", - len(tf_records), tf_records[0], tf_records[-1]) if FLAGS.train_many: with utils.logged_timer("Training"): train_many(FLAGS.start_at, FLAGS.num_datasets) else: + tf_records = argv[1:] + logging.info("Training on %s records: %s to %s", + len(tf_records), tf_records[0], tf_records[-1]) with utils.logged_timer("Training"): train(*tf_records) if FLAGS.export_path: From e90ce649532b0205da7855503680654792b053ba Mon Sep 17 00:00:00 2001 From: jacksona Date: Wed, 20 Mar 2019 12:56:20 -0700 Subject: [PATCH 15/16] lint --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index bdd3b7b53..c6f21f8a1 100644 --- a/train.py +++ b/train.py @@ -253,7 +253,7 @@ def main(argv): """Train on examples and export the updated model weights.""" if FLAGS.train_many: with utils.logged_timer("Training"): - train_many(FLAGS.start_at, FLAGS.num_datasets) + train_many(FLAGS.window_start_at, FLAGS.num_datasets) else: tf_records = argv[1:] logging.info("Training on %s records: %s to %s", From cdfe3918c2db4b5741da044fb36e7fc3614fa0db Mon Sep 17 00:00:00 2001 From: jacksona Date: Tue, 26 Mar 2019 14:20:40 -0700 Subject: [PATCH 16/16] Update flags correctly --- train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index c6f21f8a1..e6300e2f5 100644 --- a/train.py +++ b/train.py @@ -58,10 +58,10 @@ 'Whether to run train repeatedly, automatically incrementing the window') flags.DEFINE_integer('window_start_at', 10000000, - 'The game number to start the window at (when used with `many`)') + 'Used with `train_many`. The game number where the window begins') -flags.DEFINE_integer('num_repeats', 3, - 'Used with `many`. The number of times to increment the window and re-train.') +flags.DEFINE_integer('num_datasets', 3, + 'Used with `train_many`. The number of times to increment the window and re-train.') flags.register_multi_flags_validator(