Skip to content

Commit 91988af

Browse files
committed
pass linter checks
1 parent 73e3ea6 commit 91988af

File tree

13 files changed

+1091
-1046
lines changed

13 files changed

+1091
-1046
lines changed

algoperf/checkpoint_utils.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,49 +31,52 @@
3131
int,
3232
]
3333

34+
3435
class BoolHandler(NumpyHandler):
36+
"""
37+
An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler.
38+
It works by treating the scalar as a 0-dimensional array.
39+
"""
40+
41+
def typestr(self) -> str:
42+
"""Unique string identifier for this handler."""
43+
return 'np.bool_'
44+
45+
async def serialize(
46+
self,
47+
values: Sequence[np.bool_],
48+
infos: Sequence,
49+
args: Optional[Sequence[ocp.SaveArgs]] = None,
50+
):
3551
"""
36-
An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler.
37-
It works by treating the scalar as a 0-dimensional array.
52+
Serializes a sequence of np.bool_ scalars by first converting them
53+
to 0-dim numpy arrays and then calling the parent NumpyHandler.
3854
"""
55+
# Convert each scalar np.bool_ to a 0-dimensional np.ndarray
56+
array_values = [np.asarray(v, dtype=np.bool_) for v in values]
57+
# Use the parent class's robust serialization logic
58+
return await super().serialize(array_values, infos, args)
59+
60+
async def deserialize(
61+
self,
62+
infos: Sequence,
63+
args: Optional[Sequence[ocp.RestoreArgs]] = None,
64+
) -> Sequence[np.bool_]:
65+
"""
66+
Deserializes into a sequence of np.bool_ scalars by calling the
67+
parent handler and then converting the resulting 0-dim arrays.
68+
"""
69+
# Parent deserialize will return a sequence of 0-dimensional np.ndarray
70+
results = await super().deserialize(infos, args)
3971

40-
def typestr(self) -> str:
41-
"""Unique string identifier for this handler."""
42-
return 'np.bool_'
72+
# Convert each 0-d array back to an np.bool_ scalar using .item()
73+
scalar_results = [np.bool_(r.item()) for r in results]
74+
return scalar_results
4375

44-
async def serialize(
45-
self,
46-
values: Sequence[np.bool_],
47-
infos: Sequence,
48-
args: Optional[Sequence[ocp.SaveArgs]] = None,
49-
):
50-
"""
51-
Serializes a sequence of np.bool_ scalars by first converting them
52-
to 0-dim numpy arrays and then calling the parent NumpyHandler.
53-
"""
54-
# Convert each scalar np.bool_ to a 0-dimensional np.ndarray
55-
array_values = [np.asarray(v, dtype=np.bool_) for v in values]
56-
# Use the parent class's robust serialization logic
57-
return await super().serialize(array_values, infos, args)
58-
59-
async def deserialize(
60-
self,
61-
infos: Sequence,
62-
args: Optional[Sequence[ocp.RestoreArgs]] = None,
63-
) -> Sequence[np.bool_]:
64-
"""
65-
Deserializes into a sequence of np.bool_ scalars by calling the
66-
parent handler and then converting the resulting 0-dim arrays.
67-
"""
68-
# Parent deserialize will return a sequence of 0-dimensional np.ndarray
69-
results = await super().deserialize(infos, args)
70-
71-
# Convert each 0-d array back to an np.bool_ scalar using .item()
72-
scalar_results = [np.bool_(r.item()) for r in results]
73-
return scalar_results
7476

7577
ocp.type_handlers.register_type_handler(np.bool_, BoolHandler(), override=True)
7678

79+
7780
def maybe_restore_checkpoint(
7881
framework: str,
7982
optimizer_state: spec.OptimizerState,

algoperf/pytorch_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]:
2727
return use_pytorch_ddp, rank, device, n_gpus
2828

2929

30-
def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads = True) -> None:
30+
def pytorch_init(
31+
use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads=True
32+
) -> None:
3133
# Make sure no GPU memory is preallocated to Jax.
3234
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
3335
# Only use CPU for Jax to avoid memory issues.
@@ -47,8 +49,10 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_
4749
profiler.set_local_rank(rank)
4850
# Only log once (for local rank == 0).
4951
if rank != 0:
52+
5053
def logging_pass(*args):
5154
pass
55+
5256
logging.info = logging_pass
5357
# Initialize the process group.
5458
dist.init_process_group('nccl')

