Skip to content
This repository was archived by the owner on Mar 11, 2021. It is now read-only.

"sliding window" bigtable training mode #713

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
34 changes: 34 additions & 0 deletions preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,40 @@ def get_tpu_bt_input_tensors(games, games_nr, batch_size, num_repeats=1,
return dataset


def get_many_tpu_bt_input_tensors(games, games_nr, batch_size,
start_at, num_datasets,
moves=2**21,
window_size=500e3,
window_increment=25000):
dataset = None
for i in range(num_datasets):
# TODO(amj) mixin calibration games with some math. (from start_at that
# is proportionally along compared to last_game_number? comparing
# timestamps?)
ds = games.moves_from_games(start_at + (i * window_increment),
start_at + (i * window_increment) + window_size,
moves=moves,
shuffle=True,
column_family=bigtable_input.TFEXAMPLE,
column='example')
dataset = dataset.concatenate(ds) if dataset else ds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the general approach: if the training loop does multiple scans, I would expect to create a new dataset for each pass, rather than try to create a single enormous dataset, which I imagine would be harder to debug, inspect, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but multiple calls to tpuestimator.train will create new graphs :( I am not sure what a good solution for lazy evaluating of these Datasets would be. As it is, it takes a real long time to build the datasets before training even starts -- i suspect the concatenate is doing something bad as things get slower and slower.


dataset = dataset.repeat(1)
dataset = dataset.map(lambda row_name, s: s)
dataset = dataset.batch(batch_size,drop_remainder=False)
dataset = dataset.map(
functools.partial(batch_parse_tf_example, batch_size))
# Unbatch the dataset so we can rotate it
dataset = dataset.apply(tf.contrib.data.unbatch())
dataset = dataset.apply(tf.contrib.data.map_and_batch(
_random_rotation_pure_tf,
batch_size,
drop_remainder=True))

dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
return dataset


def make_dataset_from_selfplay(data_extracts):
"""
Returns an iterable of tf.Examples.
Expand Down
56 changes: 51 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@
flags.DEFINE_bool('freeze', False,
'Whether to freeze the graph at the end of training.')

flags.DEFINE_bool('train_many', False,
'Whether to run train repeatedly, automatically incrementing the window')

flags.DEFINE_integer('window_start_at', 10000000,
'Used with `train_many`. The game number where the window begins')

flags.DEFINE_integer('num_datasets', 3,
'Used with `train_many`. The number of times to increment the window and re-train.')


flags.register_multi_flags_validator(
['use_bt', 'use_tpu'],
Expand Down Expand Up @@ -139,6 +148,39 @@ def after_run(self, run_context, run_values):
self.before_weights = None


def train_many(start_at=1000000, num_datasets=3, moves=2**24):
""" Trains on a set of bt_datasets, skipping eval for now.
(from preprocessing.get_many_tpu_bt_input_tensors)
"""
if not FLAGS.use_tpu and FLAGS.use_bt:
raise ValueError("Only tpu & bt mode supported")

tf.logging.set_verbosity(tf.logging.INFO)
estimator = dual_net.get_estimator()
effective_batch_size = FLAGS.train_batch_size * FLAGS.num_tpu_cores

def _input_fn(params):
games = bigtable_input.GameQueue(
FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table)
games_nr = bigtable_input.GameQueue(
FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table + '-nr')

return preprocessing.get_many_tpu_bt_input_tensors(
games, games_nr, params['batch_size'],
moves=moves,
window_size=FLAGS.window_size,
start_at=start_at, num_datasets=num_datasets)

hooks = []

steps = num_datasets * FLAGS.steps_to_train
logging.info("Training, steps = %s, batch = %s -> %s examples",
steps or '?', effective_batch_size,
(steps * effective_batch_size) if steps else '?')

estimator.train(_input_fn, steps=steps, hooks=hooks)


def train(*tf_records: "Records to train on"):
"""Train on examples."""
tf.logging.set_verbosity(tf.logging.INFO)
Expand Down Expand Up @@ -209,11 +251,15 @@ def _input_fn():

def main(argv):
"""Train on examples and export the updated model weights."""
tf_records = argv[1:]
logging.info("Training on %s records: %s to %s",
len(tf_records), tf_records[0], tf_records[-1])
with utils.logged_timer("Training"):
train(*tf_records)
if FLAGS.train_many:
with utils.logged_timer("Training"):
train_many(FLAGS.window_start_at, FLAGS.num_datasets)
else:
tf_records = argv[1:]
logging.info("Training on %s records: %s to %s",
len(tf_records), tf_records[0], tf_records[-1])
with utils.logged_timer("Training"):
train(*tf_records)
if FLAGS.export_path:
dual_net.export_model(FLAGS.export_path)
if FLAGS.freeze:
Expand Down