Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions algoperf/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,34 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import torch
import torch.nn.functional as F
from absl import logging
from torch import nn


class DTYPE(enum.Enum):
FLOAT32 = 0
FLOAT16 = 1
BFLOAT16 = 2


# Mapping from DTYPE enum to JAX dtypes
JAX_DTYPE_MAP = {
DTYPE.FLOAT32: jnp.float32,
DTYPE.FLOAT16: jnp.float16,
DTYPE.BFLOAT16: jnp.bfloat16,
}

# Mapping from DTYPE enum to PyTorch dtypes
PYTORCH_DTYPE_MAP = {
DTYPE.FLOAT32: torch.float32,
DTYPE.FLOAT16: torch.float16,
DTYPE.BFLOAT16: torch.bfloat16,
}


class LossType(enum.Enum):
SOFTMAX_CROSS_ENTROPY = 0
SIGMOID_CROSS_ENTROPY = 1
Expand Down
2 changes: 0 additions & 2 deletions algoperf/workloads/cifar/cifar_jax/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import jax
import tensorflow as tf
import tensorflow_datasets as tfds
from flax import jax_utils

from algoperf import spec
from algoperf.data_utils import shard_and_maybe_pad_np
Expand Down Expand Up @@ -186,5 +185,4 @@ def create_input_iter(
),
ds,
)
it = jax_utils.prefetch_to_device(it, 2)
return it
8 changes: 5 additions & 3 deletions algoperf/workloads/cifar/cifar_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __call__(
update_batch_norm: bool = True,
use_running_average_bn: bool = None,
) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
conv = functools.partial(nn.Conv, use_bias=False, param_dtype=self.dtype)

# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
Expand All @@ -41,7 +41,7 @@ def __call__(
use_running_average=use_running_average_bn,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype,
param_dtype=self.dtype,
)

x = conv(
Expand All @@ -66,7 +66,9 @@ def __call__(
x = nn.avg_pool(x, (4, 4), strides=(4, 4))
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(
self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype
self.num_classes,
kernel_init=nn.initializers.normal(),
param_dtype=self.dtype,
)(x)
return x

Expand Down
29 changes: 25 additions & 4 deletions algoperf/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import jax
import jax.numpy as jnp
import jmp
import optax
import tensorflow_datasets as tfds
from flax import linen as nn
Expand All @@ -18,6 +19,17 @@


class CifarWorkload(BaseCifarWorkload):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
compute_dtype = spec.JAX_DTYPE_MAP[self._compute_dtype]
param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype]
output_dtype = compute_dtype
self._mp_policy = jmp.Policy(
compute_dtype=compute_dtype,
param_dtype=param_dtype,
output_dtype=output_dtype,
)

def _build_cifar_dataset(
self,
data_rng: spec.RandomState,
Expand Down Expand Up @@ -80,7 +92,8 @@ def sync_batch_stats(
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
"""Dropout is unused."""
model_cls = getattr(models, 'ResNet18')
model = model_cls(num_classes=self._num_classes, dtype=jnp.float32)
param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype]
model = model_cls(num_classes=self._num_classes, dtype=param_dtype)
self._model = model
input_shape = (1, 32, 32, 3)
variables = jax.jit(model.init)(
Expand All @@ -89,7 +102,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
model_state, params = pop(variables, 'params')
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_sharding_utils.replicate(params)
model_state = jax_sharding_utils.replicate(model_state)
params = jax_sharding_utils.replicate(params)
return params, model_state

Expand All @@ -110,24 +123,32 @@ def model_fn(
del mode
del rng
del dropout_rate
# Cast params and inputs to compute dtype
params, inputs = self._mp_policy.cast_to_compute(
(params, augmented_and_preprocessed_input_batch['inputs'])
)
variables = {'params': params, **model_state}
if update_batch_norm:
logits, new_model_state = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
inputs,
update_batch_norm=update_batch_norm,
mutable=['batch_stats'],
use_running_average_bn=use_running_average_bn,
)
# Cast logits to output dtype
logits = self._mp_policy.cast_to_output(logits)
return logits, new_model_state
else:
logits = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
inputs,
update_batch_norm=update_batch_norm,
mutable=False,
use_running_average_bn=use_running_average_bn,
)
# Cast logits to output dtype
logits = self._mp_policy.cast_to_output(logits)
return logits, model_state

# Does NOT apply regularization, which is left to the submitter to do in
Expand Down
24 changes: 21 additions & 3 deletions algoperf/workloads/cifar/cifar_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ def __init__(
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.dtype = dtype

self.inplanes = 64
self.dilation = 1
Expand All @@ -49,7 +51,13 @@ def __init__(
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(
3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
3,
self.inplanes,
kernel_size=3,
stride=1,
padding=1,
bias=False,
dtype=dtype,
)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
Expand All @@ -63,7 +71,7 @@ def __init__(
self.layer4 = self._make_layer(
block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
)
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.fc = nn.Linear(512 * block.expansion, num_classes, dtype=dtype)
self.reset_parameters()

def reset_parameters(self) -> None:
Expand Down Expand Up @@ -105,7 +113,15 @@ def _make_layer(
downsample = torch.nn.Sequential(
collections.OrderedDict(
[
('conv', conv1x1(self.inplanes, planes * block.expansion, stride)),
(
'conv',
conv1x1(
self.inplanes,
planes * block.expansion,
stride,
dtype=self.dtype,
),
),
('bn', norm_layer(planes * block.expansion)),
]
)
Expand All @@ -122,6 +138,7 @@ def _make_layer(
self.base_width,
previous_dilation,
norm_layer,
dtype=self.dtype,
)
)
self.inplanes = planes * block.expansion
Expand All @@ -134,6 +151,7 @@ def _make_layer(
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer,
dtype=self.dtype,
)
)

Expand Down
9 changes: 7 additions & 2 deletions algoperf/workloads/cifar/cifar_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(self, *args, **kwargs) -> None:
# Is set in submission_runner.py for workloads with PyTorch evaluation
# data loaders via the `eval_num_workers` property.
self._eval_num_workers = None
self._param_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._param_dtype]
self._compute_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._compute_dtype]

@property
def eval_num_workers(self) -> int:
Expand Down Expand Up @@ -128,7 +130,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
return self._model, None

torch.random.manual_seed(rng[0])
self._model = resnet18(num_classes=self._num_classes)
self._model = resnet18(
num_classes=self._num_classes, dtype=self._param_dtype_pt
)
self._param_shapes = param_utils.pytorch_param_shapes(self._model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
self._model.to(DEVICE)
Expand Down Expand Up @@ -175,7 +179,8 @@ def model_fn(
spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
logits_batch = model(augmented_and_preprocessed_input_batch['inputs'])
with torch.autocast(device_type='cuda', dtype=self._compute_dtype_pt):
logits_batch = model(augmented_and_preprocessed_input_batch['inputs'])
return logits_batch, None

# Does NOT apply regularization, which is left to the submitter to do in
Expand Down
2 changes: 2 additions & 0 deletions algoperf/workloads/cifar/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

class BaseCifarWorkload(spec.Workload):
_num_classes: int = 10
_compute_dtype: spec.DTYPE = spec.DTYPE.BFLOAT16
_param_dtype: spec.DTYPE = spec.DTYPE.FLOAT32

@property
def target_metric_name(self) -> str:
Expand Down
Loading
Loading