algoperf/workloads/lm/input_pipeline.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,15 @@ def batch_with_padding(
5050
return padded_batched_dataset
5151

5252

53-
def get_data_iter(data_rng: jax.random.PRNGKey,
53+
def get_data_iter(
54+
data_rng: jax.random.PRNGKey,
5455
split: str,
5556
data_dir: str,
5657
batch_size: int,
57-
num_batches: Optional[int] = None,):
58-
58+
num_batches: Optional[int] = None,
59+
):
5960
ds = get_lm_dataset(data_rng, split, data_dir, batch_size, num_batches)
60-
61+
6162
it = map(
6263
functools.partial(
6364
data_utils.shard_and_maybe_pad_np, global_batch_size=batch_size
@@ -67,6 +68,7 @@ def get_data_iter(data_rng: jax.random.PRNGKey,
6768

6869
return iter(it)
6970

71+
7072
def get_lm_dataset(
7173
data_rng: jax.random.PRNGKey,
7274
split: str,
@@ -78,7 +80,7 @@ def get_lm_dataset(
7880
if split not in TFDS_SPLIT_NAME:
7981
raise NotImplementedError
8082

81-
shuffle_seed = jax.random.randint(data_rng, (), -2**31, 2**31-1)
83+
shuffle_seed = jax.random.randint(data_rng, (), -(2**31), 2**31 - 1)
8284

8385
data_dir = os.path.join(data_dir, TFDS_SPLIT_NAME[split])
8486
tokens_ds = tf.data.Dataset.load(data_dir)
@@ -98,19 +100,17 @@ def get_lm_dataset(
98100
num_parallel_calls=AUTOTUNE,
99101
)
100102
if split == 'train':
101-
ds = sequences_ds.shuffle(
102-
SHUFFLE_BUFFER_SIZE, seed=shuffle_seed
103-
)
104-
ds = ds.batch(
105-
batch_size, drop_remainder=False
106-
)
103+
ds = sequences_ds.shuffle(SHUFFLE_BUFFER_SIZE, seed=shuffle_seed)
104+
ds = ds.batch(batch_size, drop_remainder=False)
107105
ds = ds.take(num_batches) if num_batches is not None else ds
108106
ds = ds.repeat()
109-
ds = ds.map(lambda x: {
110-
'inputs': x['inputs'],
111-
'targets': x['targets'],
112-
'weights': None,
113-
})
107+
ds = ds.map(
108+
lambda x: {
109+
'inputs': x['inputs'],
110+
'targets': x['targets'],
111+
'weights': None,
112+
}
113+
)
114114
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
115115
elif split == 'eval_train':
116116
ds = batch_with_padding(
@@ -123,10 +123,13 @@ def get_lm_dataset(
123123
)
124124
ds = ds.take(num_batches) if num_batches is not None else ds
125125
ds = ds.repeat()
126-
ds = ds.map(lambda x: {'inputs': x['inputs'],
127-
'targets': x['targets'],
128-
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)
129-
})
126+
ds = ds.map(
127+
lambda x: {
128+
'inputs': x['inputs'],
129+
'targets': x['targets'],
130+
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0),
131+
}
132+
)
130133
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
131134
elif split == 'validation':
132135
ds = batch_with_padding(
@@ -139,9 +142,12 @@ def get_lm_dataset(
139142
)
140143
ds = ds.take(num_batches) if num_batches is not None else ds
141144
ds = ds.repeat()
142-
ds = ds.map(lambda x: {'inputs': x['inputs'],
143-
'targets': x['targets'],
144-
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0)
145-
})
145+
ds = ds.map(
146+
lambda x: {
147+
'inputs': x['inputs'],
148+
'targets': x['targets'],
149+
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0),
150+
}
151+
)
146152
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
147153
return ds

algoperf/workloads/lm/lm_jax/models.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

0 commit comments

Comments
 (0)