diff --git a/.automation_scripts/parse_xml_results.py b/.automation_scripts/parse_xml_results.py
new file mode 100644
index 000000000000..7db2e1ce9233
--- /dev/null
+++ b/.automation_scripts/parse_xml_results.py
@@ -0,0 +1,178 @@
+""" The Python PyTorch testing script.
+##
+# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+"""
+
+import xml.etree.ElementTree as ET
+from pathlib import Path
+from typing import Any, Dict, Tuple
+
+# Backends list
+BACKENDS_LIST = [
+ "dist-gloo",
+ "dist-nccl"
+]
+
+TARGET_WORKFLOW = "--rerun-disabled-tests"
+
+def get_job_id(report: Path) -> int:
+ # [Job id in artifacts]
+ # Retrieve the job id from the report path. In our GHA workflows, we append
+ # the job id to the end of the report name, so `report` looks like:
+ # unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
+ # and we want to get `5596745227` out of it.
+ try:
+ return int(report.parts[0].rpartition("_")[2])
+ except ValueError:
+ return -1
+
+def is_rerun_disabled_tests(root: ET.ElementTree) -> bool:
+ """
+ Check if the test report is coming from rerun_disabled_tests workflow
+ """
+ skipped = root.find(".//*skipped")
+ # Need to check against None here, if not skipped doesn't work as expected
+ if skipped is None:
+ return False
+
+ message = skipped.attrib.get("message", "")
+ return TARGET_WORKFLOW in message or "num_red" in message
+
+def parse_xml_report(
+ tag: str,
+ report: Path,
+ workflow_id: int,
+ workflow_run_attempt: int,
+ work_flow_name: str
+) -> Dict[Tuple[str], Dict[str, Any]]:
+ """Convert a test report xml file into a JSON-serializable list of test cases."""
+ print(f"Parsing {tag}s for test report: {report}")
+
+ job_id = get_job_id(report)
+ print(f"Found job id: {job_id}")
+
+ test_cases: Dict[Tuple[str], Dict[str, Any]] = {}
+
+ root = ET.parse(report)
+ # TODO: unlike unittest, pytest-flakefinder used by rerun disabled tests for test_ops
+ # includes skipped messages multiple times (50 times by default). This slows down
+ # this script too much (O(n)) because it tries to gather all the stats. This should
+ # be fixed later in the way we use pytest-flakefinder. A zipped test report from rerun
+ # disabled test is only few MB, but will balloon up to a much bigger XML file after
+ # extracting from a dozen to few hundred MB
+ if is_rerun_disabled_tests(root):
+ return test_cases
+
+ for test_case in root.iter(tag):
+ case = process_xml_element(test_case)
+ if tag == 'testcase':
+ case["workflow_id"] = workflow_id
+ case["workflow_run_attempt"] = workflow_run_attempt
+ case["job_id"] = job_id
+ case["work_flow_name"] = work_flow_name
+
+ # [invoking file]
+ # The name of the file that the test is located in is not necessarily
+ # the same as the name of the file that invoked the test.
+ # For example, `test_jit.py` calls into multiple other test files (e.g.
+ # jit/test_dce.py). For sharding/test selection purposes, we want to
+ # record the file that invoked the test.
+ #
+ # To do this, we leverage an implementation detail of how we write out
+ # tests (https://bit.ly/3ajEV1M), which is that reports are created
+ # under a folder with the same name as the invoking file.
+ case_name = report.parent.name
+ for ind in range(len(BACKENDS_LIST)):
+ if BACKENDS_LIST[ind] in report.parts:
+ case_name = case_name + "_" + BACKENDS_LIST[ind]
+ break
+ case["invoking_file"] = case_name
+ test_cases[ ( case["invoking_file"], case["classname"], case["name"], case["work_flow_name"] ) ] = case
+ elif tag == 'testsuite':
+ case["work_flow_name"] = work_flow_name
+ case["invoking_xml"] = report.name
+ case["running_time_xml"] = case["time"]
+ case_name = report.parent.name
+ for ind in range(len(BACKENDS_LIST)):
+ if BACKENDS_LIST[ind] in report.parts:
+ case_name = case_name + "_" + BACKENDS_LIST[ind]
+ break
+ case["invoking_file"] = case_name
+
+ test_cases[ ( case["invoking_file"], case["invoking_xml"], case["work_flow_name"] ) ] = case
+
+ return test_cases
+
+def process_xml_element(element: ET.Element) -> Dict[str, Any]:
+ """Convert a test suite element into a JSON-serializable dict."""
+ ret: Dict[str, Any] = {}
+
+ # Convert attributes directly into dict elements.
+ # e.g.
+ #
+ # becomes:
+ # {"name": "test_foo", "classname": "test_bar"}
+ ret.update(element.attrib)
+
+ # The XML format encodes all values as strings. Convert to ints/floats if
+ # possible to make aggregation possible in Rockset.
+ for k, v in ret.items():
+ try:
+ ret[k] = int(v)
+ except ValueError:
+ pass
+ try:
+ ret[k] = float(v)
+ except ValueError:
+ pass
+
+ # Convert inner and outer text into special dict elements.
+ # e.g.
+ # my_inner_text my_tail
+ # becomes:
+ # {"text": "my_inner_text", "tail": " my_tail"}
+ if element.text and element.text.strip():
+ ret["text"] = element.text
+ if element.tail and element.tail.strip():
+ ret["tail"] = element.tail
+
+ # Convert child elements recursively, placing them at a key:
+ # e.g.
+ #
+ # hello
+ # world
+ # another
+ #
+ # becomes
+ # {
+ # "foo": [{"text": "hello"}, {"text": "world"}],
+ # "bar": {"text": "another"}
+ # }
+ for child in element:
+ if child.tag not in ret:
+ ret[child.tag] = process_xml_element(child)
+ else:
+ # If there are multiple tags with the same name, they should be
+ # coalesced into a list.
+ if not isinstance(ret[child.tag], list):
+ ret[child.tag] = [ret[child.tag]]
+ ret[child.tag].append(process_xml_element(child))
+ return ret
\ No newline at end of file
diff --git a/.automation_scripts/run_pytorch_unit_tests.py b/.automation_scripts/run_pytorch_unit_tests.py
new file mode 100644
index 000000000000..514afd19624c
--- /dev/null
+++ b/.automation_scripts/run_pytorch_unit_tests.py
@@ -0,0 +1,518 @@
+#!/usr/bin/env python3
+
+""" The Python PyTorch testing script.
+##
+# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+"""
+
+import argparse
+import os
+import shutil
+import subprocess
+from subprocess import STDOUT, CalledProcessError
+
+from collections import namedtuple
+from datetime import datetime
+from pathlib import Path
+from parse_xml_results import (
+ parse_xml_report
+)
+from pprint import pprint
+from typing import Any, Dict, List
+
+# unit test status list
+UT_STATUS_LIST = [
+ "PASSED",
+ "MISSED",
+ "SKIPPED",
+ "FAILED",
+ "XFAILED",
+ "ERROR"
+]
+
+DEFAULT_CORE_TESTS = [
+ "test_nn",
+ "test_torch",
+ "test_cuda",
+ "test_ops",
+ "test_unary_ufuncs",
+ "test_autograd",
+ "inductor/test_torchinductor"
+]
+
+DISTRIBUTED_CORE_TESTS = [
+ "distributed/test_c10d_common",
+ "distributed/test_c10d_nccl",
+ "distributed/test_distributed_spawn"
+]
+
+CONSOLIDATED_LOG_FILE_NAME="pytorch_unit_tests.log"
+
+def parse_xml_reports_as_dict(workflow_run_id, workflow_run_attempt, tag, workflow_name, path="."):
+ test_cases = {}
+ items_list = os.listdir(path)
+ for dir in items_list:
+ new_dir = path + '/' + dir + '/'
+ if os.path.isdir(new_dir):
+ for xml_report in Path(new_dir).glob("**/*.xml"):
+ test_cases.update(
+ parse_xml_report(
+ tag,
+ xml_report,
+ workflow_run_id,
+ workflow_run_attempt,
+ workflow_name
+ )
+ )
+ return test_cases
+
+def get_test_status(test_case):
+ # In order of priority: S=skipped, F=failure, E=error, P=pass
+ if "skipped" in test_case and test_case["skipped"]:
+ type_message = test_case["skipped"]
+ if type_message.__contains__('type') and type_message['type'] == "pytest.xfail":
+ return "XFAILED"
+ else:
+ return "SKIPPED"
+ elif "failure" in test_case and test_case["failure"]:
+ return "FAILED"
+ elif "error" in test_case and test_case["error"]:
+ return "ERROR"
+ else:
+ return "PASSED"
+
+def get_test_message(test_case, status=None):
+ if status == "SKIPPED":
+ return test_case["skipped"] if "skipped" in test_case else ""
+ elif status == "FAILED":
+ return test_case["failure"] if "failure" in test_case else ""
+ elif status == "ERROR":
+ return test_case["error"] if "error" in test_case else ""
+ else:
+ if "skipped" in test_case:
+ return test_case["skipped"]
+ elif "failure" in test_case:
+ return test_case["failure"]
+ elif "error" in test_case:
+ return test_case["error"]
+ else:
+ return ""
+
+def get_test_file_running_time(test_suite):
+ if test_suite.__contains__('time'):
+ return test_suite["time"]
+ return 0
+
+def get_test_running_time(test_case):
+ if test_case.__contains__('time'):
+ return test_case["time"]
+ return ""
+
+def summarize_xml_files(path, workflow_name):
+ # statistics
+ TOTAL_TEST_NUM = 0
+ TOTAL_PASSED_NUM = 0
+ TOTAL_SKIPPED_NUM = 0
+ TOTAL_XFAIL_NUM = 0
+ TOTAL_FAILED_NUM = 0
+ TOTAL_ERROR_NUM = 0
+ TOTAL_EXECUTION_TIME = 0
+
+ #parse the xml files
+ test_cases = parse_xml_reports_as_dict(-1, -1, 'testcase', workflow_name, path)
+ test_suites = parse_xml_reports_as_dict(-1, -1, 'testsuite', workflow_name, path)
+ test_file_and_status = namedtuple("test_file_and_status", ["file_name", "status"])
+ # results dict
+ res = {}
+ res_item_list = [ "PASSED", "SKIPPED", "XFAILED", "FAILED", "ERROR" ]
+ test_file_items = set()
+ for (k,v) in list(test_suites.items()):
+ file_name = k[0]
+ if not file_name in test_file_items:
+ test_file_items.add(file_name)
+ # initialization
+ for item in res_item_list:
+ temp_item = test_file_and_status(file_name, item)
+ res[temp_item] = {}
+ temp_item_statistics = test_file_and_status(file_name, "STATISTICS")
+ res[temp_item_statistics] = {'TOTAL': 0, 'PASSED': 0, 'SKIPPED': 0, 'XFAILED': 0, 'FAILED': 0, 'ERROR': 0, 'EXECUTION_TIME': 0}
+ test_running_time = get_test_file_running_time(v)
+ res[temp_item_statistics]["EXECUTION_TIME"] += test_running_time
+ TOTAL_EXECUTION_TIME += test_running_time
+ else:
+ test_tuple_key_statistics = test_file_and_status(file_name, "STATISTICS")
+ test_running_time = get_test_file_running_time(v)
+ res[test_tuple_key_statistics]["EXECUTION_TIME"] += test_running_time
+ TOTAL_EXECUTION_TIME += test_running_time
+
+ for (k,v) in list(test_cases.items()):
+ file_name = k[0]
+ class_name = k[1]
+ test_name = k[2]
+ combined_name = file_name + "::" + class_name + "::" + test_name
+ test_status = get_test_status(v)
+ test_running_time = get_test_running_time(v)
+ test_message = get_test_message(v, test_status)
+ test_info_value = ""
+ test_tuple_key_status = test_file_and_status(file_name, test_status)
+ test_tuple_key_statistics = test_file_and_status(file_name, "STATISTICS")
+ TOTAL_TEST_NUM += 1
+ res[test_tuple_key_statistics]["TOTAL"] += 1
+ if test_status == "PASSED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["PASSED"] += 1
+ TOTAL_PASSED_NUM += 1
+ elif test_status == "SKIPPED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["SKIPPED"] += 1
+ TOTAL_SKIPPED_NUM += 1
+ elif test_status == "XFAILED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["XFAILED"] += 1
+ TOTAL_XFAIL_NUM += 1
+ elif test_status == "FAILED":
+ test_info_value = test_message
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["FAILED"] += 1
+ TOTAL_FAILED_NUM += 1
+ elif test_status == "ERROR":
+ test_info_value = test_message
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["ERROR"] += 1
+ TOTAL_ERROR_NUM += 1
+
+ # generate statistics_dict
+ statistics_dict = {}
+ statistics_dict["TOTAL"] = TOTAL_TEST_NUM
+ statistics_dict["PASSED"] = TOTAL_PASSED_NUM
+ statistics_dict["SKIPPED"] = TOTAL_SKIPPED_NUM
+ statistics_dict["XFAILED"] = TOTAL_XFAIL_NUM
+ statistics_dict["FAILED"] = TOTAL_FAILED_NUM
+ statistics_dict["ERROR"] = TOTAL_ERROR_NUM
+ statistics_dict["EXECUTION_TIME"] = TOTAL_EXECUTION_TIME
+ aggregate_item = workflow_name + "_aggregate"
+ total_item = test_file_and_status(aggregate_item, "STATISTICS")
+ res[total_item] = statistics_dict
+
+ return res
+
+def run_command_and_capture_output(cmd):
+ try:
+ print(f"Running command '{cmd}'")
+ with open(CONSOLIDATED_LOG_FILE_PATH, "a+") as output_file:
+ print(f"========================================", file=output_file, flush=True)
+ print(f"[RUN_PYTORCH_UNIT_TESTS] Running command '{cmd}'", file=output_file, flush=True) # send to consolidated file as well
+ print(f"========================================", file=output_file, flush=True)
+ p = subprocess.run(cmd, shell=True, stdout=output_file, stderr=STDOUT, text=True)
+ except CalledProcessError as e:
+ print(f"ERROR: Cmd {cmd} failed with return code: {e.returncode}!")
+
+def run_entire_tests(workflow_name, test_shell_path, overall_logs_path_current_run, test_reports_src):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_entire_tests/"
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_entire_tests/"
+ elif workflow_name == "inductor":
+ os.environ['TEST_CONFIG'] = 'inductor'
+ copied_logs_path = overall_logs_path_current_run + "inductor_xml_results_entire_tests/"
+ # use test.sh for tests execution
+ run_command_and_capture_output(test_shell_path)
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ entire_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+ return entire_results_dict
+
+def run_priority_tests(workflow_name, test_run_test_path, overall_logs_path_current_run, test_reports_src):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_priority_tests/"
+ # use run_test.py for tests execution
+ default_priority_test_suites = " ".join(DEFAULT_CORE_TESTS)
+ command = "python3 " + test_run_test_path + " --include " + default_priority_test_suites + " --exclude-jit-executor --exclude-distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0,1'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_priority_tests/"
+ # use run_test.py for tests execution
+ distributed_priority_test_suites = " ".join(DISTRIBUTED_CORE_TESTS)
+ command = "python3 " + test_run_test_path + " --include " + distributed_priority_test_suites + " --distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ priority_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+
+ return priority_results_dict
+
+def run_selected_tests(workflow_name, test_run_test_path, overall_logs_path_current_run, test_reports_src, selected_list):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_selected_tests/"
+ # use run_test.py for tests execution
+ default_selected_test_suites = " ".join(selected_list)
+ command = "python3 " + test_run_test_path + " --include " + default_selected_test_suites + " --exclude-jit-executor --exclude-distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0,1'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_selected_tests/"
+ # use run_test.py for tests execution
+ distributed_selected_test_suites = " ".join(selected_list)
+ command = "python3 " + test_run_test_path + " --include " + distributed_selected_test_suites + " --distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "inductor":
+ os.environ['TEST_CONFIG'] = 'inductor'
+ copied_logs_path = overall_logs_path_current_run + "inductor_xml_results_selected_tests/"
+ inductor_selected_test_suites = ""
+ non_inductor_selected_test_suites = ""
+ for item in selected_list:
+ if "inductor/" in item:
+ inductor_selected_test_suites += item
+ inductor_selected_test_suites += " "
+ else:
+ non_inductor_selected_test_suites += item
+ non_inductor_selected_test_suites += " "
+ if inductor_selected_test_suites != "":
+ inductor_selected_test_suites = inductor_selected_test_suites[:-1]
+ command = "python3 " + test_run_test_path + " --include " + inductor_selected_test_suites + " --verbose"
+ run_command_and_capture_output(command)
+ if non_inductor_selected_test_suites != "":
+ non_inductor_selected_test_suites = non_inductor_selected_test_suites[:-1]
+ command = "python3 " + test_run_test_path + " --inductor --include " + non_inductor_selected_test_suites + " --verbose"
+ run_command_and_capture_output(command)
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ selected_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+
+ return selected_results_dict
+
+def run_test_and_summarize_results(
+ pytorch_root_dir: str,
+ priority_tests: bool,
+ test_config: List[str],
+ default_list: List[str],
+ distributed_list: List[str],
+ inductor_list: List[str],
+ skip_rerun: bool) -> Dict[str, Any]:
+
+ # copy current environment variables
+ _environ = dict(os.environ)
+
+ # modify path
+ test_shell_path = pytorch_root_dir + "/.ci/pytorch/test.sh"
+ test_run_test_path = pytorch_root_dir + "/test/run_test.py"
+ repo_test_log_folder_path = pytorch_root_dir + "/.automation_logs/"
+ test_reports_src = pytorch_root_dir + "/test/test-reports/"
+ run_test_python_file = pytorch_root_dir + "/test/run_test.py"
+
+ # change directory to pytorch root
+ os.chdir(pytorch_root_dir)
+
+ # all test results dict
+ res_all_tests_dict = {}
+
+ # patterns
+ search_text = "--reruns=2"
+ replace_text = "--reruns=0"
+
+ # create logs folder
+ if not os.path.exists(repo_test_log_folder_path):
+ os.mkdir(repo_test_log_folder_path)
+
+ # Set common environment variables for all scenarios
+ os.environ['CI'] = '1'
+ os.environ['PYTORCH_TEST_WITH_ROCM'] = '1'
+ os.environ['HSA_FORCE_FINE_GRAIN_PCIE'] = '1'
+ os.environ['PYTORCH_TESTING_DEVICE_ONLY_FOR'] = 'cuda'
+ os.environ['CONTINUE_THROUGH_ERROR'] = 'True'
+ if skip_rerun:
+ # modify run_test.py in-place
+ with open(run_test_python_file, 'r') as file:
+ data = file.read()
+ data = data.replace(search_text, replace_text)
+ with open(run_test_python_file, 'w') as file:
+ file.write(data)
+
+ # Time stamp
+ current_datetime = datetime.now().strftime("%Y%m%d_%H-%M-%S")
+ print("Current date & time : ", current_datetime)
+ # performed as Job ID
+ str_current_datetime = str(current_datetime)
+ overall_logs_path_current_run = repo_test_log_folder_path + str_current_datetime + "/"
+ os.mkdir(overall_logs_path_current_run)
+
+ global CONSOLIDATED_LOG_FILE_PATH
+ CONSOLIDATED_LOG_FILE_PATH = overall_logs_path_current_run + CONSOLIDATED_LOG_FILE_NAME
+
+ # Check multi gpu availability if distributed tests are enabled
+ if ("distributed" in test_config) or len(distributed_list) != 0:
+ check_num_gpus_for_distributed()
+
+ # Install test requirements
+ command = "pip3 install -r requirements.txt && pip3 install -r .ci/docker/requirements-ci.txt"
+ run_command_and_capture_output(command)
+
+ # Run entire tests for each workflow
+ if not priority_tests and not default_list and not distributed_list and not inductor_list:
+ # run entire tests for default, distributed and inductor workflows → use test.sh
+ if not test_config:
+ check_num_gpus_for_distributed()
+ # default test process
+ res_default_all = run_entire_tests("default", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_all
+ # distributed test process
+ res_distributed_all = run_entire_tests("distributed", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_all
+ # inductor test process
+ res_inductor_all = run_entire_tests("inductor", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["inductor"] = res_inductor_all
+ else:
+ workflow_list = []
+ for item in test_config:
+ workflow_list.append(item)
+ if "default" in workflow_list:
+ res_default_all = run_entire_tests("default", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_all
+ if "distributed" in workflow_list:
+ res_distributed_all = run_entire_tests("distributed", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_all
+ if "inductor" in workflow_list:
+ res_inductor_all = run_entire_tests("inductor", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["inductor"] = res_inductor_all
+ # Run priority test for each workflow
+ elif priority_tests and not default_list and not distributed_list and not inductor_list:
+ if not test_config:
+ check_num_gpus_for_distributed()
+ # default test process
+ res_default_priority = run_priority_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_priority
+ # distributed test process
+ res_distributed_priority = run_priority_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_priority
+ # will not run inductor priority tests
+ print("Inductor priority tests cannot run since no core tests defined with inductor workflow.")
+ else:
+ workflow_list = []
+ for item in test_config:
+ workflow_list.append(item)
+ if "default" in workflow_list:
+ res_default_priority = run_priority_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_priority
+ if "distributed" in workflow_list:
+ res_distributed_priority = run_priority_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_priority
+ if "inductor" in workflow_list:
+ print("Inductor priority tests cannot run since no core tests defined with inductor workflow.")
+ # Run specified tests for each workflow
+ elif (default_list or distributed_list or inductor_list) and not test_config and not priority_tests:
+ if default_list:
+ default_workflow_list = []
+ for item in default_list:
+ default_workflow_list.append(item)
+ res_default_selected = run_selected_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src, default_workflow_list)
+ res_all_tests_dict["default"] = res_default_selected
+ if distributed_list:
+ distributed_workflow_list = []
+ for item in distributed_list:
+ distributed_workflow_list.append(item)
+ res_distributed_selected = run_selected_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src, distributed_workflow_list)
+ res_all_tests_dict["distributed"] = res_distributed_selected
+ if inductor_list:
+ inductor_workflow_list = []
+ for item in inductor_list:
+ inductor_workflow_list.append(item)
+ res_inductor_selected = run_selected_tests("inductor", test_run_test_path, overall_logs_path_current_run, test_reports_src, inductor_workflow_list)
+ res_all_tests_dict["inductor"] = res_inductor_selected
+ else:
+ raise Exception("Invalid test configurations!")
+
+ # restore environment variables
+ os.environ.clear()
+ os.environ.update(_environ)
+
+ # restore files
+ if skip_rerun:
+ # modify run_test.py in-place
+ with open(run_test_python_file, 'r') as file:
+ data = file.read()
+ data = data.replace(replace_text, search_text)
+ with open(run_test_python_file, 'w') as file:
+ file.write(data)
+
+ return res_all_tests_dict
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Run PyTorch unit tests and generate xml results summary', formatter_class=argparse.RawTextHelpFormatter)
+ parser.add_argument('--test_config', nargs='+', default=[], type=str, help="space-separated list of test workflows to be executed eg. 'default distributed'")
+ parser.add_argument('--priority_tests', action='store_true', help="run priority tests only")
+ parser.add_argument('--default_list', nargs='+', default=[], help="space-separated list of 'default' config test suites/files to be executed eg. 'test_weak test_dlpack'")
+ parser.add_argument('--distributed_list', nargs='+', default=[], help="space-separated list of 'distributed' config test suites/files to be executed eg. 'distributed/test_c10d_common distributed/test_c10d_nccl'")
+ parser.add_argument('--inductor_list', nargs='+', default=[], help="space-separated list of 'inductor' config test suites/files to be executed eg. 'inductor/test_torchinductor test_ops'")
+ parser.add_argument('--pytorch_root', default='.', type=str, help="PyTorch root directory")
+ parser.add_argument('--skip_rerun', action='store_true', help="skip rerun process")
+ parser.add_argument('--example_output', type=str, help="{'workflow_name': {\n"
+ " test_file_and_status(file_name='workflow_aggregate', status='STATISTICS'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='ERROR'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='FAILED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='PASSED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='SKIPPED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='STATISTICS'): {} \n"
+ "}}\n")
+ parser.add_argument('--example_usages', type=str, help="RUN ALL TESTS: python3 run_pytorch_unit_tests.py \n"
+ "RUN PRIORITY TESTS: python3 run_pytorch_unit_tests.py --test_config distributed --priority_test \n"
+ "RUN SELECTED TESTS: python3 run_pytorch_unit_tests.py --default_list test_weak test_dlpack --inductor_list inductor/test_torchinductor")
+ return parser.parse_args()
+
+def check_num_gpus_for_distributed():
+ p = subprocess.run("rocminfo | grep -cE 'Name:\s+gfx'", shell=True, capture_output=True, text=True)
+ num_gpus_visible = int(p.stdout)
+ assert num_gpus_visible > 1, "Number of visible GPUs should be >1 to run distributed unit tests"
+
+def main():
+ args = parse_args()
+ all_tests_results = run_test_and_summarize_results(args.pytorch_root, args.priority_tests, args.test_config, args.default_list, args.distributed_list, args.inductor_list, args.skip_rerun)
+ pprint(dict(all_tests_results))
+
+if __name__ == "__main__":
+ main()
diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt
index 71d3ef714fbe..99ec5b4aa341 100644
--- a/.ci/docker/ci_commit_pins/triton.txt
+++ b/.ci/docker/ci_commit_pins/triton.txt
@@ -1 +1 @@
-bbb06c0334a6772b92d24bde54956e675c8c6604
+d08c31a24d622b4bf767a6645135b7b3d0d886f4
diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh
index f48140952c3a..8e714bcb6cd3 100755
--- a/.ci/docker/common/install_triton.sh
+++ b/.ci/docker/common/install_triton.sh
@@ -21,7 +21,7 @@ elif [ -n "${TRITON_CPU}" ]; then
TRITON_REPO="https://github.com/triton-lang/triton-cpu"
TRITON_TEXT_FILE="triton-cpu"
else
- TRITON_REPO="https://github.com/triton-lang/triton"
+ TRITON_REPO="https://github.com/ROCm/triton"
TRITON_TEXT_FILE="triton"
fi
diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt
index 583136d7df2f..248ee8409036 100644
--- a/.ci/docker/requirements-ci.txt
+++ b/.ci/docker/requirements-ci.txt
@@ -113,9 +113,8 @@ ninja==1.11.1.3
#test that import: run_test.py, test_cpp_extensions_aot.py,test_determination.py
numba==0.49.0 ; python_version < "3.9" and platform_machine != "s390x"
-numba==0.55.2 ; python_version == "3.9" and platform_machine != "s390x"
-numba==0.55.2 ; python_version == "3.10" and platform_machine != "s390x"
-numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x"
+numba==0.60.0 ; python_version == "3.9" and platform_machine != "s390x"
+numba==0.61.2 ; python_version > "3.9" and platform_machine != "s390x"
#Description: Just-In-Time Compiler for Numerical Functions
#Pinned versions: 0.54.1, 0.49.0, <=0.49.1
#test that import: test_numba_integration.py
@@ -134,12 +133,10 @@ numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x"
#test_nn.py, test_namedtensor.py, test_linalg.py, test_jit_cuda_fuser.py,
#test_jit.py, test_indexing.py, test_datapipe.py, test_dataloader.py,
#test_binary_ufuncs.py
-numpy==1.22.4; python_version == "3.9" or python_version == "3.10"
-numpy==1.26.2; python_version == "3.11" or python_version == "3.12"
-numpy==2.1.2; python_version >= "3.13"
+numpy==2.0.2 ; python_version == "3.9"
+numpy==2.1.2 ; python_version > "3.9"
-pandas==2.0.3; python_version < "3.13"
-pandas==2.2.3; python_version >= "3.13"
+pandas==2.2.3
#onnxruntime
#Description: scoring engine for Open Neural Network Exchange (ONNX) models
@@ -169,10 +166,11 @@ pillow==11.0.0
#Pinned versions: 10.3.0
#test that import:
-protobuf==5.29.4
-#Description: Google's data interchange format
-#Pinned versions: 5.29.4
-#test that import: test_tensorboard.py, test/onnx/*
+protobuf==3.20.2 ; python_version <= "3.12"
+protobuf==4.25.1 ; python_version == "3.13"
+#Description: Google’s data interchange format
+#Pinned versions: 3.20.1
+#test that import: test_tensorboard.py
psutil
#Description: information on running processes and system utilization
@@ -250,8 +248,8 @@ scikit-image==0.22.0 ; python_version >= "3.10"
#Pinned versions: 0.20.3
#test that import:
-scipy==1.10.1 ; python_version <= "3.11"
-scipy==1.14.1 ; python_version >= "3.12"
+scipy==1.13.1 ; python_version == "3.9"
+scipy==1.14.1 ; python_version > "3.9"
# Pin SciPy because of failing distribution tests (see #60347)
#Description: scientific python
#Pinned versions: 1.10.1
@@ -310,8 +308,7 @@ z3-solver==4.15.1.0 ; platform_machine != "s390x"
#Pinned versions:
#test that import:
-tensorboard==2.13.0 ; python_version < "3.13"
-tensorboard==2.18.0 ; python_version >= "3.13"
+tensorboard==2.18.0
#Description: Also included in .ci/docker/requirements-docs.txt
#Pinned versions:
#test that import: test_tensorboard
@@ -323,7 +320,8 @@ pywavelets==1.7.0 ; python_version >= "3.12"
#Pinned versions: 1.4.1
#test that import:
-lxml==5.3.0
+lxml==5.3.0 ; python_version <= "3.12"
+lxml==6.0.0 ; python_version == "3.13"
#Description: This is a requirement of unittest-xml-reporting
# Python-3.9 binaries
@@ -335,8 +333,9 @@ sympy==1.13.3
#Pinned versions:
#test that import:
-onnx==1.18.0
-#Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal
+onnx==1.16.1 ; python_version <= "3.12"
+onnx==1.18.0 ; python_version == "3.13"
+#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
#Pinned versions:
#test that import:
diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh
index f5b949858d60..4cec3e2f6d72 100755
--- a/.circleci/scripts/binary_populate_env.sh
+++ b/.circleci/scripts/binary_populate_env.sh
@@ -5,7 +5,9 @@ export TZ=UTC
tagged_version() {
GIT_DIR="${workdir}/pytorch/.git"
GIT_DESCRIBE="git --git-dir ${GIT_DIR} describe --tags --match v[0-9]*.[0-9]*.[0-9]*"
- if [[ ! -d "${GIT_DIR}" ]]; then
+ if [[ -n "${CIRCLE_TAG:-}" ]]; then
+ echo "${CIRCLE_TAG}"
+ elif [[ ! -d "${GIT_DIR}" ]]; then
echo "Abort, abort! Git dir ${GIT_DIR} does not exists!"
kill $$
elif ${GIT_DESCRIBE} --exact >/dev/null; then
@@ -69,6 +71,8 @@ fi
export PYTORCH_BUILD_NUMBER=1
+# This part is done in the builder scripts so commenting the duplicate code
+: <<'BLOCK_COMMENT'
# Set triton version as part of PYTORCH_EXTRA_INSTALL_REQUIREMENTS
TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt)
@@ -117,6 +121,7 @@ if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_B
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${TRITON_REQUIREMENT}"
fi
fi
+BLOCK_COMMENT
USE_GLOO_WITH_OPENSSL="ON"
if [[ "$GPU_ARCH_TYPE" =~ .*aarch64.* ]]; then
diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py
index 11fa8404273d..f2851e331725 100644
--- a/.github/scripts/build_triton_wheel.py
+++ b/.github/scripts/build_triton_wheel.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
import os
+import re
import shutil
import sys
from pathlib import Path
@@ -50,6 +51,30 @@ def patch_init_py(
with open(path, "w") as f:
f.write(orig)
+def get_rocm_version() -> str:
+ rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm"
+ rocm_version = "0.0.0"
+ rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h"
+ if not os.path.isfile(rocm_version_h):
+ rocm_version_h = f"{rocm_path}/include/rocm_version.h"
+ # The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
+ if os.path.isfile(rocm_version_h):
+ RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)")
+ RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)")
+ RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)")
+ major, minor, patch = 0, 0, 0
+ for line in open(rocm_version_h):
+ match = RE_MAJOR.search(line)
+ if match:
+ major = int(match.group(1))
+ match = RE_MINOR.search(line)
+ if match:
+ minor = int(match.group(1))
+ match = RE_PATCH.search(line)
+ if match:
+ patch = int(match.group(1))
+ rocm_version = str(major)+"."+str(minor)+"."+str(patch)
+ return rocm_version
def build_triton(
*,
@@ -64,14 +89,24 @@ def build_triton(
if "MAX_JOBS" not in env:
max_jobs = os.cpu_count() or 1
env["MAX_JOBS"] = str(max_jobs)
-
+ if not release:
+ # Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
+ # while release build should only include the version, i.e. 2.1.0
+ rocm_version = get_rocm_version()
+ version_suffix = f"+rocm{rocm_version}.git{commit_hash[:8]}"
+ version += version_suffix
with TemporaryDirectory() as tmpdir:
triton_basedir = Path(tmpdir) / "triton"
triton_pythondir = triton_basedir / "python"
triton_repo = "https://github.com/openai/triton"
if device == "rocm":
- triton_pkg_name = "pytorch-triton-rocm"
+ triton_repo = "https://github.com/ROCm/triton"
+ rocm_version = get_rocm_version() # e.g., "7.0.1"
+ if tuple(map(int, rocm_version.split("."))) > (7, 0, 0):
+ triton_pkg_name = "triton"
+ else:
+ triton_pkg_name = "pytorch-triton-rocm"
elif device == "xpu":
triton_pkg_name = "pytorch-triton-xpu"
triton_repo = "https://github.com/intel/intel-xpu-backend-for-triton"
@@ -89,6 +124,7 @@ def build_triton(
# change built wheel name and version
env["TRITON_WHEEL_NAME"] = triton_pkg_name
+ env["TRITON_WHEEL_VERSION_SUFFIX"] = version_suffix
if with_clang_ldd:
env["TRITON_BUILD_WITH_CLANG_LLD"] = "1"
diff --git a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh
index 60e1a19c1aac..a65db3f2df12 100644
--- a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh
+++ b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh
@@ -45,6 +45,24 @@ struct OffsetCalculator {
C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
offset_type offsets;
+
+#if defined(USE_ROCM)
+ if ((dims > 0) && (dims <= 2)) {
+ auto divmod = sizes_[0].divmod(linear_idx);
+ #pragma unroll
+ for (int arg = 0; arg < NARGS; arg++)
+ offsets[arg] = divmod.mod * strides_[0][arg];
+ if (dims >= 2) {
+ divmod = sizes_[1].divmod(divmod.div);
+ #pragma unroll
+ for (int arg = 0; arg < NARGS; arg++)
+ offsets[arg] += divmod.mod * strides_[1][arg];
+ }
+ // [...]
+ return offsets;
+ }
+#endif
+
#pragma unroll
for (int arg = 0; arg < NARGS; arg++) {
offsets[arg] = 0;
diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh
index 12ad84a15b18..ee28c5c1693f 100644
--- a/aten/src/ATen/native/cuda/CUDALoops.cuh
+++ b/aten/src/ATen/native/cuda/CUDALoops.cuh
@@ -999,12 +999,41 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
dtypes[i] = iter.dtype(i);
}
auto offset_calc = ::make_offset_calculator(iter);
+#ifdef USE_ROCM
+ constexpr int grp_sz = 128;
+ launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
+ if (unrl) {
+ auto offsets0 = offset_calc.get(idx);
+ auto offsets1 = offset_calc.get(idx + grp_sz);
+ auto offsets2 = offset_calc.get(idx + grp_sz * 2);
+ auto offsets3 = offset_calc.get(idx + grp_sz * 3);
+ void* out0 = data[0] + offsets0[0];
+ void* out1 = data[0] + offsets1[0];
+ void* out2 = data[0] + offsets2[0];
+ void* out3 = data[0] + offsets3[0];
+ arg0_t result0 = invoke(f, &data[1], &offsets0[1], &dtypes[1], 1);
+ arg0_t result1 = invoke(f, &data[1], &offsets1[1], &dtypes[1], 1);
+ arg0_t result2 = invoke(f, &data[1], &offsets2[1], &dtypes[1], 1);
+ arg0_t result3 = invoke(f, &data[1], &offsets3[1], &dtypes[1], 1);
+ c10::cast_and_store(dtypes[0], out0, result0);
+ c10::cast_and_store(dtypes[0], out1, result1);
+ c10::cast_and_store(dtypes[0], out2, result2);
+ c10::cast_and_store(dtypes[0], out3, result3);
+ } else {
+ auto offsets = offset_calc.get(idx);
+ void* out = data[0] + offsets[0];
+ arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1);
+ c10::cast_and_store(dtypes[0], out, result);
+ }
+ });
+#else
launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
void* out = data[0] + offsets[0];
arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1);
c10::cast_and_store(dtypes[0], out, result);
});
+#endif
}
}
diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu
index 59b0426bab1f..62a07e1e28c8 100644
--- a/aten/src/ATen/native/cuda/Copy.cu
+++ b/aten/src/ATen/native/cuda/Copy.cu
@@ -42,6 +42,19 @@ void bfloat16_copy_kernel_cuda(TensorIteratorBase &iter) {
});
}
+#ifdef USE_ROCM
+void bfloat16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) {
+ gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::BFloat16 value) {
+ return static_cast(value);
+ });
+}
+void float16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) {
+ gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::Half value) {
+ return static_cast(value);
+ });
+}
+#endif
+
void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
ScalarType dtype = iter.dtype(0);
ScalarType other_dtype = iter.dtype(1);
@@ -187,7 +200,17 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
} else {
float16_copy_kernel_cuda(iter);
}
- } else if (isBitsType(dtype)) {
+ }
+#ifdef USE_ROCM
+ else if ((iter.dtype(1) == kBFloat16 || iter.dtype(1) == kHalf) && dtype == kFloat) {
+ if (iter.dtype(1) == kBFloat16) {
+ bfloat16tofloat32_copy_kernel_cuda(iter);
+ } else {
+ float16tofloat32_copy_kernel_cuda(iter);
+ }
+ }
+#endif
+ else if (isBitsType(dtype)) {
TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting "
"bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype);
AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {
diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu
index 02feb55cb69d..e49fffc2effc 100644
--- a/aten/src/ATen/native/cuda/Indexing.cu
+++ b/aten/src/ATen/native/cuda/Indexing.cu
@@ -56,8 +56,7 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
#endif
}
-#ifdef USE_ROCM
-#define SKIP_SORTED_INDICES 32
+#if 0
template
__global__ void indexing_backward_kernel(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -142,7 +141,10 @@ __global__ void indexing_backward_kernel(
}
}
}
+#endif
+#ifdef USE_ROCM
+#define SKIP_SORTED_INDICES 32
template
__global__ void indexing_backward_kernel_stride_1(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -254,7 +256,8 @@ __global__ void indexing_backward_kernel_stride_1(
}
}
}
-#else
+#endif
+
template
__global__ void indexing_backward_kernel(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -333,6 +336,7 @@ __global__ void indexing_backward_kernel(
}
}
+#ifndef USE_ROCM
template
__global__ void indexing_backward_kernel_stride_1(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -784,7 +788,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<<>>(
+ indexing_backward_kernel<<>>(
sorted_indices.const_data_ptr(),
orig_indices.const_data_ptr(),
expandedValue.const_data_ptr(),
diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu
index 940680eb3682..81387bcceaf0 100644
--- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu
+++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu
@@ -141,7 +141,11 @@ WelfordDataLN cuWelfordOnlineSum(
if constexpr (!rms_norm){
U delta = val - curr_sum.mean;
U new_count = curr_sum.count + 1.f;
+#if defined(USE_ROCM) && defined(PYTORCH_LAYERNORM_FAST_RECIPROCAL)
+ U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count);
+#else
U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster
+#endif
return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count};
} else{
return {0.f, curr_sum.sigma2 + val * val, 0};
@@ -159,7 +163,11 @@ WelfordDataLN cuWelfordCombine(
U count = dataA.count + dataB.count;
U mean, sigma2;
if (count > decltype(dataB.count){0}) {
+#if defined(USE_ROCM) && defined(PYTORCH_LAYERNORM_FAST_RECIPROCAL)
+ auto coef = __builtin_amdgcn_rcpf(count);
+#else
auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division
+#endif
auto nA = dataA.count * coef;
auto nB = dataB.count * coef;
mean = nA*dataA.mean + nB*dataB.mean;
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index ef5c2fd4e97d..daceebd8bc88 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1037,6 +1037,22 @@ if(USE_ROCM)
list(APPEND HIP_HIPCC_FLAGS -fdebug-info-for-profiling)
endif(CMAKE_BUILD_TYPE MATCHES Debug)
+ # Get EnVar 'PYTORCH_LAYERNORM_FAST_RECIPROCAL' (or default to on).
+ if(DEFINED ENV{PYTORCH_LAYERNORM_FAST_RECIPROCAL})
+ set(PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE $ENV{PYTORCH_LAYERNORM_FAST_RECIPROCAL})
+ else()
+ set(PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE ON)
+ endif()
+
+ set(PYTORCH_LAYERNORM_FAST_RECIPROCAL
+ ${PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE}
+ CACHE BOOL "Enable fast reciprocals within layer normalization." FORCE
+ )
+
+ if(PYTORCH_LAYERNORM_FAST_RECIPROCAL)
+ add_definitions(-DPYTORCH_LAYERNORM_FAST_RECIPROCAL)
+ endif()
+
# needed for compat with newer versions of hip-clang that introduced C++20 mangling rules
list(APPEND HIP_HIPCC_FLAGS -fclang-abi-compat=17)
diff --git a/related_commits b/related_commits
new file mode 100644
index 000000000000..b96cf18c181a
--- /dev/null
+++ b/related_commits
@@ -0,0 +1,10 @@
+ubuntu|pytorch|apex|release/1.9.0|07c3ee5347294b7a07a65c2c3596f1b14c7d3daa|https://github.com/ROCm/apex
+centos|pytorch|apex|release/1.9.0|07c3ee5347294b7a07a65c2c3596f1b14c7d3daa|https://github.com/ROCm/apex
+ubuntu|pytorch|torchvision|release/0.24|b919bd0c56abbb3c5ca056a3a458af9fd1cabf52|https://github.com/pytorch/vision
+centos|pytorch|torchvision|release/0.24|b919bd0c56abbb3c5ca056a3a458af9fd1cabf52|https://github.com/pytorch/vision
+ubuntu|pytorch|torchdata|release/0.11|377e64c1be69a9be6649d14c9e3664070323e464|https://github.com/pytorch/data
+centos|pytorch|torchdata|release/0.11|377e64c1be69a9be6649d14c9e3664070323e464|https://github.com/pytorch/data
+ubuntu|pytorch|torchaudio|release/2.9|e3c6ee2b6588b7cd27a84182de74bf12fe043831|https://github.com/pytorch/audio
+centos|pytorch|torchaudio|release/2.9|e3c6ee2b6588b7cd27a84182de74bf12fe043831|https://github.com/pytorch/audio
+ubuntu|pytorch|ao|main|a52a64aeb84fa6ff683ec2c7c42b97e27651a619|https://github.com/pytorch/ao
+centos|pytorch|ao|main|a52a64aeb84fa6ff683ec2c7c42b97e27651a619|https://github.com/pytorch/ao
diff --git a/requirements.txt b/requirements.txt
index fc4b53dfd49e..f6dc86a0aa46 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,6 +12,9 @@ hypothesis
jinja2
lintrunner ; platform_machine != "s390x" and platform_machine != "riscv64"
networkx>=2.5.1
+ninja
+numpy==2.0.2 ; python_version == "3.9"
+numpy==2.1.2 ; python_version > "3.9"
optree>=0.13.0
psutil
sympy>=1.13.3
diff --git a/setup.py b/setup.py
index 11ca48482a76..ae0097465da6 100644
--- a/setup.py
+++ b/setup.py
@@ -162,6 +162,10 @@
# USE_ROCM_CK_SDPA=1
# Enable building CK SDPA backend in ROCm platform
#
+# PYTORCH_LAYERNORM_FAST_RECIPROCAL
+# If set, enables the use of builtin functions for fast reciprocals (1/x) w.r.t.
+# layer normalization. Default: enabled.
+#
# Environment variables we respect (these environment variables are
# conventional and are often understood/set by other software.)
#
diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py
index 90399546d26e..6523cddcec6d 100644
--- a/test/inductor/test_combo_kernels.py
+++ b/test/inductor/test_combo_kernels.py
@@ -296,23 +296,6 @@ def fn(a0, a1, a2, b0, b1, b2):
self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)
- @requires_cuda_and_triton
- def test_persistent_reduction_no_x_dim(self):
- def fn(x, y):
- return x.sum(1), y.sum(1)
-
- inps = (
- torch.rand(16, 256, device="cuda"),
- torch.rand(32, 256, device="cuda"),
- )
- torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256)
- torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256)
- out_eager = fn(*inps)
- out_compiled = torch.compile(fn)(*inps)
-
- self.assertEqual(out_eager, out_compiled)
- self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)
-
@instantiate_parametrized_tests
class ComboKernelDynamicShapesTests(TestCase):
diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py
index 41db6b18daba..6bde7a8c540a 100644
--- a/test/inductor/test_torchinductor_strided_blocks.py
+++ b/test/inductor/test_torchinductor_strided_blocks.py
@@ -816,6 +816,7 @@ def test_2d_reduction_odd_shapes(
# Check the code for multiple Rn_BLOCK's
self._assert_reduction_ndims(code, 2)
+
@parametrize(
"size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback",
[
diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py
index 1c31d5445f91..569d1bac8595 100644
--- a/test/test_binary_ufuncs.py
+++ b/test/test_binary_ufuncs.py
@@ -1480,8 +1480,8 @@ def to_np(value):
self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent)
elif torch.can_cast(torch.result_type(base, exponent), base.dtype):
actual2 = actual.pow_(exponent)
- self.assertEqual(actual, expected)
- self.assertEqual(actual2, expected)
+ self.assertEqual(actual, expected.to(actual))
+ self.assertEqual(actual2, expected.to(actual))
else:
self.assertRaisesRegex(
RuntimeError,
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 293bb2b7e701..7985a2cd9fe8 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -504,6 +504,9 @@ def test_out_of_memory_retry(self):
IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support"
)
def test_set_per_process_memory_fraction(self):
+ if torch.version.hip and ('gfx1101' in torch.cuda.get_device_properties(0).gcnArchName):
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
orig = torch.cuda.get_per_process_memory_fraction(0)
torch.cuda.reset_peak_memory_stats(0)
try:
diff --git a/third_party/fbgemm b/third_party/fbgemm
index 4b39c551efe1..3cefe0564a8c 160000
--- a/third_party/fbgemm
+++ b/third_party/fbgemm
@@ -1 +1 @@
-Subproject commit 4b39c551efe15e6bbade20565b0ceb2d8ce3352d
+Subproject commit 3cefe0564a8c3de514a152d40a2b4770f2ee5be0
diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py
index 417fac7b4f63..2189e44f9e24 100644
--- a/torch/_inductor/choices.py
+++ b/torch/_inductor/choices.py
@@ -232,6 +232,18 @@ def should_use_persistent_reduction(
features.reduction_numel, threshold
) # type: ignore[arg-types]
+ @staticmethod
+ def want_no_x_dim(features: SIMDKernelFeatures) -> bool:
+ """
+ Heuristic to decide if we should drop the X dimension from a persistent reduction kernel.
+ So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1.
+ Strangely this is faster than a [1, RBLOCK] block in some cases.
+
+ ROCm branch change: Remove want_no_x_dim for persistent reduction.
+ Inductor benchmarks show no perf advantage and simplifies autotune flow.
+ """
+ return False
+
@staticmethod
def reduction_split_factor(
device: torch.device,
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 175ea55ec3af..3848fc3355e4 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -1306,7 +1306,7 @@ def tan(x):
@staticmethod
@maybe_upcast_float32()
def tanh(x):
- return f"libdevice.tanh({x})"
+ return f"libdevice.fast_tanhf({x})"
@staticmethod
@maybe_upcast_float32()
@@ -2030,12 +2030,11 @@ def should_use_persistent_reduction(self) -> bool:
)
def want_no_x_dim(self):
- return (
- self.persistent_reduction
- and len(self.numels) == self.num_reduction_dims + 1
- and self.fixed_config
- and self.fixed_config["XBLOCK"] == 1
- )
+ """
+ ROCm branch change: Remove want_no_x_dim for persistent reduction.
+ Inductor benchmarks show no perf advantage and simplifies autotune flow.
+ """
+ return False
@property
def assert_function(self) -> str:
diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py
index dc2392119cc5..94a905e4211c 100644
--- a/torch/_inductor/codegen/triton_combo_kernel.py
+++ b/torch/_inductor/codegen/triton_combo_kernel.py
@@ -614,7 +614,7 @@ def jit_line(
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
- num_warps={self.num_warps},
+ filename=__file__,
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)
diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py
index ad7a0d56fc4b..26b3bcf5cc5c 100644
--- a/torch/_inductor/runtime/coordinate_descent_tuner.py
+++ b/torch/_inductor/runtime/coordinate_descent_tuner.py
@@ -3,6 +3,7 @@
import itertools
import logging
from typing import Callable, Optional, TYPE_CHECKING
+from functools import lru_cache
from .hints import TRITON_MAX_BLOCK
from .runtime_utils import red_text, triton_config_to_hashable
@@ -60,10 +61,16 @@ def get_config_max(self, prefix: str) -> int:
size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None
return min(max_block, size_hint) if size_hint is not None else max_block
+ @lru_cache(maxsize=1)
def get_warpsmax(self):
- # Currently, CUDA has a maximum of 1024 threads, so 32 is the max
- # number of warps.
- return 1024 // 32
+ # CUDA/ROCm has a maximum of 1024 threads per block
+ from torch.cuda import current_device, get_device_properties, is_available
+
+ warp_size = (
+ get_device_properties(current_device()).warp_size if is_available() else 32
+ )
+
+ return 1024 // warp_size
def cache_benchmark_result(self, config, timing):
self.cached_benchmark_results[triton_config_to_hashable(config)] = timing
diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
index 547fad522246..46832167622b 100644
--- a/torch/_inductor/runtime/triton_heuristics.py
+++ b/torch/_inductor/runtime/triton_heuristics.py
@@ -2870,6 +2870,10 @@ def _persistent_reduction_configs(
rnumel = get_total_reduction_numel(size_hints)
MAX_PERSISTENT_BLOCK_NUMEL = 4096
+ max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or (
+ inductor_meta.get("max_autotune")
+ or inductor_meta.get("max_autotune_pointwise")
+ )
if "y" not in size_hints:
configs = [
@@ -2899,18 +2903,27 @@ def _persistent_reduction_configs(
if "y" in size_hints:
pass
# TODO(jansel): we should be able to improve these heuristics
- elif reduction_hint == ReductionHint.INNER and rnumel >= 256:
- configs = configs[:1]
- elif reduction_hint == ReductionHint.OUTER:
- configs = configs[-1:]
+ if not max_autotune_enabled: # Don't filter if tuning enabled
+ if reduction_hint == ReductionHint.INNER and rnumel >= 256:
+ configs = configs[:1]
+ elif reduction_hint == ReductionHint.OUTER:
+ configs = configs[-1:]
+
+ tiny_configs = [
+ triton_config_reduction(
+ size_hints,
+ 2 * (256 // rnumel) if rnumel <= 256 else 1,
+ rnumel,
+ )
+ ]
+
+ if max_autotune_enabled:
+ for conf in tiny_configs:
+ if conf not in configs:
+ configs.append(conf)
elif reduction_hint == ReductionHint.OUTER_TINY:
- configs = [
- triton_config_reduction(
- size_hints,
- 2 * (256 // rnumel) if rnumel <= 256 else 1,
- rnumel,
- )
- ]
+ configs = tiny_configs
+
for c in configs:
# we don't need Rn_BLOCK for persistent reduction
for prefix in size_hints:
@@ -3102,20 +3115,29 @@ def user_autotune(
)
-def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
+def foreach(triton_meta, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
+ configs = []
+ if disable_pointwise_autotuning(inductor_meta) and not (
+ inductor_meta.get("max_autotune") or
+ inductor_meta.get("max_autotune_pointwise")
+ ):
+ configs.append(triton.Config({}, num_stages=1, num_warps=8))
+ else:
+ for warps in [1, 2, 4, 8]:
+ configs.append(triton.Config({}, num_stages=1, num_warps=warps))
+
return cached_autotune(
None,
- [triton.Config({}, num_stages=1, num_warps=num_warps)],
+ configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,
filename=filename,
)
-
@dataclasses.dataclass
class GridExpr:
"""Generate code for grid size expressions in launcher"""
diff --git a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h
index 9728d27d4d79..0ac2c79d1e98 100644
--- a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h
+++ b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h
@@ -260,7 +260,7 @@ typedef __half half;
)";
#endif
-#if defined(USE_ROCM)
+#if defined(USE_ROCM) && ROCM_VERSION < 70000
constexpr auto bfloat16_support_literal =
R"(
#ifndef __align__
@@ -317,6 +317,75 @@ __device__ __nv_bfloat16 __float2bfloat16(const float a) {
return val;
}
+__device__ float __bfloat162float(const __nv_bfloat16 a) {
+ union
+ {
+ uint32_t int32;
+ float fp32;
+ } u = {uint32_t(a.__x) << 16};
+ return u.fp32;
+}
+#endif /* defined(__cplusplus) */
+)";
+#elif defined(USE_ROCM) && ROCM_VERSION >= 70000
+constexpr auto bfloat16_support_literal =
+ R"(
+#ifndef __align__
+#define __align__(x) __attribute__((aligned(x)))
+#endif
+
+typedef unsigned int uint32_t;
+
+typedef struct __align__(2) {
+ unsigned short x;
+}
+__nv_bfloat16_raw;
+
+#if defined(__cplusplus)
+struct __align__(2) __nv_bfloat16 {
+ __host__ __device__ __nv_bfloat16() {}
+
+ __host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) {
+ __x = hr.x;
+ return *this;
+ }
+
+ unsigned short __x;
+};
+
+__device__ unsigned short __internal_float2bfloat16(
+ const float f,
+ unsigned int& sign,
+ unsigned int& remainder) {
+ unsigned int x;
+
+ x = __float_as_uint(f);
+
+ if ((x & 0x7fffffffU) > 0x7f800000U) {
+ sign = 0U;
+ remainder = 0U;
+ return static_cast(0x7fffU);
+ }
+ sign = x >> 31;
+ remainder = x << 16;
+ return static_cast(x >> 16);
+}
+
+/* Definitions of intrinsics */
+__device__ __nv_bfloat16 __float2bfloat16(const float a) {
+ __nv_bfloat16 val;
+ __nv_bfloat16_raw r;
+ unsigned int sign;
+ unsigned int remainder;
+ r.x = __internal_float2bfloat16(a, sign, remainder);
+ if ((remainder > 0x80000000U) ||
+ ((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
+ r.x++;
+ }
+ val = r;
+ return val;
+}
+
__device__ float __bfloat162float(const __nv_bfloat16 a) {
union
{