From 218016b340a1c710136029941f8a6135312074c5 Mon Sep 17 00:00:00 2001 From: "Zhong, Ruijie" Date: Wed, 23 Apr 2025 00:01:19 -0700 Subject: [PATCH 1/5] [Nightly] Enhance op microbench --- .github/scripts/microbench_summary.py | 227 ++++++++++++++++++ .github/scripts/microbench_summary.sh | 203 ---------------- .github/workflows/_linux_op_benchmark.yml | 4 +- test/microbench/embedding.py | 29 +++ test/microbench/embedding_bag.py | 51 ++++ test/microbench/indexing.diag.py | 27 +++ test/microbench/indexing.index.py | 62 +++++ test/microbench/indexing.index_add.py | 35 +++ test/microbench/indexing.index_copy.py | 44 ++++ test/microbench/indexing.index_fill.py | 44 ++++ test/microbench/indexing.index_put.py | 65 +++++ test/microbench/indexing.index_select.py | 30 +++ test/microbench/indexing.masked_fill.py | 31 +++ test/microbench/indexing.put.py | 29 +++ test/microbench/indexing.take.py | 28 +++ test/microbench/pooling.max_unpool2d.py | 72 ++++++ test/microbench/pooling.max_unpool3d.py | 72 ++++++ test/microbench/repeat_interleave.py | 48 ++++ test/microbench/scan.masked_select.py | 23 ++ test/microbench/scan.nonzero.py | 26 ++ test/microbench/scatter_gather.gather.py | 48 ++++ test/microbench/scatter_gather.scatter.py | 74 ++++++ test/microbench/scatter_gather.scatter_add.py | 74 ++++++ 23 files changed, 1141 insertions(+), 205 deletions(-) create mode 100644 .github/scripts/microbench_summary.py delete mode 100644 .github/scripts/microbench_summary.sh create mode 100644 test/microbench/embedding.py create mode 100644 test/microbench/embedding_bag.py create mode 100644 test/microbench/indexing.diag.py create mode 100644 test/microbench/indexing.index.py create mode 100644 test/microbench/indexing.index_add.py create mode 100644 test/microbench/indexing.index_copy.py create mode 100644 test/microbench/indexing.index_fill.py create mode 100644 test/microbench/indexing.index_put.py create mode 100644 test/microbench/indexing.index_select.py create mode 100644 test/microbench/indexing.masked_fill.py create mode 100644 test/microbench/indexing.put.py create mode 100644 test/microbench/indexing.take.py create mode 100644 test/microbench/pooling.max_unpool2d.py create mode 100644 test/microbench/pooling.max_unpool3d.py create mode 100644 test/microbench/repeat_interleave.py create mode 100644 test/microbench/scan.masked_select.py create mode 100644 test/microbench/scan.nonzero.py create mode 100644 test/microbench/scatter_gather.gather.py create mode 100644 test/microbench/scatter_gather.scatter.py create mode 100644 test/microbench/scatter_gather.scatter_add.py diff --git a/.github/scripts/microbench_summary.py b/.github/scripts/microbench_summary.py new file mode 100644 index 000000000..089918bd0 --- /dev/null +++ b/.github/scripts/microbench_summary.py @@ -0,0 +1,227 @@ +""" +Microbenchmark Summary Tool - Parses performance logs and generates CSV/Excel reports +# Usage +# Summary forward op time, forward_op_summary.csv is forward summary file +python microbench_summary.py path/to/profile's log g forward_op_summary.csv +# Summary backward op time, backward_op_summary.csv is backward summary file, True means summary backward, default is false. +python microbench_summary.py path/to/profile's log g backward_op_summary.csv --backward +""" + +import re +import pandas as pd +import glob +import os +import argparse +from pathlib import Path +from typing import Dict, List, Optional + +def main(): + parser = argparse.ArgumentParser( + description="Parse performance logs and generate summary reports", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("log_dir", help="Directory containing log files") + parser.add_argument("output_file", help="Output CSV file path") + parser.add_argument("--backward", action="store_true", + help="Process backward operations instead of forward") + args = parser.parse_args() + + try: + df = parse_logs(args.log_dir, args.backward) + if df.empty: + print("Warning: No valid data found in log files!") + return + + save_reports(df, args.output_file) + print(f"Successfully generated reports: {args.output_file} and {args.output_file.replace('.csv', '.xlsx')}") + except Exception as e: + print(f"Error: {str(e)}") + raise + +def parse_logs(log_dir: str, get_backward: bool = False) -> pd.DataFrame: + data = [] + columns = [ + "case_name", "datatype", "op_name", "shape", "channels_last", "dim", + "output_size", "P", "reduce", "kernel_size", "stride", "replacement", + "num_samples", "scale_factor", "mode", "padding_mode", "align_corners", + "shifts", "affine", "backward", "time(us)" + ] + + for log_file in glob.glob(os.path.join(log_dir, "*.log")): + try: + with open(log_file, 'r') as f: + content = f.read() + + case_name = Path(log_file).stem + base_op_name = case_name.split('.')[-1] + op_name, time_pattern = get_op_pattern(base_op_name, get_backward) + + if get_backward and base_op_name == "l1_loss": + process_l1_loss(content, case_name, data, columns) + continue + + time_matches = extract_times(content, time_pattern, get_backward) + shape_lines = re.findall(r"(shape\s*[:=].*?)(?=\n\S|$)", content) + + for i, (time, unit) in enumerate(time_matches[:len(shape_lines)]): + time_us = convert_to_us(float(time), unit) + params = extract_params(shape_lines[i]) + + if get_backward and params.get("backward", "False") == "False": + continue + + record = create_record(params, case_name, op_name, str(get_backward), time_us) + data.append([record.get(col, "") for col in columns]) + + except Exception as e: + print(f"Warning: Error processing {log_file} - {str(e)}") + continue + + return pd.DataFrame(data, columns=columns) if data else pd.DataFrame() + +def get_op_pattern(base_op_name: str, get_backward: bool) -> tuple: + op_name_map = { + 'forward': { + 'batch_norm': ('aten::batch_norm', 'aten::batch_norm'), + 'unique': ('unique2', 'unique2'), + 'fractional_max_pool2d': ('fractional_max_pool2d', r'\bfractional_max_pool2d\b'), + 'fractional_max_pool3d': ('fractional_max_pool3d', r'\bfractional_max_pool3d\b'), + 'adaptive_max_pool2d': ('adaptive_max_pool2d', r'\badaptive_max_pool2d\b'), + 'max_pool3d': ('max_pool3d_with_indices', 'max_pool3d_with_indices '), + 'max_pool2d': ('max_pool2d_with_indices', 'max_pool2d_with_indices '), + 'exponential': ('exponential_', r'\bexponential_\b'), + 'geometric': ('geometric_', r'\bgeometric_\b'), + 'uniform': ('uniform_', r'\buniform_\b'), + 'random': ('random_', r'\brandom_\b'), + 'log_normal': ('log_normal_', r'\blog_normal_\b'), + 'normal': ('normal_', r'\bnormal_\b'), + 'bernoulli': ('bernoulli_', r'\bbernoulli_\b'), + 'cauchy': ('cauchy_', r'\bcauchy_\b'), + 'dropout': ('dropout', r'\bdropout\b'), + 'layer_norm': ('layer_norm', r'\blayer_norm\b'), + 'ctc_loss': ('_ctc_loss', r'\b_ctc_loss\b'), + 'adaptive_avg_pool2d': ('adaptive_avg_pool2d', r'\badaptive_avg_pool2d\b'), + 'softmax': ('aten::softmax', 'aten::softmax'), + 'group_norm': ('aten::group_norm', 'aten::group_norm'), + }, + 'backward': { + 'batch_norm': ('batch_norm_backward', 'batch_norm_backward'), + 'fractional_max_pool2d': ('fractional_max_pool2d_backward', r'\bfractional_max_pool2d_backward\b'), + 'fractional_max_pool3d': ('fractional_max_pool3d_backward', r'\bfractional_max_pool3d_backward\b'), + 'adaptive_max_pool2d': ('adaptive_max_pool2d_backward', r'\badaptive_max_pool2d_backward\b'), + 'max_pool3d': ('max_pool3d_with_indices_backward', 'max_pool3d_with_indices_backward '), + 'max_pool2d': ('max_pool2d_with_indices_backward', 'max_pool2d_with_indices_backward '), + 'col2im': ('Col2ImBackward0', 'Col2ImBackward0 '), + 'im2col': ('Im2ColBackward0', 'Im2ColBackward0 '), + 'flip': ('FlipBackward0', 'FlipBackward0 '), + 'matmul': ('MmBackward0', 'MmBackward0 '), + 'roll': ('RollBackward0', 'RollBackward0 '), + 'softmax': ('softmax_backward_data', 'softmax_backward_data '), + 'remainder': ('RemainderBackward0', 'RemainderBackward0 '), + 'smooth_l1_loss': ('smooth_l1_loss_backward', 'smooth_l1_loss_backward'), + 'l1_loss': ('l1_loss', 'l1_loss'), + } + } + + mode = 'backward' if get_backward else 'forward' + + for op_pattern in op_name_map[mode]: + if op_pattern in base_op_name: + return op_name_map[mode][op_pattern] + + if get_backward: + return (f"{base_op_name}_backward", f"{base_op_name}_backward ") + else: + return (base_op_name, f"{base_op_name} ") + +def process_l1_loss(content: str, case_name: str, data: List, columns: List): + filtered_content = [line for line in content.split('\n') if "autograd::engine" not in line] + filtered_content = '\n'.join(filtered_content) + abs_times = re.findall(r"AbsBackward0(?:\s+\S+){8}\s+(\d+\.?\d*)([a-zA-Z]*)", filtered_content) + mean_times = re.findall(r"MeanBackward0(?:\s+\S+){8}\s+(\d+\.?\d*)([a-zA-Z]*)", filtered_content) + shape_lines = re.findall(r"(shape\s*[:=].*?)(?=\n\S|$)", content) + + for i, (time, unit) in enumerate(abs_times[:6]): + if i >= len(shape_lines): + break + time_us = convert_to_us(float(time), unit) + params = extract_params(shape_lines[i]) + record = create_record(params, case_name, "AbsBackward0", "True", time_us) + data.append([record.get(col, "") for col in columns]) + + for i, (time, unit) in enumerate(mean_times): + if (i + 6) >= len(shape_lines): + break + time_us = convert_to_us(float(time), unit) + params = extract_params(shape_lines[i + 6]) + record = create_record(params, case_name, "MeanBackward0", "True", time_us) + data.append([record.get(col, "") for col in columns]) + +def extract_times(content: str, pattern: str, get_backward: bool) -> List: + lines = content.split('\n') + results = [] + for line in lines: + if get_backward and any(x in pattern for x in ["Col2ImBackward0", "Im2ColBackward0", + "FlipBackward0", "MmBackward0", + "RollBackward0"]): + if "autograd::engine" in line: + continue + + match = re.search(fr"{pattern}.*?(?:\s+\S+){{8}}\s+(\d+\.?\d*)([a-zA-Z]*)", line) + if match: + results.append((match.group(1), match.group(2))) + + return results + +def create_record(params: Dict, case_name: str, op_name: str, + backward: str, time_us: float) -> Dict: + return { + "P": params.get("p", ""), + **params, + "case_name": case_name, + "op_name": op_name, + "backward": backward, + "time(us)": time_us + } + +def convert_to_us(value: float, unit: str) -> float: + unit = unit.lower() + if unit == "ms": + return value * 1000 + elif unit == "s": + return value * 1_000_000 + return value + +def extract_params(text: str) -> Dict: + params = {} + pairs = re.split(r'[;]', text.strip()) + + for pair in pairs: + if not any(delim in pair for delim in [':', '=']): + continue + + delim = ':' if ':' in pair else '=' + key, value = pair.split(delim, 1) + key = key.strip().lower() + value = value.strip() + + if key in ['p', 'P']: + key = 'p' + elif key in ['dims', 'dim']: + key = 'dim' + elif key in ['shape']: + key = 'shape' + + params[key] = value + + return params + +def save_reports(df: pd.DataFrame, csv_path: str): + os.makedirs(os.path.dirname(csv_path) or '.', exist_ok=True) + df.to_csv(csv_path, index=False, sep=';') + excel_path = csv_path.replace('.csv', '.xlsx') + df.to_excel(excel_path, index=False) + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/microbench_summary.sh b/.github/scripts/microbench_summary.sh deleted file mode 100644 index 637117ca6..000000000 --- a/.github/scripts/microbench_summary.sh +++ /dev/null @@ -1,203 +0,0 @@ -#! /bin/bash -# This script is for op perf summary, both for forward and backward op - -# usage -# Summary forward op time, forward_op_summary.csv is forward summary file -## bash microbench_summary.sh path/to/profile's log forward_op_summary.csv -# Summary backward op time, backward_op_summary.csv is backward summary file, True means summary backward, default is false. -## bash microbench_summary.sh path/to/profile's log backward_op_summary.csv True - -results_dir="$1" -output_file="$2" -Get_backward=${3:-False} -cd "$results_dir" || exit - -echo "case_name;datatype;op_name;shape;channels_last;dim;output_size;P;reduce;kernel_size;stride;replacement;num_samples;scale_factor;mode;padding_mode;align_corners;shifts;affine;backward;time(us)" >> "$output_file" - -function op_summary { - while IFS= read -r line1 && IFS= read -r line2 <&3; do - text=${line1} - IFS=';' read -ra pairs <<< "$(echo "$text" | tr -d '\n' | tr -s ' ')" - for pair in "${pairs[@]}"; do - IFS=':' read -r key value <<< "$pair" - key=$(echo "$key" | xargs) - value=$(echo "$value" | xargs) - if [[ shape = "$key" ]] ; then - shape=${value} - fi - if [[ datatype = "$key" ]] ; then - datatype=${value} - fi - if [[ dim = "$key" ]] || [[ dims = "$key" ]] ; then - dim=${value} - fi - if [[ output_size = "$key" ]] ; then - output_size=${value} - fi - if [[ channels_last = "$key" ]] ; then - channels_last=${value} - fi - if [[ backward = "$key" ]] ; then - backward=${value} - fi - if [[ reduce = "$key" ]] ; then - reduce=${value} - fi - if [[ kernel_size = "$key" ]] ; then - kernel_size=${value} - fi - if [[ P = "$key" ]] ; then - P=${value} - fi - if [[ stride = "$key" ]] ; then - stride=${value} - fi - if [[ replacement = "$key" ]] ; then - replacement=${value} - fi - if [[ num_samples = "$key" ]] ; then - num_samples=${value} - fi - if [[ scale_factor = "$key" ]] ; then - scale_factor=${value} - fi - if [[ mode = "$key" ]] ; then - mode=${value} - fi - if [[ padding_mode = "$key" ]] ; then - padding_mode=${value} - fi - if [[ align_corners = "$key" ]] ; then - align_corners=${value} - fi - if [[ affine = "$key" ]] ; then - affine=${value} - fi - if [[ shifts = "$key" ]] ; then - shifts=${value} - fi - done - number="" - if [[ $line2 =~ ^([0-9.]+)([a-zA-Z]+)$ ]] ; then - number="${BASH_REMATCH[1]}" - unit="${BASH_REMATCH[2]}" - fi - # Align the time units - if [[ $unit == "ms" ]] ;then - number=$(echo "scale=3; $number * 1000" | bc) - fi - if [[ $unit == "s" ]] ;then - number=$(echo "scale=3; $number * 1000000" | bc) - fi - if [[ $Get_backward == "True" ]] && [[ $backward == "False" ]]; then - echo "Only Forward" - else - echo "${i%.*};${datatype};${op_name};$shape;$channels_last;$dim;$output_size;$P;$reduce;$kernel_size;$stride;$replacement;$num_samples;$scale_factor;$mode;$padding_mode;$align_corners;$shifts;$affine;$backward;$number" >> "$output_file" - fi - done < <(echo "$texts") 3< <(echo "$times") -} - -filename=$(find -- *.log) - -for i in $filename -do - output_size="" - P="" - channels_last="" - dim="" - backward="" - reduce="" - kernel_size="" - affine="" - output_size="" - stride="" - replacement="" - num_samples="" - scale_factor="" - mode="" - padding_mode="" - align_corners="" - shifts="" - case_name="${i%.*}" - op_name=$(echo "$case_name" | awk -F. '{print $NF}') - if [[ $Get_backward == "False" ]] ; then - if [[ $op_name =~ batch_norm ]] ; then - op_name="aten::batch_norm" - times=$(grep -E "${op_name}" "${i}" | awk '{print $10}') - elif [[ $op_name =~ exponential ]] || [[ $op_name =~ geometric ]] || [[ $op_name =~ uniform ]] || [[ $op_name =~ random ]] || [[ $op_name =~ normal ]] || [[ $op_name =~ log_normal ]] || [[ $op_name =~ bernoulli ]] || [[ $op_name =~ cauchy ]] ;then - op_name=$op_name"_" - times=$(grep -E "${op_name}" "${i}" | awk '{print $10}') - elif [[ $op_name == unique ]] ; then - op_name="unique2" - times=$(grep -E "${op_name}" "${i}" | awk '{print $10}') - elif [[ $op_name == max_pool3d ]] || [[ $op_name == max_pool2d ]] ; then - op_name=$op_name"_with_indices" - times=$(grep -E "${op_name} " "${i}" | awk '{print $10}') - elif [[ $op_name == dropout ]] || [[ $op_name == layer_norm ]] ; then - times=$(grep -w "${op_name}" "${i}" | awk '{print $10}') - elif [[ $op_name == ctc_loss ]] ; then - op_name="_"$op_name - times=$(grep -w "${op_name}" "${i}" | awk '{print $10}') - elif [[ $op_name == adaptive_avg_pool2d ]] ; then - op_name="adaptive_avg_pool2d" - times=$(grep -w "${op_name} " "${i}" | awk '{print $10}') - elif [[ $op_name == softmax ]] ; then - op_name="aten::softmax" - times=$(grep -E "${op_name}" "${i}" | awk '{print $10}') - elif [[ $op_name == group_norm ]] ; then - op_name="aten::group_norm" - times=$(grep -E "${op_name}" "${i}" | awk '{print $10}') - else - times=$(grep -E "${op_name} " "${i}" | awk '{print $10}') - fi - else - if [[ $op_name =~ batch_norm ]] ; then - op_name="batch_norm_backward" - times=$(grep -E "${op_name}" "${i}" | awk '{print $10}') - elif [[ $op_name == max_pool3d ]] || [[ $op_name == max_pool2d ]] ; then - op_name=$op_name"_with_indices_backward" - times=$(grep -E "${op_name} " "${i}" | awk '{print $10}') - elif [[ $op_name == col2im ]] ; then - op_name="Col2ImBackward0" - times=$(grep -E "${op_name} " "${i}" | grep -v "autograd::engine" | awk '{print $10}') - elif [[ $op_name == im2col ]] ; then - op_name="Im2ColBackward0" - times=$(grep -E "${op_name} " "${i}" | grep -v "autograd::engine" | awk '{print $10}') - elif [[ $op_name == flip ]] ; then - op_name="FlipBackward0" - times=$(grep -E "${op_name} " "${i}" | grep -v "autograd::engine" | awk '{print $10}') - elif [[ $op_name == matmul ]] ; then - op_name="MmBackward0" - times=$(grep -E "${op_name} " "${i}" | grep -v "autograd::engine" | awk '{print $10}') - elif [[ $op_name == roll ]] ; then - op_name="RollBackward0" - times=$(grep -E "${op_name} " "${i}" | grep -v "autograd::engine" | awk '{print $10}') - elif [[ $op_name == softmax ]] ; then - op_name=$op_name"_backward_data" - times=$(grep -E "${op_name} " "${i}" | awk '{print $10}') - elif [[ $op_name == remainder ]] ; then - op_name="RemainderBackward0" - times=$(grep -E "${op_name} " "${i}" | awk '{print $10}') - elif [[ $op_name == l1_loss ]] ; then - op_name="l1_loss" - else - op_name=$op_name"_backward" - times=$(grep -E "${op_name} " "${i}" | awk '{print $10}') - fi - fi - - texts=$(grep -E "shape :|shape:" "$i") - number="" - if [[ $op_name == l1_loss ]] && [[ $Get_backward == "True" ]] ; then - op_name="AbsBackward0" - times=$(grep -E "${op_name} " "${i}" | grep -v "autograd" | awk '{print $10}' | head -n 6) - texts=$(grep -E "shape :|shape:" "$i" | head -n 6) - op_summary - op_name="MeanBackward0" - times=$(grep -E "${op_name} " "${i}" | grep -v "autograd" | awk '{print $10}') - texts=$(grep -E "shape :|shape:" "$i" | tail -n 6) - op_summary - else - op_summary - fi -done diff --git a/.github/workflows/_linux_op_benchmark.yml b/.github/workflows/_linux_op_benchmark.yml index 4e686aacb..051881617 100644 --- a/.github/workflows/_linux_op_benchmark.yml +++ b/.github/workflows/_linux_op_benchmark.yml @@ -117,9 +117,9 @@ jobs: python ${i%.*}.py > ${{ github.workspace }}/op_benchmark/${i%.*}.log done # Summary forward op time - bash ${{ github.workspace }}/.github/scripts/microbench_summary.sh ${{ github.workspace }}/op_benchmark ${{ github.workspace }}/op_benchmark/forward_op_summary.csv + python ${{ github.workspace }}/.github/scripts/microbench_summary.py ${{ github.workspace }}/op_benchmark ${{ github.workspace }}/op_benchmark/forward_op_summary.csv # Summary backward op time - bash ${{ github.workspace }}/.github/scripts/microbench_summary.sh ${{ github.workspace }}/op_benchmark ${{ github.workspace }}/op_benchmark/backward_op_summary.csv True + python ${{ github.workspace }}/.github/scripts/microbench_summary.py ${{ github.workspace }}/op_benchmark ${{ github.workspace }}/op_benchmark/backward_op_summary.csv --backward - name: Upload Inductor XPU OP benchmark Log if: always() uses: actions/upload-artifact@v4 diff --git a/test/microbench/embedding.py b/test/microbench/embedding.py new file mode 100644 index 000000000..8298d6923 --- /dev/null +++ b/test/microbench/embedding.py @@ -0,0 +1,29 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [(1024, 8)] +device = "xpu" +backward = True +dict_len = 2500000 +vect_len = 128 + +for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + emb = torch.nn.Embedding(dict_len, vect_len, dtype=dtype, device=device) + input = torch.randint(0, dict_len, (1024, 8), device=device) + grad = torch.randn(1024, 8, vect_len, dtype=dtype, device=device) + + # warm up + output = emb(input) + output.backward(grad) + + # go + print("shape:", (shape), "; datatype:", dtype, "; backward:", backward) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + output = emb(input) + output.backward(grad) + print(prof.key_averages().table(sort_by="xpu_time_total", row_limit=100)) diff --git a/test/microbench/embedding_bag.py b/test/microbench/embedding_bag.py new file mode 100644 index 000000000..b941484eb --- /dev/null +++ b/test/microbench/embedding_bag.py @@ -0,0 +1,51 @@ +import torch +import random +from torch.profiler import profile, ProfilerActivity + +device = "xpu" +backward = True + +for dtype in [torch.bfloat16, torch.float16, torch.float32]: + for reduce in ['max','mean','sum']: + dict_len = 2500000 + vect_len = 128 + batch = 1024 + + emb = torch.nn.EmbeddingBag(dict_len, vect_len, mode=reduce, dtype=dtype, device=device) + input = torch.empty([batch], dtype=torch.long, device=device) + for i in range(batch): + input[i] = random.randint(0, dict_len - 1) + + bag = torch.empty([batch], dtype=torch.long, device=device) + for i in range(batch): + bag[i] = i + + if backward: + grad = torch.randn(batch, vect_len, dtype=dtype, device=device) + + # warm up + for i in range(5): + output = emb(input, bag) + if backward: + output.backward(grad) + + # go + print( + "shape:", + (batch), + "; datatype:", + dtype, + "; reduce:", + reduce, + "; backward:", + backward, + ) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + output = emb(input, bag) + if backward: + output.backward(grad) + print(prof.key_averages().table(sort_by="xpu_time_total", row_limit=100)) diff --git a/test/microbench/indexing.diag.py b/test/microbench/indexing.diag.py new file mode 100644 index 000000000..4fc9121c1 --- /dev/null +++ b/test/microbench/indexing.diag.py @@ -0,0 +1,27 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [(8192), (8192, 8192)] +device = "xpu" +backward = False + +cache_r = torch.randn((1024 * 1024 * 1024), device=device) +cache_w = torch.randn((1024 * 1024 * 1024), device=device) + +for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + input = torch.randn(shape, dtype=dtype, device=device) + + # warm up + output = torch.diag(input) + + # go + print("shape:", (shape), "; datatype:", dtype, "; backward:", backward) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(10): + cache_r = cache_w * i + output = torch.diag(input) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/indexing.index.py b/test/microbench/indexing.index.py new file mode 100644 index 000000000..deee8e17c --- /dev/null +++ b/test/microbench/indexing.index.py @@ -0,0 +1,62 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [(4, 15000)] +device = "xpu" +backward = False + +for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + for mode in ["with_nonzero", "without_nonzero"]: + d = torch.rand(shape, dtype=dtype, device=device) + e = torch.rand(shape, dtype=dtype, device=device) + + if mode == "with_nonzero": + # warm up + for i in range(100): + f = d < e + g = e[f] + + # go + print( + "shape:", + (shape), + "; datatype:", + dtype, + "; mode:", + mode, + "; backward:", + backward, + ) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + f = d < e + g = e[f] + print(prof.key_averages().table(sort_by="xpu_time_total")) + else: + f = torch.linspace(0, 4-2, steps=int(4/2), device=device).to(torch.long) + # warm up + for i in range(100): + g = e[f] + + # go + print( + "shape:", + (shape), + "; dtype:", + dtype, + "; mode:", + mode, + "; backward:", + backward, + ) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + g = e[f] + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/indexing.index_add.py b/test/microbench/indexing.index_add.py new file mode 100644 index 000000000..434c62d89 --- /dev/null +++ b/test/microbench/indexing.index_add.py @@ -0,0 +1,35 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [(1024, 1024)] +device = "xpu" +backward = False +step = int(1024/2) +cache_r = torch.randn((8192 * 8192), device=device) +cache_w = torch.randn((8192 * 8192), device=device) + +for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + for dim in [0, 1]: + input = torch.zeros(shape, dtype=dtype, device=device) + indices = torch.linspace(0, 1022, steps=step, device=device).to(torch.long) + y_0 = torch.ones((512, 1024), dtype=dtype, device=device) + y_1 = torch.randn((1024, 512), dtype=dtype, device=device) + + # warm up + for i in range(10): + output = input.index_add(0, indices, y_0) + + # go + print("shape:", (shape), "; datatype:", dtype, "; dim:", dim, "; backward:", backward) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(10): + cache_r = cache_w * i + if dim == 0: + output = input.index_add(dim, indices, y_0) + else: + output = input.index_add(dim, indices, y_1) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/indexing.index_copy.py b/test/microbench/indexing.index_copy.py new file mode 100644 index 000000000..b00501a8f --- /dev/null +++ b/test/microbench/indexing.index_copy.py @@ -0,0 +1,44 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [(1024, 1024)] +device = "xpu" +backward = False + +cache_r = torch.randn((1024 * 1024 * 1024), device=device) +cache_w = torch.randn((1024 * 1024 * 1024), device=device) + +for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + for dim in [0, 1]: + input = torch.zeros(shape, dtype=dtype, device=device) + indices = torch.linspace(0, 1022, steps=512, device=device).to(torch.long) + y_0 = torch.ones((512, 1024), dtype=dtype, device=device) + y_1 = torch.randn((1024, 512), dtype=dtype, device=device) + + # warm up + for i in range(10): + output = input.index_copy(0, indices, y_0) + + # go + print( + "shape:", + (shape), + "; datatype:", + dtype, + "; dim:", + dim, + "; backward:", + backward, + ) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(10): + cache_r = cache_w * i + if dim == 0: + output = input.index_copy(dim, indices, y_0) + else: + output = input.index_copy(dim, indices, y_1) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/indexing.index_fill.py b/test/microbench/indexing.index_fill.py new file mode 100644 index 000000000..e3b5df701 --- /dev/null +++ b/test/microbench/indexing.index_fill.py @@ -0,0 +1,44 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [(1024, 1024)] +device = "xpu" +backward = False + +cache_r = torch.randn((1024 * 1024 * 1024), device=device) +cache_w = torch.randn((1024 * 1024 * 1024), device=device) + +for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + for dim in [0, 1]: + input = torch.zeros(shape, dtype=dtype, device=device) + indices = torch.linspace(0, 1022, steps=512, device=device).to(torch.long) + y_0 = torch.ones((512, 1024), dtype=dtype, device=device) + y_1 = torch.randn((1024, 512), dtype=dtype, device=device) + + # warm up + for i in range(10): + output = input.index_fill(0, indices, 1) + + # go + print( + "shape:", + (shape), + "; datatype:", + dtype, + "; dim:", + dim, + "; backward:", + backward, + ) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(10): + cache_r = cache_w * i + if dim == 0: + output = input.index_fill(dim, indices, 1) + else: + output = input.index_fill(dim, indices, 2) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/indexing.index_put.py b/test/microbench/indexing.index_put.py new file mode 100644 index 000000000..f3b66a9f8 --- /dev/null +++ b/test/microbench/indexing.index_put.py @@ -0,0 +1,65 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [(4, 15000)] +device = "xpu" +backward = False + +for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + for mode in ["with_nonzero", "without_nonzero"]: + d = torch.rand(4, 15000, dtype=dtype, device=device) + e = torch.rand(4, 15000, dtype=dtype, device=device) + f = d < e + g = e[f] + + if mode == "with_nonzero": + # warm up + for i in range(100): + d[f] = g + + # go + print( + "shape:", + (shape), + "; datatype:", + dtype, + "; mode:", + mode, + "; backward:", + backward, + ) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + d[f] = g + print(prof.key_averages().table(sort_by="xpu_time_total")) + else: + f = f.nonzero() + index = [] + for i in range(f.dim()): + index.append(f.select(1, i)) + # warm up + for i in range(100): + d[index] = g + + # go + print( + "shape:", + (shape), + "; dtype:", + dtype, + "; mode:", + mode, + "; backward:", + backward, + ) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + d[index] = g + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/indexing.index_select.py b/test/microbench/indexing.index_select.py new file mode 100644 index 000000000..fffd280c3 --- /dev/null +++ b/test/microbench/indexing.index_select.py @@ -0,0 +1,30 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [(1024, 1024), (8192, 8192)] +device = "xpu" +backward = False +cache_r = torch.randn((1024 * 1024 * 1024), device=device) +cache_w = torch.randn((1024 * 1024 * 1024), device=device) + +for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + main_size=shape[0] + step = int(main_size / 2) + input = torch.randn(shape, dtype=dtype, device=device) + indices = torch.linspace(0, shape[0]-2, steps=step, device=device).to(torch.long) + + # warm up + for i in range(10): + y_0 = torch.index_select(input, 0, indices) + + # go + print("shape:", (shape), "; datatype:", dtype, "; backward:", backward) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + cache_r = cache_w * i + y_0 = torch.index_select(input, 0, indices) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/indexing.masked_fill.py b/test/microbench/indexing.masked_fill.py new file mode 100644 index 000000000..a77703f9a --- /dev/null +++ b/test/microbench/indexing.masked_fill.py @@ -0,0 +1,31 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [(8192, 8192)] +device = "xpu" +backward = False +cache_r = torch.randn((1024 * 1024 * 1024), device=device) +cache_w = torch.randn((1024 * 1024 * 1024), device=device) + +for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + input = torch.zeros(shape, dtype=dtype, device=device) + masks_ = torch.zeros((8192), dtype=dtype, device=device) + indices = torch.linspace(0, 8190, steps=4096, device=device).to(torch.long) + masks_.index_fill_(0, indices, True) + masks = masks_.to(torch.bool) + + # warm up + for i in range(10): + y_1 = input.masked_fill(mask=masks, value=1) + + # go + print("shape:", (shape), "; datatype:", dtype, "; backward:", backward) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + cache_r = cache_w * i + y_1 = input.masked_fill(mask=masks, value=1) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/indexing.put.py b/test/microbench/indexing.put.py new file mode 100644 index 000000000..af0b43838 --- /dev/null +++ b/test/microbench/indexing.put.py @@ -0,0 +1,29 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [(8192, 8192)] +device = "xpu" +backward = False +cache_r = torch.randn((1024 * 1024 * 1024), device=device) +cache_w = torch.randn((1024 * 1024 * 1024), device=device) + +for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + input = torch.zeros(shape, dtype=dtype, device=device) + indices = torch.linspace(0, 8190*8190, steps=4096*4096, device=device).to(torch.long) + sources = torch.ones((4096, 4096), dtype=dtype, device=device) + + # warm up + for i in range(10): + input.put_(index=indices, source=sources) + + # go + print("shape:", (shape), "; datatype:", dtype, "; backward:", backward) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + cache_r = cache_w * i + input.put_(index=indices, source=sources) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/indexing.take.py b/test/microbench/indexing.take.py new file mode 100644 index 000000000..2f5e8fc76 --- /dev/null +++ b/test/microbench/indexing.take.py @@ -0,0 +1,28 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [(8192, 8192)] +device = "xpu" +backward = False +cache_r = torch.randn((1024 * 1024 * 1024), device=device) +cache_w = torch.randn((1024 * 1024 * 1024), device=device) + +for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + input = torch.randn(shape, dtype=dtype, device=device) + indices = torch.linspace(0, 8190*8190, steps=4096*4096, device=device).to(torch.long) + + # warm up + for i in range(10): + output = torch.take(input, indices) + + # go + print("shape:", (shape), "; datatype:", dtype, "; backward:", backward) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + cache_r = cache_w * i + output = torch.take(input, indices) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/pooling.max_unpool2d.py b/test/microbench/pooling.max_unpool2d.py new file mode 100644 index 000000000..a5f0af6bf --- /dev/null +++ b/test/microbench/pooling.max_unpool2d.py @@ -0,0 +1,72 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [ + (4, 64, 128, 128), + (4, 65, 128, 128), + (8, 128, 128, 128), +] + + +def maxUnpool2d(shape, dtype, device, channels_last, backward): + N, C, H, W = int(shape[0]), int(shape[1]), int(shape[2]), int(shape[3]) + kernel_size = 2 + + pool = torch.nn.MaxPool2d(kernel_size, return_indices=True).to(device=device, dtype=dtype) + unpool = torch.nn.MaxUnpool2d(kernel_size).to(device=device, dtype=dtype) + torch.manual_seed(20) + + if channels_last: + input = torch.randn([N, C, H, W]).to(memory_format=torch.channels_last).to(device=device, dtype=dtype) + else: + input = torch.randn([N, C, H, W]).to(device=device, dtype=dtype) + output, indices = pool(input) + + if channels_last: + x_dpcpp = output.to(memory_format=torch.channels_last).to(device=device, dtype=dtype) + indices_dpcpp = indices.to(memory_format=torch.channels_last).to(device=device, dtype=torch.int64) + else: + x_dpcpp = output.to(device=device, dtype=dtype) + indices_dpcpp = indices.to(device=device, dtype=torch.int64) + + if backward: + x_dpcpp.requires_grad_(True) + if channels_last: + grad_dpcpp = torch.randn([N, C, H, W]).to(memory_format=torch.channels_last).to(device=device, dtype=dtype) + else: + grad_dpcpp = torch.randn([N, C, H, W]).to(device=device, dtype=dtype) + + y_dpcpp = unpool(x_dpcpp, indices_dpcpp, output_size=torch.Size([N, C, H, W])) + + if backward: + y_dpcpp.backward(grad_dpcpp) + +if __name__ == "__main__": + backward = True + device = "xpu" + for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + for channels_last in [False, True]: + # warm up + maxUnpool2d(shape, dtype, device, channels_last, backward=backward) + + # go + print( + "shape:", + (shape[0], shape[1], shape[2], shape[3]), + "; datatype:", + dtype, + "; kernel_size:", + str(2), + "; channels_last:", + channels_last, + "; backward:", + backward, + ) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + maxUnpool2d(shape, dtype, device, channels_last, backward=backward) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/pooling.max_unpool3d.py b/test/microbench/pooling.max_unpool3d.py new file mode 100644 index 000000000..f809c02a5 --- /dev/null +++ b/test/microbench/pooling.max_unpool3d.py @@ -0,0 +1,72 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [ + (2, 32, 64, 64, 64), + (4, 33, 64, 64, 64), + (16, 32, 32, 32, 32), +] + + +def maxUnpool3d(shape, dtype, device, channels_last, backward): + N, C, D, H, W = int(shape[0]), int(shape[1]), int(shape[2]), int(shape[3]), int(shape[4]) + kernel_size = 2 + + pool = torch.nn.MaxPool3d(kernel_size, return_indices=True).to(device=device, dtype=dtype) + unpool = torch.nn.MaxUnpool3d(kernel_size).to(device=device, dtype=dtype) + torch.manual_seed(20) + + if channels_last: + input = torch.randn([N, C, D, H, W]).to(memory_format=torch.channels_last_3d).to(device=device, dtype=torch.float32) + else: + input = torch.randn([N, C, D, H, W]).to(device=device, dtype=torch.float32) + output, indices = pool(input) + + if channels_last: + x_dpcpp = output.to(memory_format=torch.channels_last_3d).to(device=device, dtype=dtype) + indices_dpcpp = indices.to(memory_format=torch.channels_last_3d).to(device=device, dtype=torch.int64) + else: + x_dpcpp = output.to(device=device, dtype=dtype) + indices_dpcpp = indices.to(device=device, dtype=torch.int64) + + if backward: + x_dpcpp.requires_grad_(True) + if channels_last: + grad_dpcpp = torch.randn([N, C, D, H, W]).to(memory_format=torch.channels_last_3d).to(device=device, dtype=dtype) + else: + grad_dpcpp = torch.randn([N, C, D, H, W]).to(device=device, dtype=dtype) + + y_dpcpp = unpool(x_dpcpp, indices_dpcpp, output_size=torch.Size([N, C, D, H, W])) + + if backward: + y_dpcpp.backward(grad_dpcpp) + +if __name__ == "__main__": + backward = True + device = "xpu" + for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + for channels_last in [False, True]: + # warm up + maxUnpool3d(shape, dtype, device, channels_last, backward=backward) + + # go + print( + "shape:", + (shape[0], shape[1], shape[2], shape[3], shape[4]), + "; datatype:", + dtype, + "; kernel_size:", + str(2), + "; channels_last:", + channels_last, + "; backward:", + backward, + ) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + maxUnpool3d(shape, dtype, device, channels_last, backward=backward) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/repeat_interleave.py b/test/microbench/repeat_interleave.py new file mode 100644 index 000000000..4ece0208b --- /dev/null +++ b/test/microbench/repeat_interleave.py @@ -0,0 +1,48 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [ + (16, 8, 23), + (4, 2048, 2048), +] +device = "xpu" +backward = False + +for shape in shape_list: + for repeats in [8]: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + for dim in [0, 2]: + input = torch.randn(shape, device=device, dtype=dtype) + + if backward: + input.requires_grad_(True) + + # warm up + for i in range(5): + output = torch.repeat_interleave(input, repeats, dim) + + if backward: + gy = torch.empty_like(output) + output.backward(gy) + # go + print( + "shape:", + shape, + "; datatype:", + dtype, + "; dim:", + dim, + "; backward:", + backward, + ) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + output = torch.repeat_interleave(input, repeats, dim) + + if backward: + gy = torch.empty_like(output) + output.backward(gy) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/scan.masked_select.py b/test/microbench/scan.masked_select.py new file mode 100644 index 000000000..dd4a2d540 --- /dev/null +++ b/test/microbench/scan.masked_select.py @@ -0,0 +1,23 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [(8193, 8193)] +device = "xpu" +backward = False + +for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + input = torch.randn(shape, dtype=dtype, device=device) + mask = input.ge(0.5) + # warm up + torch.masked_select(input, mask) + + # go + print("shape:", shape, "; datatype:", dtype, "; backward:", backward) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + torch.masked_select(input, mask) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/scan.nonzero.py b/test/microbench/scan.nonzero.py new file mode 100644 index 000000000..937a635bb --- /dev/null +++ b/test/microbench/scan.nonzero.py @@ -0,0 +1,26 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [(2047, 2047, 10), (1, 4 * 15000)] +device = "xpu" +backward = False + +for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + if shape == (2047, 2047, 10): + input = torch.randint(-2, 3, shape, dtype=dtype, device=device) + else: + input = torch.randn(shape, dtype=dtype, device=device) + + # warm up + torch.nonzero(input) + + # go + print("shape:", shape, "; datatype:", dtype, "; backward:", backward) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + torch.nonzero(input) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/scatter_gather.gather.py b/test/microbench/scatter_gather.gather.py new file mode 100644 index 000000000..45781fad3 --- /dev/null +++ b/test/microbench/scatter_gather.gather.py @@ -0,0 +1,48 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [ + ((2048, 64, 4), (2048, 64, 1), 2), # LQCD shape + ((28, 4096, 9), (28, 4096, 1), 2), + ((512, 36, 4), (512, 36, 1), 2), + ((102400*6400, 4), (102400*6400, 1), 1), # big shape thin + ((102400, 4*6400), (25600, 4*6400), 0), # big shape fat + ((4*6400, 102400), (1*6400, 102400), 0), + ((10240, 8192), (10240, 2048), 1), # medium shape + ((8192, 10240), (2048, 2560), 1), + ((10240, 8192), (2560, 8192), 0), + ((8192, 10240), (2048, 10240), 0), +] + +device = "xpu" +backward = False + +g_xpu = torch.Generator(device=device) +g_xpu.manual_seed(25) +torch.manual_seed(25) +for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + shapes = shape[0] + ishapes = shape[1] + dim = shape[2] + a = torch.randn(shapes, dtype=dtype, device=device) + index = torch.randint(1, shapes[dim], ishapes, device=device, generator=g_xpu) + print( + "shape:", + shapes, + "; kernel_size:", + ishapes, + "; datatype:", + dtype, + "; dim:", + dim, + "; backward:", + backward, + ) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + torch.gather(a, dim, index) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/scatter_gather.scatter.py b/test/microbench/scatter_gather.scatter.py new file mode 100644 index 000000000..801b481b6 --- /dev/null +++ b/test/microbench/scatter_gather.scatter.py @@ -0,0 +1,74 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [ + # shape, dim + ((28, 4096, 9, 1), 2),# LQCD shape + ((512, 36, 4, 1), 2), + ((4, 4096, 4096), 0), # big shape + ((2048, 4, 4096), 0), + ((2048, 4096, 4), 0), + ((2048, 4096, 4096), 0), + ((4096, 8192, 8192), 0), + ((4097, 8193, 8193), 0), + ((4, 4096, 4096), 1), # big shape + ((2048, 4, 4096), 1), + ((2048, 4096, 4), 1), + ((2048, 4096, 4096), 1), + ((4096, 8192, 8192), 1), + ((4097, 8193, 8193), 1), +] + +device = "xpu" +backward = False + +g_xpu = torch.Generator(device=device) +g_xpu.manual_seed(25) +torch.manual_seed(25) + +def Scatter(shape, dtype, dim, device): + if dim == 2: + m, n, k1, k2 = shape[0][0], shape[0][1], shape[0][2], shape[0][3] + src = torch.ones((m, n, k1), dtype=dtype, device=device) + index = torch.randint(0, k2, (m, n, k1), generator=g_xpu, device=device) + zeros = torch.zeros(m, n, k2, dtype=dtype, device=device) + else: # dim=0 + if dim == 0: + m1, m2, n = shape[0][0], shape[0][1], shape[0][2] + src = torch.ones((m1, n), dtype=dtype, device=device) + index = torch.randint(0, m2, (m1, n), generator=g_xpu, device=device) + zeros = torch.zeros(m2, n, dtype=src.dtype, device=device) + else: # dim=1 + m, n1, n2 = shape[0][0], shape[0][1], shape[0][2] + src = torch.ones((m, n1), dtype=dtype, device=device) + index = torch.randint(0, n2, (m, n1), generator=g_xpu, device=device) + zeros = torch.zeros(m, n2, dtype=src.dtype, device=device) + + dst = zeros.scatter_(dim, index, src) + + +if __name__ == "__main__": + for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + dim = shape[1] + # warm up + Scatter(shape, dtype, dim, device) + + # go + print( + "shape:", + shape[0], + "; datatype:", + dtype, + "; dim:", + dim, + "; backward:", + backward, + ) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + Scatter(shape, dtype, dim, device) + print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/scatter_gather.scatter_add.py b/test/microbench/scatter_gather.scatter_add.py new file mode 100644 index 000000000..1437ea1f0 --- /dev/null +++ b/test/microbench/scatter_gather.scatter_add.py @@ -0,0 +1,74 @@ +import torch +from torch.profiler import profile, ProfilerActivity + +shape_list = [ + # shape, dim + ((28, 4096, 9, 1), 2),# LQCD shape + ((512, 36, 4, 1), 2), + ((4, 4096, 4096), 0), # big shape + ((2048, 4, 4096), 0), + ((2048, 4096, 4), 0), + ((2048, 4096, 4096), 0), + ((4096, 8192, 8192), 0), + ((4097, 8193, 8193), 0), + ((4, 4096, 4096), 1), # big shape + ((2048, 4, 4096), 1), + ((2048, 4096, 4), 1), + ((2048, 4096, 4096), 1), + ((4096, 8192, 8192), 1), + ((4097, 8193, 8193), 1), +] + +device = "xpu" +backward = False + +g_xpu = torch.Generator(device=device) +g_xpu.manual_seed(25) +torch.manual_seed(25) + +def Scatter_add(shape, dtype, dim, device): + if dim == 2: + m, n, k1, k2 = shape[0][0], shape[0][1], shape[0][2], shape[0][3] + src = torch.ones((m, n, k1), dtype=dtype, device=device) + index = torch.randint(0, k2, (m, n, k1), generator=g_xpu, device=device) + zeros = torch.zeros(m, n, k2, dtype=dtype, device=device) + else: # dim=0 + if dim == 0: + m1, m2, n = shape[0][0], shape[0][1], shape[0][2] + src = torch.ones((m1, n), dtype=dtype, device=device) + index = torch.randint(0, m2, (m1, n), generator=g_xpu, device=device) + zeros = torch.zeros(m2, n, dtype=src.dtype, device=device) + else: # dim=1 + m, n1, n2 = shape[0][0], shape[0][1], shape[0][2] + src = torch.ones((m, n1), dtype=dtype, device=device) + index = torch.randint(0, n2, (m, n1), generator=g_xpu, device=device) + zeros = torch.zeros(m, n2, dtype=src.dtype, device=device) + + dst = zeros.scatter_add_(dim, index, src) + + +if __name__ == "__main__": + for shape in shape_list: + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + dim = shape[1] + # warm up + Scatter_add(shape, dtype, dim, device) + + # go + print( + "shape:", + shape[0], + "; datatype:", + dtype, + "; dim:", + dim, + "; backward:", + backward, + ) + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], + record_shapes=True, + ) as prof: + for i in range(20): + Scatter_add(shape, dtype, dim, device) + print(prof.key_averages().table(sort_by="xpu_time_total")) From c30fe27900b22be6d86b3641682e219a3dc13cac Mon Sep 17 00:00:00 2001 From: "Zhong, Ruijie" Date: Wed, 23 Apr 2025 00:31:24 -0700 Subject: [PATCH 2/5] align the lint check --- test/microbench/embedding_bag.py | 8 +++-- test/microbench/indexing.index.py | 2 +- test/microbench/indexing.index_add.py | 13 ++++++-- test/microbench/indexing.index_select.py | 6 ++-- test/microbench/indexing.put.py | 4 ++- test/microbench/indexing.take.py | 4 ++- test/microbench/pooling.max_unpool2d.py | 29 ++++++++++++++---- test/microbench/pooling.max_unpool3d.py | 37 ++++++++++++++++++----- test/microbench/scan.nonzero.py | 2 +- test/microbench/scatter_gather.gather.py | 14 ++++----- test/microbench/scatter_gather.scatter.py | 7 ++--- 11 files changed, 91 insertions(+), 35 deletions(-) diff --git a/test/microbench/embedding_bag.py b/test/microbench/embedding_bag.py index b941484eb..feb34e2f8 100644 --- a/test/microbench/embedding_bag.py +++ b/test/microbench/embedding_bag.py @@ -1,17 +1,19 @@ -import torch import random +import torch from torch.profiler import profile, ProfilerActivity device = "xpu" backward = True for dtype in [torch.bfloat16, torch.float16, torch.float32]: - for reduce in ['max','mean','sum']: + for reduce in ['max', 'mean', 'sum']: dict_len = 2500000 vect_len = 128 batch = 1024 - emb = torch.nn.EmbeddingBag(dict_len, vect_len, mode=reduce, dtype=dtype, device=device) + emb = torch.nn.EmbeddingBag( + dict_len, vect_len, mode=reduce, dtype=dtype, device=device + ) input = torch.empty([batch], dtype=torch.long, device=device) for i in range(batch): input[i] = random.randint(0, dict_len - 1) diff --git a/test/microbench/indexing.index.py b/test/microbench/indexing.index.py index deee8e17c..5ce3cf91c 100644 --- a/test/microbench/indexing.index.py +++ b/test/microbench/indexing.index.py @@ -37,7 +37,7 @@ g = e[f] print(prof.key_averages().table(sort_by="xpu_time_total")) else: - f = torch.linspace(0, 4-2, steps=int(4/2), device=device).to(torch.long) + f = torch.linspace(0, 4 - 2, steps=int(4 / 2), device=device).to(torch.long) # warm up for i in range(100): g = e[f] diff --git a/test/microbench/indexing.index_add.py b/test/microbench/indexing.index_add.py index 434c62d89..ad3bbc49e 100644 --- a/test/microbench/indexing.index_add.py +++ b/test/microbench/indexing.index_add.py @@ -4,7 +4,7 @@ shape_list = [(1024, 1024)] device = "xpu" backward = False -step = int(1024/2) +step = int(1024 / 2) cache_r = torch.randn((8192 * 8192), device=device) cache_w = torch.randn((8192 * 8192), device=device) @@ -21,7 +21,16 @@ output = input.index_add(0, indices, y_0) # go - print("shape:", (shape), "; datatype:", dtype, "; dim:", dim, "; backward:", backward) + print( + "shape:", + (shape), + "; datatype:", + dtype, + "; dim:", + dim, + "; backward:", + backward, + ) with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], record_shapes=True, diff --git a/test/microbench/indexing.index_select.py b/test/microbench/indexing.index_select.py index fffd280c3..3f5060f09 100644 --- a/test/microbench/indexing.index_select.py +++ b/test/microbench/indexing.index_select.py @@ -9,10 +9,12 @@ for shape in shape_list: for dtype in [torch.bfloat16, torch.float16, torch.float32]: - main_size=shape[0] + main_size = shape[0] step = int(main_size / 2) input = torch.randn(shape, dtype=dtype, device=device) - indices = torch.linspace(0, shape[0]-2, steps=step, device=device).to(torch.long) + indices = torch.linspace(0, shape[0] - 2, steps=step, device=device).to( + torch.long + ) # warm up for i in range(10): diff --git a/test/microbench/indexing.put.py b/test/microbench/indexing.put.py index af0b43838..455c282dc 100644 --- a/test/microbench/indexing.put.py +++ b/test/microbench/indexing.put.py @@ -10,7 +10,9 @@ for shape in shape_list: for dtype in [torch.bfloat16, torch.float16, torch.float32]: input = torch.zeros(shape, dtype=dtype, device=device) - indices = torch.linspace(0, 8190*8190, steps=4096*4096, device=device).to(torch.long) + indices = torch.linspace(0, 8190 * 8190, steps=4096 * 4096, device=device).to( + torch.long + ) sources = torch.ones((4096, 4096), dtype=dtype, device=device) # warm up diff --git a/test/microbench/indexing.take.py b/test/microbench/indexing.take.py index 2f5e8fc76..7ec30ed4d 100644 --- a/test/microbench/indexing.take.py +++ b/test/microbench/indexing.take.py @@ -10,7 +10,9 @@ for shape in shape_list: for dtype in [torch.bfloat16, torch.float16, torch.float32]: input = torch.randn(shape, dtype=dtype, device=device) - indices = torch.linspace(0, 8190*8190, steps=4096*4096, device=device).to(torch.long) + indices = torch.linspace(0, 8190 * 8190, steps=4096 * 4096, device=device).to( + torch.long + ) # warm up for i in range(10): diff --git a/test/microbench/pooling.max_unpool2d.py b/test/microbench/pooling.max_unpool2d.py index a5f0af6bf..fdf4c7a49 100644 --- a/test/microbench/pooling.max_unpool2d.py +++ b/test/microbench/pooling.max_unpool2d.py @@ -12,19 +12,29 @@ def maxUnpool2d(shape, dtype, device, channels_last, backward): N, C, H, W = int(shape[0]), int(shape[1]), int(shape[2]), int(shape[3]) kernel_size = 2 - pool = torch.nn.MaxPool2d(kernel_size, return_indices=True).to(device=device, dtype=dtype) + pool = torch.nn.MaxPool2d(kernel_size, return_indices=True).to( + device=device, dtype=dtype + ) unpool = torch.nn.MaxUnpool2d(kernel_size).to(device=device, dtype=dtype) torch.manual_seed(20) if channels_last: - input = torch.randn([N, C, H, W]).to(memory_format=torch.channels_last).to(device=device, dtype=dtype) + input = ( + torch.randn([N, C, H, W]) + .to(memory_format=torch.channels_last) + .to(device=device, dtype=dtype) + ) else: input = torch.randn([N, C, H, W]).to(device=device, dtype=dtype) output, indices = pool(input) if channels_last: - x_dpcpp = output.to(memory_format=torch.channels_last).to(device=device, dtype=dtype) - indices_dpcpp = indices.to(memory_format=torch.channels_last).to(device=device, dtype=torch.int64) + x_dpcpp = output.to(memory_format=torch.channels_last).to( + device=device, dtype=dtype + ) + indices_dpcpp = indices.to(memory_format=torch.channels_last).to( + device=device, dtype=torch.int64 + ) else: x_dpcpp = output.to(device=device, dtype=dtype) indices_dpcpp = indices.to(device=device, dtype=torch.int64) @@ -32,7 +42,11 @@ def maxUnpool2d(shape, dtype, device, channels_last, backward): if backward: x_dpcpp.requires_grad_(True) if channels_last: - grad_dpcpp = torch.randn([N, C, H, W]).to(memory_format=torch.channels_last).to(device=device, dtype=dtype) + grad_dpcpp = ( + torch.randn([N, C, H, W]) + .to(memory_format=torch.channels_last) + .to(device=device, dtype=dtype) + ) else: grad_dpcpp = torch.randn([N, C, H, W]).to(device=device, dtype=dtype) @@ -41,6 +55,7 @@ def maxUnpool2d(shape, dtype, device, channels_last, backward): if backward: y_dpcpp.backward(grad_dpcpp) + if __name__ == "__main__": backward = True device = "xpu" @@ -68,5 +83,7 @@ def maxUnpool2d(shape, dtype, device, channels_last, backward): record_shapes=True, ) as prof: for i in range(20): - maxUnpool2d(shape, dtype, device, channels_last, backward=backward) + maxUnpool2d( + shape, dtype, device, channels_last, backward=backward + ) print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/pooling.max_unpool3d.py b/test/microbench/pooling.max_unpool3d.py index f809c02a5..3e2c12f73 100644 --- a/test/microbench/pooling.max_unpool3d.py +++ b/test/microbench/pooling.max_unpool3d.py @@ -9,22 +9,38 @@ def maxUnpool3d(shape, dtype, device, channels_last, backward): - N, C, D, H, W = int(shape[0]), int(shape[1]), int(shape[2]), int(shape[3]), int(shape[4]) + N, C, D, H, W = ( + int(shape[0]), + int(shape[1]), + int(shape[2]), + int(shape[3]), + int(shape[4]), + ) kernel_size = 2 - pool = torch.nn.MaxPool3d(kernel_size, return_indices=True).to(device=device, dtype=dtype) + pool = torch.nn.MaxPool3d(kernel_size, return_indices=True).to( + device=device, dtype=dtype + ) unpool = torch.nn.MaxUnpool3d(kernel_size).to(device=device, dtype=dtype) torch.manual_seed(20) if channels_last: - input = torch.randn([N, C, D, H, W]).to(memory_format=torch.channels_last_3d).to(device=device, dtype=torch.float32) + input = ( + torch.randn([N, C, D, H, W]) + .to(memory_format=torch.channels_last_3d) + .to(device=device, dtype=torch.float32) + ) else: input = torch.randn([N, C, D, H, W]).to(device=device, dtype=torch.float32) output, indices = pool(input) if channels_last: - x_dpcpp = output.to(memory_format=torch.channels_last_3d).to(device=device, dtype=dtype) - indices_dpcpp = indices.to(memory_format=torch.channels_last_3d).to(device=device, dtype=torch.int64) + x_dpcpp = output.to(memory_format=torch.channels_last_3d).to( + device=device, dtype=dtype + ) + indices_dpcpp = indices.to(memory_format=torch.channels_last_3d).to( + device=device, dtype=torch.int64 + ) else: x_dpcpp = output.to(device=device, dtype=dtype) indices_dpcpp = indices.to(device=device, dtype=torch.int64) @@ -32,7 +48,11 @@ def maxUnpool3d(shape, dtype, device, channels_last, backward): if backward: x_dpcpp.requires_grad_(True) if channels_last: - grad_dpcpp = torch.randn([N, C, D, H, W]).to(memory_format=torch.channels_last_3d).to(device=device, dtype=dtype) + grad_dpcpp = ( + torch.randn([N, C, D, H, W]) + .to(memory_format=torch.channels_last_3d) + .to(device=device, dtype=dtype) + ) else: grad_dpcpp = torch.randn([N, C, D, H, W]).to(device=device, dtype=dtype) @@ -41,6 +61,7 @@ def maxUnpool3d(shape, dtype, device, channels_last, backward): if backward: y_dpcpp.backward(grad_dpcpp) + if __name__ == "__main__": backward = True device = "xpu" @@ -68,5 +89,7 @@ def maxUnpool3d(shape, dtype, device, channels_last, backward): record_shapes=True, ) as prof: for i in range(20): - maxUnpool3d(shape, dtype, device, channels_last, backward=backward) + maxUnpool3d( + shape, dtype, device, channels_last, backward=backward + ) print(prof.key_averages().table(sort_by="xpu_time_total")) diff --git a/test/microbench/scan.nonzero.py b/test/microbench/scan.nonzero.py index 937a635bb..12aac922d 100644 --- a/test/microbench/scan.nonzero.py +++ b/test/microbench/scan.nonzero.py @@ -9,7 +9,7 @@ for dtype in [torch.bfloat16, torch.float16, torch.float32]: if shape == (2047, 2047, 10): input = torch.randint(-2, 3, shape, dtype=dtype, device=device) - else: + else: input = torch.randn(shape, dtype=dtype, device=device) # warm up diff --git a/test/microbench/scatter_gather.gather.py b/test/microbench/scatter_gather.gather.py index 45781fad3..469428001 100644 --- a/test/microbench/scatter_gather.gather.py +++ b/test/microbench/scatter_gather.gather.py @@ -2,13 +2,13 @@ from torch.profiler import profile, ProfilerActivity shape_list = [ - ((2048, 64, 4), (2048, 64, 1), 2), # LQCD shape + ((2048, 64, 4), (2048, 64, 1), 2), # LQCD shape ((28, 4096, 9), (28, 4096, 1), 2), ((512, 36, 4), (512, 36, 1), 2), - ((102400*6400, 4), (102400*6400, 1), 1), # big shape thin - ((102400, 4*6400), (25600, 4*6400), 0), # big shape fat - ((4*6400, 102400), (1*6400, 102400), 0), - ((10240, 8192), (10240, 2048), 1), # medium shape + ((102400 * 6400, 4), (102400 * 6400, 1), 1), # big shape thin + ((102400, 4 * 6400), (25600, 4 * 6400), 0), # big shape fat + ((4 * 6400, 102400), (1 * 6400, 102400), 0), + ((10240, 8192), (10240, 2048), 1), # medium shape ((8192, 10240), (2048, 2560), 1), ((10240, 8192), (2560, 8192), 0), ((8192, 10240), (2048, 10240), 0), @@ -33,9 +33,9 @@ "; kernel_size:", ishapes, "; datatype:", - dtype, + dtype, "; dim:", - dim, + dim, "; backward:", backward, ) diff --git a/test/microbench/scatter_gather.scatter.py b/test/microbench/scatter_gather.scatter.py index 801b481b6..ccd1c3194 100644 --- a/test/microbench/scatter_gather.scatter.py +++ b/test/microbench/scatter_gather.scatter.py @@ -2,16 +2,15 @@ from torch.profiler import profile, ProfilerActivity shape_list = [ - # shape, dim - ((28, 4096, 9, 1), 2),# LQCD shape + ((28, 4096, 9, 1), 2), # LQCD shape ((512, 36, 4, 1), 2), - ((4, 4096, 4096), 0), # big shape + ((4, 4096, 4096), 0), # big shape ((2048, 4, 4096), 0), ((2048, 4096, 4), 0), ((2048, 4096, 4096), 0), ((4096, 8192, 8192), 0), ((4097, 8193, 8193), 0), - ((4, 4096, 4096), 1), # big shape + ((4, 4096, 4096), 1), # big shape ((2048, 4, 4096), 1), ((2048, 4096, 4), 1), ((2048, 4096, 4096), 1), From 127121b3dbf4b8a2d2a1e13039a37a1dd5cfeb6f Mon Sep 17 00:00:00 2001 From: "Zhong, Ruijie" Date: Wed, 23 Apr 2025 00:37:28 -0700 Subject: [PATCH 3/5] align the lint check --- .github/scripts/microbench_summary.py | 30 ++++++++++++------- test/microbench/embedding_bag.py | 3 +- test/microbench/indexing.index.py | 4 ++- test/microbench/scatter_gather.scatter.py | 5 ++-- test/microbench/scatter_gather.scatter_add.py | 11 +++---- 5 files changed, 33 insertions(+), 20 deletions(-) diff --git a/.github/scripts/microbench_summary.py b/.github/scripts/microbench_summary.py index 089918bd0..99d456546 100644 --- a/.github/scripts/microbench_summary.py +++ b/.github/scripts/microbench_summary.py @@ -2,9 +2,9 @@ Microbenchmark Summary Tool - Parses performance logs and generates CSV/Excel reports # Usage # Summary forward op time, forward_op_summary.csv is forward summary file -python microbench_summary.py path/to/profile's log g forward_op_summary.csv +python microbench_summary.py path/to/profile's log forward_op_summary.csv # Summary backward op time, backward_op_summary.csv is backward summary file, True means summary backward, default is false. -python microbench_summary.py path/to/profile's log g backward_op_summary.csv --backward +python microbench_summary.py path/to/profile's log backward_op_summary.csv --backward """ import re @@ -13,7 +13,7 @@ import os import argparse from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List def main(): parser = argparse.ArgumentParser( @@ -22,7 +22,7 @@ def main(): ) parser.add_argument("log_dir", help="Directory containing log files") parser.add_argument("output_file", help="Output CSV file path") - parser.add_argument("--backward", action="store_true", + parser.add_argument("--backward", action="store_true", help="Process backward operations instead of forward") args = parser.parse_args() @@ -41,15 +41,15 @@ def main(): def parse_logs(log_dir: str, get_backward: bool = False) -> pd.DataFrame: data = [] columns = [ - "case_name", "datatype", "op_name", "shape", "channels_last", "dim", - "output_size", "P", "reduce", "kernel_size", "stride", "replacement", - "num_samples", "scale_factor", "mode", "padding_mode", "align_corners", + "case_name", "datatype", "op_name", "shape", "channels_last", "dim", + "output_size", "P", "reduce", "kernel_size", "stride", "replacement", + "num_samples", "scale_factor", "mode", "padding_mode", "align_corners", "shifts", "affine", "backward", "time(us)" ] for log_file in glob.glob(os.path.join(log_dir, "*.log")): try: - with open(log_file, 'r') as f: + with open(log_file) as f: content = f.read() case_name = Path(log_file).stem @@ -97,6 +97,12 @@ def get_op_pattern(base_op_name: str, get_backward: bool) -> tuple: 'normal': ('normal_', r'\bnormal_\b'), 'bernoulli': ('bernoulli_', r'\bbernoulli_\b'), 'cauchy': ('cauchy_', r'\bcauchy_\b'), + 'index_fill': ('index_fill_', r'\bindex_fill_\b'), + 'index_put': ('index_put_', r'\bindex_put_\b'), + 'put': ('put_', r'\bput_\b'), + 'masked_fill': ('masked_fill_', r'\bmasked_fill_\b'), + 'scatter_add': ('scatter_add_', r'\bscatter_add_\b'), + 'scatter': ('scatter_', r'\bscatter_\b'), 'dropout': ('dropout', r'\bdropout\b'), 'layer_norm': ('layer_norm', r'\blayer_norm\b'), 'ctc_loss': ('_ctc_loss', r'\b_ctc_loss\b'), @@ -109,6 +115,8 @@ def get_op_pattern(base_op_name: str, get_backward: bool) -> tuple: 'fractional_max_pool2d': ('fractional_max_pool2d_backward', r'\bfractional_max_pool2d_backward\b'), 'fractional_max_pool3d': ('fractional_max_pool3d_backward', r'\bfractional_max_pool3d_backward\b'), 'adaptive_max_pool2d': ('adaptive_max_pool2d_backward', r'\badaptive_max_pool2d_backward\b'), + 'max_unpool2d': ('MaxUnpool2DBackward0', 'MaxUnpool2DBackward0 '), + 'max_unpool3d': ('MaxUnpool3DBackward0', 'MaxUnpool3DBackward0 '), 'max_pool3d': ('max_pool3d_with_indices_backward', 'max_pool3d_with_indices_backward '), 'max_pool2d': ('max_pool2d_with_indices_backward', 'max_pool2d_with_indices_backward '), 'col2im': ('Col2ImBackward0', 'Col2ImBackward0 '), @@ -163,17 +171,17 @@ def extract_times(content: str, pattern: str, get_backward: bool) -> List: for line in lines: if get_backward and any(x in pattern for x in ["Col2ImBackward0", "Im2ColBackward0", "FlipBackward0", "MmBackward0", - "RollBackward0"]): + "RollBackward0", "MaxUnpool2DBackward0", "MaxUnpool3DBackward0"]): if "autograd::engine" in line: continue match = re.search(fr"{pattern}.*?(?:\s+\S+){{8}}\s+(\d+\.?\d*)([a-zA-Z]*)", line) if match: results.append((match.group(1), match.group(2))) - + return results -def create_record(params: Dict, case_name: str, op_name: str, +def create_record(params: Dict, case_name: str, op_name: str, backward: str, time_us: float) -> Dict: return { "P": params.get("p", ""), diff --git a/test/microbench/embedding_bag.py b/test/microbench/embedding_bag.py index feb34e2f8..30825bb7d 100644 --- a/test/microbench/embedding_bag.py +++ b/test/microbench/embedding_bag.py @@ -1,4 +1,5 @@ import random + import torch from torch.profiler import profile, ProfilerActivity @@ -6,7 +7,7 @@ backward = True for dtype in [torch.bfloat16, torch.float16, torch.float32]: - for reduce in ['max', 'mean', 'sum']: + for reduce in ["max", "mean", "sum"]: dict_len = 2500000 vect_len = 128 batch = 1024 diff --git a/test/microbench/indexing.index.py b/test/microbench/indexing.index.py index 5ce3cf91c..4b0c82bb0 100644 --- a/test/microbench/indexing.index.py +++ b/test/microbench/indexing.index.py @@ -37,7 +37,9 @@ g = e[f] print(prof.key_averages().table(sort_by="xpu_time_total")) else: - f = torch.linspace(0, 4 - 2, steps=int(4 / 2), device=device).to(torch.long) + f = torch.linspace(0, 4 - 2, steps=int(4 / 2), device=device).to( + torch.long + ) # warm up for i in range(100): g = e[f] diff --git a/test/microbench/scatter_gather.scatter.py b/test/microbench/scatter_gather.scatter.py index ccd1c3194..9aff261ac 100644 --- a/test/microbench/scatter_gather.scatter.py +++ b/test/microbench/scatter_gather.scatter.py @@ -25,19 +25,20 @@ g_xpu.manual_seed(25) torch.manual_seed(25) + def Scatter(shape, dtype, dim, device): if dim == 2: m, n, k1, k2 = shape[0][0], shape[0][1], shape[0][2], shape[0][3] src = torch.ones((m, n, k1), dtype=dtype, device=device) index = torch.randint(0, k2, (m, n, k1), generator=g_xpu, device=device) zeros = torch.zeros(m, n, k2, dtype=dtype, device=device) - else: # dim=0 + else: if dim == 0: m1, m2, n = shape[0][0], shape[0][1], shape[0][2] src = torch.ones((m1, n), dtype=dtype, device=device) index = torch.randint(0, m2, (m1, n), generator=g_xpu, device=device) zeros = torch.zeros(m2, n, dtype=src.dtype, device=device) - else: # dim=1 + else: m, n1, n2 = shape[0][0], shape[0][1], shape[0][2] src = torch.ones((m, n1), dtype=dtype, device=device) index = torch.randint(0, n2, (m, n1), generator=g_xpu, device=device) diff --git a/test/microbench/scatter_gather.scatter_add.py b/test/microbench/scatter_gather.scatter_add.py index 1437ea1f0..9f244f837 100644 --- a/test/microbench/scatter_gather.scatter_add.py +++ b/test/microbench/scatter_gather.scatter_add.py @@ -3,15 +3,15 @@ shape_list = [ # shape, dim - ((28, 4096, 9, 1), 2),# LQCD shape + ((28, 4096, 9, 1), 2), # LQCD shape ((512, 36, 4, 1), 2), - ((4, 4096, 4096), 0), # big shape + ((4, 4096, 4096), 0), # big shape ((2048, 4, 4096), 0), ((2048, 4096, 4), 0), ((2048, 4096, 4096), 0), ((4096, 8192, 8192), 0), ((4097, 8193, 8193), 0), - ((4, 4096, 4096), 1), # big shape + ((4, 4096, 4096), 1), # big shape ((2048, 4, 4096), 1), ((2048, 4096, 4), 1), ((2048, 4096, 4096), 1), @@ -26,19 +26,20 @@ g_xpu.manual_seed(25) torch.manual_seed(25) + def Scatter_add(shape, dtype, dim, device): if dim == 2: m, n, k1, k2 = shape[0][0], shape[0][1], shape[0][2], shape[0][3] src = torch.ones((m, n, k1), dtype=dtype, device=device) index = torch.randint(0, k2, (m, n, k1), generator=g_xpu, device=device) zeros = torch.zeros(m, n, k2, dtype=dtype, device=device) - else: # dim=0 + else: if dim == 0: m1, m2, n = shape[0][0], shape[0][1], shape[0][2] src = torch.ones((m1, n), dtype=dtype, device=device) index = torch.randint(0, m2, (m1, n), generator=g_xpu, device=device) zeros = torch.zeros(m2, n, dtype=src.dtype, device=device) - else: # dim=1 + else: m, n1, n2 = shape[0][0], shape[0][1], shape[0][2] src = torch.ones((m, n1), dtype=dtype, device=device) index = torch.randint(0, n2, (m, n1), generator=g_xpu, device=device) From e6aaa34b47213218c6b3a6aa369b981267bd2c17 Mon Sep 17 00:00:00 2001 From: "Zhong, Ruijie" Date: Wed, 23 Apr 2025 18:38:01 -0700 Subject: [PATCH 4/5] fix the packages install issue and remove unneccessary case --- .github/workflows/_linux_op_benchmark.yml | 1 + test/microbench/adaptive_avg_pool2d.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/_linux_op_benchmark.yml b/.github/workflows/_linux_op_benchmark.yml index 051881617..8a17dfd0a 100644 --- a/.github/workflows/_linux_op_benchmark.yml +++ b/.github/workflows/_linux_op_benchmark.yml @@ -116,6 +116,7 @@ jobs: do python ${i%.*}.py > ${{ github.workspace }}/op_benchmark/${i%.*}.log done + pip install pandas openpyxl # Summary forward op time python ${{ github.workspace }}/.github/scripts/microbench_summary.py ${{ github.workspace }}/op_benchmark ${{ github.workspace }}/op_benchmark/forward_op_summary.csv # Summary backward op time diff --git a/test/microbench/adaptive_avg_pool2d.py b/test/microbench/adaptive_avg_pool2d.py index 5a444879d..c928b0235 100644 --- a/test/microbench/adaptive_avg_pool2d.py +++ b/test/microbench/adaptive_avg_pool2d.py @@ -4,7 +4,6 @@ device = "xpu" shape_list = [ - (8, 512, 7, 7, (1, 1)), (8, 512, 32, 32, (7, 7)), (8, 256, 56, 56, (14, 14)), ] From 35971582d9c967b4fa03a29db4f6ac96e50cd803 Mon Sep 17 00:00:00 2001 From: "Zhong, Ruijie" Date: Thu, 24 Apr 2025 00:47:39 -0700 Subject: [PATCH 5/5] fix typo issue and para issue --- .github/scripts/microbench_summary.py | 2 ++ test/microbench/indexing.index.py | 2 +- test/microbench/indexing.index_put.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/scripts/microbench_summary.py b/.github/scripts/microbench_summary.py index 99d456546..d47b8365a 100644 --- a/.github/scripts/microbench_summary.py +++ b/.github/scripts/microbench_summary.py @@ -97,6 +97,8 @@ def get_op_pattern(base_op_name: str, get_backward: bool) -> tuple: 'normal': ('normal_', r'\bnormal_\b'), 'bernoulli': ('bernoulli_', r'\bbernoulli_\b'), 'cauchy': ('cauchy_', r'\bcauchy_\b'), + 'embedding_bag': ('_embedding_bag', r'\b_embedding_bag\b'), + 'nonzero': ('nonzero', r'\bnonzero\b'), 'index_fill': ('index_fill_', r'\bindex_fill_\b'), 'index_put': ('index_put_', r'\bindex_put_\b'), 'put': ('put_', r'\bput_\b'), diff --git a/test/microbench/indexing.index.py b/test/microbench/indexing.index.py index 4b0c82bb0..0f631f81b 100644 --- a/test/microbench/indexing.index.py +++ b/test/microbench/indexing.index.py @@ -48,7 +48,7 @@ print( "shape:", (shape), - "; dtype:", + "; datatype:", dtype, "; mode:", mode, diff --git a/test/microbench/indexing.index_put.py b/test/microbench/indexing.index_put.py index f3b66a9f8..039079a5c 100644 --- a/test/microbench/indexing.index_put.py +++ b/test/microbench/indexing.index_put.py @@ -49,7 +49,7 @@ print( "shape:", (shape), - "; dtype:", + "; datatype:", dtype, "; mode:", mode,