diff --git a/Vision/classification/image/resnet50/config.py b/Vision/classification/image/resnet50/config.py index 63f3e25e2..129c8968c 100644 --- a/Vision/classification/image/resnet50/config.py +++ b/Vision/classification/image/resnet50/config.py @@ -26,6 +26,7 @@ def parse_args(ignore_unknown_args=False): parser = argparse.ArgumentParser( description="OneFlow ResNet50 Arguments", allow_abbrev=False ) + parser.add_argument("--device", type=str, default="cuda", help="device: cpu, cuda...") parser.add_argument( "--save", type=str, diff --git a/Vision/classification/image/resnet50/examples/train_eager.sh b/Vision/classification/image/resnet50/examples/train_eager.sh index 46461d208..e75b485a0 100644 --- a/Vision/classification/image/resnet50/examples/train_eager.sh +++ b/Vision/classification/image/resnet50/examples/train_eager.sh @@ -26,6 +26,8 @@ VAL_BATCH_SIZE=50 SRC_DIR=$(realpath $(dirname $0)/..) python3 $SRC_DIR/train.py \ + --device npu \ + --label-smoothing 0 \ --ofrecord-path $OFRECORD_PATH \ --ofrecord-part-num $OFRECORD_PART_NUM \ --num-devices-per-node 1 \ diff --git a/Vision/classification/image/resnet50/examples/train_eager_distributed_fp32.sh b/Vision/classification/image/resnet50/examples/train_eager_distributed_fp32.sh new file mode 100644 index 000000000..cfbd8c09c --- /dev/null +++ b/Vision/classification/image/resnet50/examples/train_eager_distributed_fp32.sh @@ -0,0 +1,54 @@ +# set -aux + +DEVICE_NUM_PER_NODE=8 +MASTER_ADDR=127.0.0.1 +NUM_NODES=1 +NODE_RANK=0 + +export PYTHONUNBUFFERED=1 +echo PYTHONUNBUFFERED=$PYTHONUNBUFFERED +export NCCL_LAUNCH_MODE=PARALLEL +echo NCCL_LAUNCH_MODE=$NCCL_LAUNCH_MODE +# export NCCL_DEBUG=INFO +# export ONEFLOW_DEBUG_MODE=True + +CHECKPOINT_SAVE_PATH="./graph_distributed_fp32_checkpoints" +if [ ! -d "$CHECKPOINT_SAVE_PATH" ]; then + mkdir $CHECKPOINT_SAVE_PATH +fi + +#OFRECORD_PATH=PATH_TO_IMAGENET_OFRECORD +OFRECORD_PATH="/data0/datasets/ImageNet/ofrecord" + +OFRECORD_PART_NUM=256 +LEARNING_RATE=0.768 +MOM=0.875 +EPOCH=50 +TRAIN_BATCH_SIZE=96 +VAL_BATCH_SIZE=50 + +# SRC_DIR=/path/to/models/resnet50 +SRC_DIR=$(realpath $(dirname $0)/..) + +python3 -m oneflow.distributed.launch \ + --nproc_per_node $DEVICE_NUM_PER_NODE \ + --nnodes $NUM_NODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + $SRC_DIR/train.py \ + --device npu \ + --label-smoothing 0 \ + --print-interval 100 \ + --save $CHECKPOINT_SAVE_PATH \ + --ofrecord-path $OFRECORD_PATH \ + --ofrecord-part-num $OFRECORD_PART_NUM \ + --num-devices-per-node $DEVICE_NUM_PER_NODE \ + --lr $LEARNING_RATE \ + --momentum $MOM \ + --num-epochs $EPOCH \ + --train-batch-size $TRAIN_BATCH_SIZE \ + --val-batch-size $VAL_BATCH_SIZE \ + --scale-grad \ + #--graph \ + #--fuse-bn-relu \ + #--fuse-bn-add-relu \ diff --git a/Vision/classification/image/resnet50/examples/train_graph.sh b/Vision/classification/image/resnet50/examples/train_graph.sh old mode 100644 new mode 100755 index 3e267e0bf..7636391a7 --- a/Vision/classification/image/resnet50/examples/train_graph.sh +++ b/Vision/classification/image/resnet50/examples/train_graph.sh @@ -8,7 +8,7 @@ if [ ! -d "$CHECKPOINT_SAVE_PATH" ]; then mkdir $CHECKPOINT_SAVE_PATH fi -OFRECORD_PATH="./mini-imagenet/ofrecord" +OFRECORD_PATH="/data0/datasets/ImageNet/ofrecord" if [ ! -d "$OFRECORD_PATH" ]; then wget https://oneflow-public.oss-cn-beijing.aliyuncs.com/online_document/dataset/imagenet/mini-imagenet.zip @@ -26,6 +26,8 @@ VAL_BATCH_SIZE=50 SRC_DIR=$(realpath $(dirname $0)/..) python3 $SRC_DIR/train.py \ + --device npu \ + --label-smoothing 0 \ --ofrecord-path $OFRECORD_PATH \ --ofrecord-part-num $OFRECORD_PART_NUM \ --num-devices-per-node 1 \ @@ -35,9 +37,7 @@ python3 $SRC_DIR/train.py \ --warmup-epochs 0 \ --train-batch-size $TRAIN_BATCH_SIZE \ --val-batch-size $VAL_BATCH_SIZE \ - --save $CHECKPOINT_SAVE_PATH \ --samples-per-epoch 50 \ --val-samples-per-epoch 50 \ - --use-gpu-decode \ - --scale-grad \ --graph \ + --skip-eval \ diff --git a/Vision/classification/image/resnet50/examples/train_graph_distributed_fp32.sh b/Vision/classification/image/resnet50/examples/train_graph_distributed_fp32.sh old mode 100644 new mode 100755 index 038b1c812..27c748aa0 --- a/Vision/classification/image/resnet50/examples/train_graph_distributed_fp32.sh +++ b/Vision/classification/image/resnet50/examples/train_graph_distributed_fp32.sh @@ -1,6 +1,6 @@ # set -aux -DEVICE_NUM_PER_NODE=8 +DEVICE_NUM_PER_NODE=1 MASTER_ADDR=127.0.0.1 NUM_NODES=1 NODE_RANK=0 @@ -17,7 +17,7 @@ if [ ! -d "$CHECKPOINT_SAVE_PATH" ]; then mkdir $CHECKPOINT_SAVE_PATH fi -OFRECORD_PATH=PATH_TO_IMAGENET_OFRECORD +OFRECORD_PATH="/data0/datasets/ImageNet/ofrecord" OFRECORD_PART_NUM=256 LEARNING_RATE=0.768 @@ -35,7 +35,9 @@ python3 -m oneflow.distributed.launch \ --node_rank $NODE_RANK \ --master_addr $MASTER_ADDR \ $SRC_DIR/train.py \ - --save $CHECKPOINT_SAVE_PATH \ + --device npu \ + --label-smoothing 0 \ + --print-interval 100 \ --ofrecord-path $OFRECORD_PATH \ --ofrecord-part-num $OFRECORD_PART_NUM \ --num-devices-per-node $DEVICE_NUM_PER_NODE \ @@ -44,8 +46,7 @@ python3 -m oneflow.distributed.launch \ --num-epochs $EPOCH \ --train-batch-size $TRAIN_BATCH_SIZE \ --val-batch-size $VAL_BATCH_SIZE \ - --use-gpu-decode \ - --scale-grad \ --graph \ - --fuse-bn-relu \ - --fuse-bn-add-relu \ + --skip-eval \ + # --scale-grad \ + diff --git a/Vision/classification/image/resnet50/graph.py b/Vision/classification/image/resnet50/graph.py index dcad741ba..58ab63689 100644 --- a/Vision/classification/image/resnet50/graph.py +++ b/Vision/classification/image/resnet50/graph.py @@ -51,11 +51,12 @@ def __init__( self.cross_entropy = cross_entropy self.data_loader = data_loader self.add_optimizer(optimizer, lr_sch=lr_scheduler) + self.device = args.device def build(self): image, label = self.data_loader() - image = image.to("cuda") - label = label.to("cuda") + image = image.to(self.device) + label = label.to(self.device) logits = self.model(image) loss = self.cross_entropy(logits, label) if self.return_pred_and_label: @@ -79,11 +80,12 @@ def __init__(self, model, data_loader): self.data_loader = data_loader self.model = model + self.device = args.device def build(self): image, label = self.data_loader() - image = image.to("cuda") - label = label.to("cuda") + image = image.to(self.device) + label = label.to(self.device) logits = self.model(image) pred = logits.softmax() return pred, label diff --git a/Vision/classification/image/resnet50/infer.py b/Vision/classification/image/resnet50/infer.py index 85f19ed6a..8837ec39a 100644 --- a/Vision/classification/image/resnet50/infer.py +++ b/Vision/classification/image/resnet50/infer.py @@ -55,7 +55,7 @@ def main(args): print("***** Model Init *****") model = resnet50() model.load_state_dict(flow.load(args.model_path)) - model = model.to("cuda") + model = model.to(args.device) model.eval() end_t = time.perf_counter() print(f"***** Model Init Finish, time escapled {end_t - start_t:.6f} s *****") @@ -65,7 +65,7 @@ def main(args): start_t = end_t image = load_image(args.image_path) - image = flow.Tensor(image, device=flow.device("cuda")) + image = flow.Tensor(image, device=flow.device(args.device)) if args.graph: pred = model_graph(image) else: diff --git a/Vision/classification/image/resnet50/models/data.py b/Vision/classification/image/resnet50/models/data.py index ee8da362f..2f3cbefa9 100644 --- a/Vision/classification/image/resnet50/models/data.py +++ b/Vision/classification/image/resnet50/models/data.py @@ -31,8 +31,9 @@ def make_data_loader(args, mode, is_global=False, synthetic=False): placement=placement, sbp=sbp, channel_last=args.channel_last, + device=args.device, ) - return data_loader.to("cuda") + return data_loader.to(args.device) ofrecord_data_loader = OFRecordDataLoader( ofrecord_dir=args.ofrecord_path, @@ -45,6 +46,7 @@ def make_data_loader(args, mode, is_global=False, synthetic=False): placement=placement, sbp=sbp, use_gpu_decode=args.use_gpu_decode, + device=args.device, ) return ofrecord_data_loader @@ -62,6 +64,7 @@ def __init__( placement=None, sbp=None, use_gpu_decode=False, + device="cuda", ): super().__init__() @@ -71,6 +74,7 @@ def __init__( self.total_batch_size = total_batch_size self.dataset_size = dataset_size self.mode = mode + self.device = device random_shuffle = True if mode == "train" else False shuffle_after_epoch = True if mode == "train" else False @@ -159,11 +163,12 @@ def forward(self): else: image_raw_bytes = self.image_decoder(record) image = self.resize(image_raw_bytes)[0] - image = image.to("cuda") label = self.label_decoder(record) flip_code = self.flip() - flip_code = flip_code.to("cuda") + if self.use_gpu_decode: + # todo NPU: image will down grade to cpu + flip_code = flip_code.to(self.device) image = self.crop_mirror_norm(image, flip_code) else: record = self.ofrecord_reader() @@ -184,6 +189,7 @@ def __init__( placement=None, sbp=None, channel_last=False, + device="cuda", ): super().__init__() @@ -220,10 +226,10 @@ def __init__( ) else: self.image = flow.randint( - 0, high=256, size=self.image_shape, dtype=flow.float32, device="cuda" + 0, high=256, size=self.image_shape, dtype=flow.float32, device=device, ) self.label = flow.randint( - 0, high=self.num_classes, size=self.label_shape, device="cuda", + 0, high=self.num_classes, size=self.label_shape, device=device, ).to(dtype=flow.int32) def forward(self): diff --git a/Vision/classification/image/resnet50/train.py b/Vision/classification/image/resnet50/train.py index c1ba49ba4..43d54617d 100644 --- a/Vision/classification/image/resnet50/train.py +++ b/Vision/classification/image/resnet50/train.py @@ -9,6 +9,7 @@ import time import oneflow as flow +import oneflow_npu from oneflow.nn.parallel import DistributedDataParallel as ddp from config import get_args @@ -26,6 +27,7 @@ class Trainer(object): def __init__(self): args = get_args() + self.device = args.device for k, v in args.__dict__.items(): setattr(self, k, v) @@ -89,12 +91,12 @@ def init_model(self): start_t = time.perf_counter() if self.is_global: - placement = flow.env.all_device_placement("cuda") + placement = flow.env.all_device_placement(self.device) self.model = self.model.to_global( placement=placement, sbp=flow.sbp.broadcast ) else: - self.model = self.model.to("cuda") + self.model = self.model.to(self.device) if self.load_path is None: self.legacy_init_parameters() @@ -276,7 +278,7 @@ def train_eager(self): param.grad /= self.world_size else: loss.backward() - loss = loss / self.world_size + #loss = loss / self.world_size self.optimizer.step() self.optimizer.zero_grad() @@ -311,8 +313,8 @@ def eval(self): def forward(self): image, label = self.train_data_loader() - image = image.to("cuda") - label = label.to("cuda") + image = image.to(self.device) + label = label.to(self.device) logits = self.model(image) loss = self.cross_entropy(logits, label) if self.metric_train_acc: @@ -323,8 +325,8 @@ def forward(self): def inference(self): image, label = self.val_data_loader() - image = image.to("cuda") - label = label.to("cuda") + image = image.to(self.device) + label = label.to(self.device) with flow.no_grad(): logits = self.model(image) pred = logits.softmax()