Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resnet device #410

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Vision/classification/image/resnet50/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def parse_args(ignore_unknown_args=False):
parser = argparse.ArgumentParser(
description="OneFlow ResNet50 Arguments", allow_abbrev=False
)
parser.add_argument("--device", type=str, default="cuda", help="device: cpu, cuda...")
parser.add_argument(
"--save",
type=str,
Expand Down
2 changes: 2 additions & 0 deletions Vision/classification/image/resnet50/examples/train_eager.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ VAL_BATCH_SIZE=50
SRC_DIR=$(realpath $(dirname $0)/..)

python3 $SRC_DIR/train.py \
--device npu \
--label-smoothing 0 \
--ofrecord-path $OFRECORD_PATH \
--ofrecord-part-num $OFRECORD_PART_NUM \
--num-devices-per-node 1 \
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# set -aux

DEVICE_NUM_PER_NODE=8
MASTER_ADDR=127.0.0.1
NUM_NODES=1
NODE_RANK=0

export PYTHONUNBUFFERED=1
echo PYTHONUNBUFFERED=$PYTHONUNBUFFERED
export NCCL_LAUNCH_MODE=PARALLEL
echo NCCL_LAUNCH_MODE=$NCCL_LAUNCH_MODE
# export NCCL_DEBUG=INFO
# export ONEFLOW_DEBUG_MODE=True

CHECKPOINT_SAVE_PATH="./graph_distributed_fp32_checkpoints"
if [ ! -d "$CHECKPOINT_SAVE_PATH" ]; then
mkdir $CHECKPOINT_SAVE_PATH
fi

#OFRECORD_PATH=PATH_TO_IMAGENET_OFRECORD
OFRECORD_PATH="/data0/datasets/ImageNet/ofrecord"

OFRECORD_PART_NUM=256
LEARNING_RATE=0.768
MOM=0.875
EPOCH=50
TRAIN_BATCH_SIZE=96
VAL_BATCH_SIZE=50

# SRC_DIR=/path/to/models/resnet50
SRC_DIR=$(realpath $(dirname $0)/..)

python3 -m oneflow.distributed.launch \
--nproc_per_node $DEVICE_NUM_PER_NODE \
--nnodes $NUM_NODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
$SRC_DIR/train.py \
--device npu \
--label-smoothing 0 \
--print-interval 100 \
--save $CHECKPOINT_SAVE_PATH \
--ofrecord-path $OFRECORD_PATH \
--ofrecord-part-num $OFRECORD_PART_NUM \
--num-devices-per-node $DEVICE_NUM_PER_NODE \
--lr $LEARNING_RATE \
--momentum $MOM \
--num-epochs $EPOCH \
--train-batch-size $TRAIN_BATCH_SIZE \
--val-batch-size $VAL_BATCH_SIZE \
--scale-grad \
#--graph \
#--fuse-bn-relu \
#--fuse-bn-add-relu \
8 changes: 4 additions & 4 deletions Vision/classification/image/resnet50/examples/train_graph.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ if [ ! -d "$CHECKPOINT_SAVE_PATH" ]; then
mkdir $CHECKPOINT_SAVE_PATH
fi

OFRECORD_PATH="./mini-imagenet/ofrecord"
OFRECORD_PATH="/data0/datasets/ImageNet/ofrecord"

if [ ! -d "$OFRECORD_PATH" ]; then
wget https://oneflow-public.oss-cn-beijing.aliyuncs.com/online_document/dataset/imagenet/mini-imagenet.zip
Expand All @@ -26,6 +26,8 @@ VAL_BATCH_SIZE=50
SRC_DIR=$(realpath $(dirname $0)/..)

python3 $SRC_DIR/train.py \
--device npu \
--label-smoothing 0 \
--ofrecord-path $OFRECORD_PATH \
--ofrecord-part-num $OFRECORD_PART_NUM \
--num-devices-per-node 1 \
Expand All @@ -35,9 +37,7 @@ python3 $SRC_DIR/train.py \
--warmup-epochs 0 \
--train-batch-size $TRAIN_BATCH_SIZE \
--val-batch-size $VAL_BATCH_SIZE \
--save $CHECKPOINT_SAVE_PATH \
--samples-per-epoch 50 \
--val-samples-per-epoch 50 \
--use-gpu-decode \
--scale-grad \
--graph \
--skip-eval \
15 changes: 8 additions & 7 deletions Vision/classification/image/resnet50/examples/train_graph_distributed_fp32.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# set -aux

DEVICE_NUM_PER_NODE=8
DEVICE_NUM_PER_NODE=1
MASTER_ADDR=127.0.0.1
NUM_NODES=1
NODE_RANK=0
Expand All @@ -17,7 +17,7 @@ if [ ! -d "$CHECKPOINT_SAVE_PATH" ]; then
mkdir $CHECKPOINT_SAVE_PATH
fi

OFRECORD_PATH=PATH_TO_IMAGENET_OFRECORD
OFRECORD_PATH="/data0/datasets/ImageNet/ofrecord"

OFRECORD_PART_NUM=256
LEARNING_RATE=0.768
Expand All @@ -35,7 +35,9 @@ python3 -m oneflow.distributed.launch \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
$SRC_DIR/train.py \
--save $CHECKPOINT_SAVE_PATH \
--device npu \
--label-smoothing 0 \
--print-interval 100 \
--ofrecord-path $OFRECORD_PATH \
--ofrecord-part-num $OFRECORD_PART_NUM \
--num-devices-per-node $DEVICE_NUM_PER_NODE \
Expand All @@ -44,8 +46,7 @@ python3 -m oneflow.distributed.launch \
--num-epochs $EPOCH \
--train-batch-size $TRAIN_BATCH_SIZE \
--val-batch-size $VAL_BATCH_SIZE \
--use-gpu-decode \
--scale-grad \
--graph \
--fuse-bn-relu \
--fuse-bn-add-relu \
--skip-eval \
# --scale-grad \

10 changes: 6 additions & 4 deletions Vision/classification/image/resnet50/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ def __init__(
self.cross_entropy = cross_entropy
self.data_loader = data_loader
self.add_optimizer(optimizer, lr_sch=lr_scheduler)
self.device = args.device

def build(self):
image, label = self.data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
logits = self.model(image)
loss = self.cross_entropy(logits, label)
if self.return_pred_and_label:
Expand All @@ -79,11 +80,12 @@ def __init__(self, model, data_loader):

self.data_loader = data_loader
self.model = model
self.device = args.device

def build(self):
image, label = self.data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
logits = self.model(image)
pred = logits.softmax()
return pred, label
4 changes: 2 additions & 2 deletions Vision/classification/image/resnet50/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def main(args):
print("***** Model Init *****")
model = resnet50()
model.load_state_dict(flow.load(args.model_path))
model = model.to("cuda")
model = model.to(args.device)
model.eval()
end_t = time.perf_counter()
print(f"***** Model Init Finish, time escapled {end_t - start_t:.6f} s *****")
Expand All @@ -65,7 +65,7 @@ def main(args):

start_t = end_t
image = load_image(args.image_path)
image = flow.Tensor(image, device=flow.device("cuda"))
image = flow.Tensor(image, device=flow.device(args.device))
if args.graph:
pred = model_graph(image)
else:
Expand Down
16 changes: 11 additions & 5 deletions Vision/classification/image/resnet50/models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def make_data_loader(args, mode, is_global=False, synthetic=False):
placement=placement,
sbp=sbp,
channel_last=args.channel_last,
device=args.device,
)
return data_loader.to("cuda")
return data_loader.to(args.device)

ofrecord_data_loader = OFRecordDataLoader(
ofrecord_dir=args.ofrecord_path,
Expand All @@ -45,6 +46,7 @@ def make_data_loader(args, mode, is_global=False, synthetic=False):
placement=placement,
sbp=sbp,
use_gpu_decode=args.use_gpu_decode,
device=args.device,
)
return ofrecord_data_loader

Expand All @@ -62,6 +64,7 @@ def __init__(
placement=None,
sbp=None,
use_gpu_decode=False,
device="cuda",
):
super().__init__()

Expand All @@ -71,6 +74,7 @@ def __init__(
self.total_batch_size = total_batch_size
self.dataset_size = dataset_size
self.mode = mode
self.device = device

random_shuffle = True if mode == "train" else False
shuffle_after_epoch = True if mode == "train" else False
Expand Down Expand Up @@ -159,11 +163,12 @@ def forward(self):
else:
image_raw_bytes = self.image_decoder(record)
image = self.resize(image_raw_bytes)[0]
image = image.to("cuda")

label = self.label_decoder(record)
flip_code = self.flip()
flip_code = flip_code.to("cuda")
if self.use_gpu_decode:
# todo NPU: image will down grade to cpu
flip_code = flip_code.to(self.device)
image = self.crop_mirror_norm(image, flip_code)
else:
record = self.ofrecord_reader()
Expand All @@ -184,6 +189,7 @@ def __init__(
placement=None,
sbp=None,
channel_last=False,
device="cuda",
):
super().__init__()

Expand Down Expand Up @@ -220,10 +226,10 @@ def __init__(
)
else:
self.image = flow.randint(
0, high=256, size=self.image_shape, dtype=flow.float32, device="cuda"
0, high=256, size=self.image_shape, dtype=flow.float32, device=device,
)
self.label = flow.randint(
0, high=self.num_classes, size=self.label_shape, device="cuda",
0, high=self.num_classes, size=self.label_shape, device=device,
).to(dtype=flow.int32)

def forward(self):
Expand Down
16 changes: 9 additions & 7 deletions Vision/classification/image/resnet50/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time

import oneflow as flow
import oneflow_npu
from oneflow.nn.parallel import DistributedDataParallel as ddp

from config import get_args
Expand All @@ -26,6 +27,7 @@
class Trainer(object):
def __init__(self):
args = get_args()
self.device = args.device
for k, v in args.__dict__.items():
setattr(self, k, v)

Expand Down Expand Up @@ -89,12 +91,12 @@ def init_model(self):
start_t = time.perf_counter()

if self.is_global:
placement = flow.env.all_device_placement("cuda")
placement = flow.env.all_device_placement(self.device)
self.model = self.model.to_global(
placement=placement, sbp=flow.sbp.broadcast
)
else:
self.model = self.model.to("cuda")
self.model = self.model.to(self.device)

if self.load_path is None:
self.legacy_init_parameters()
Expand Down Expand Up @@ -276,7 +278,7 @@ def train_eager(self):
param.grad /= self.world_size
else:
loss.backward()
loss = loss / self.world_size
#loss = loss / self.world_size

self.optimizer.step()
self.optimizer.zero_grad()
Expand Down Expand Up @@ -311,8 +313,8 @@ def eval(self):

def forward(self):
image, label = self.train_data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
logits = self.model(image)
loss = self.cross_entropy(logits, label)
if self.metric_train_acc:
Expand All @@ -323,8 +325,8 @@ def forward(self):

def inference(self):
image, label = self.val_data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
with flow.no_grad():
logits = self.model(image)
pred = logits.softmax()
Expand Down