diff --git a/hta/common/trace_parser.py b/hta/common/trace_parser.py index 204a954f..c96d1f61 100644 --- a/hta/common/trace_parser.py +++ b/hta/common/trace_parser.py @@ -369,6 +369,19 @@ def _parse_trace_dataframe_ijson( else: df = _parse_trace_events_ijson(trace_file_path) + if df["ts"].dtype == np.dtype("float64"): + logger.warning( + f"Rounding down ns resolution events due to issue with events overlapping." + f" ts dtype = {df['ts'].dtype}, dur dtype = {df['dur'].dtype}." + f"Please see https://github.com/pytorch/pytorch/pull/122425" + ) + # Don't floor directly, first find the end + df["end"] = df["ts"] + df["dur"] + + df["ts"] = df[~df["ts"].isnull()]["ts"].apply(lambda x: math.ceil(x)) + df["end"] = df[~df["end"].isnull()]["end"].apply(lambda x: math.floor(x)) + df["dur"] = df["end"] - df["ts"] + # assign an index to each event df.reset_index(inplace=True) df["index"] = pd.to_numeric(df["index"], downcast="integer")