Skip to content

Commit

Permalink
Revert "[logging] A few fixes/updates to record_compilation_metrics (p…
Browse files Browse the repository at this point in the history
…ytorch#143332)"

This reverts commit a9c753b.

Reverted pytorch#143332 on behalf of https://github.com/malfet due to Surprisingly failure is caused by this PR ([comment](pytorch#143332 (comment)))
  • Loading branch information
pytorchmergebot committed Dec 21, 2024
1 parent bf7009d commit ad7ab5e
Showing 1 changed file with 57 additions and 77 deletions.
134 changes: 57 additions & 77 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,77 +922,6 @@ class CompilationMetrics:
tensorify_float_success: Optional[bool] = None
tensorify_float_failure: Optional[Set[str]] = None

@classmethod
def create(cls, metrics: Dict[str, Any]):
"""
Factory method to create a CompilationMetrics from a dict of fields.
Includes the logic to add legacy fields and any pre-processing, e.g.,
we transform some fields to comma-separated strings for scuba logging.
"""

def us_to_s(metric: Optional[int]) -> Optional[float]:
return metric / 1e6 if metric is not None else None

def us_to_ms(metric: Optional[int]) -> Optional[int]:
return metric // 1000 if metric is not None else None

def collection_to_str(metric: Optional[Any]) -> Optional[str]:
def safe_str(item: Any) -> str:
try:
return str(item)
except Exception:
return "<unknown>"

if metric is None:
return None

if not isinstance(metric, (set, list)):
return "<unknown>"

return ",".join(safe_str(item) for item in sorted(metric))

# TODO: The following are legacy fields, populated from the fields that replace
# them. Remove these when we decide we can really deprecate them.
legacy_metrics = {
"start_time": us_to_s(metrics.get("start_time_us")),
"entire_frame_compile_time_s": us_to_s(
metrics.get("dynamo_cumulative_compile_time_us")
),
"backend_compile_time_s": us_to_s(
metrics.get("aot_autograd_cumulative_compile_time_us")
),
"inductor_compile_time_s": us_to_s(
metrics.get("inductor_cumulative_compile_time_us")
),
"code_gen_time_s": us_to_s(
metrics.get("inductor_code_gen_cumulative_compile_time_us")
),
"remote_cache_time_saved_s": us_to_s(
metrics.get("distributed_ephemeral_timeout_us")
),
"remote_fx_graph_cache_get_time_ms": us_to_ms(
metrics.get("remote_fx_graph_cache_get_time_us")
),
"remote_fx_graph_cache_put_time_ms": us_to_ms(
metrics.get("remote_fx_graph_cache_put_time_us")
),
"structured_logging_overhead_s": us_to_s(
metrics.get("structured_logging_overhead_us")
),
}

all_metrics = {**legacy_metrics, **metrics}

# Pre-processing:
all_metrics["inductor_fx_remote_cache_hit_keys"] = collection_to_str(
all_metrics.get("inductor_fx_remote_cache_hit_keys")
)
all_metrics["inductor_fx_remote_cache_miss_keys"] = collection_to_str(
all_metrics.get("inductor_fx_remote_cache_miss_keys")
)

return cls(**all_metrics)


DEFAULT_COMPILATION_METRICS_LIMIT = 64

Expand Down Expand Up @@ -1088,6 +1017,33 @@ def record_compilation_metrics(
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
):
def us_to_s(field):
metric = metrics.get(field, None)
return metric / 1e6 if metric is not None else None

def us_to_ms(field):
metric = metrics.get(field, None)
return metric // 1000 if metric is not None else None

def _convert_collection_to_str(field: str) -> Optional[str]:
def safe_str(item: Any) -> str:
try:
return str(item)
except Exception:
return str(None)

metric = metrics.get(field, None)
if metric is None:
return None

# Remove this field (list/set) from metrics to avoid clashes
del metrics[field]
if not isinstance(metric, set) and not isinstance(metric, list):
return None
return ",".join(safe_str(item) for item in metric)

structured_logging_overhead_s = torch._logging.get_structured_logging_overhead()

if torch._inductor.utils.should_use_remote_fx_graph_cache():
try:
from torch._inductor.fb.remote_cache import (
Expand All @@ -1105,8 +1061,8 @@ def record_compilation_metrics(
inductor_fx_remote_cache_backend_type = None
remote_cache_version = None

# Populate the compile_id from the metrics context if it's set. Otherwise,
# look for it in the current compile context.
# Populate the compile_id from the metrics context if it's set. Otherwise
# look for it in the compile context.
compile_id = metrics.get("compile_id")
if not compile_id:
compile_id = torch._guards.CompileContext.current_compile_id()
Expand All @@ -1118,17 +1074,41 @@ def record_compilation_metrics(
"duration_us": (end_time_ns - start_time_ns) // 1000,
"fail_type": exc_type.__qualname__ if exc_type else None,
"fail_reason": str(exc_value) if exc_value else None,
"structured_logging_overhead_us": to_int_us(
torch._logging.get_structured_logging_overhead()
),
"structured_logging_overhead_us": to_int_us(structured_logging_overhead_s),
"inductor_config": _scrubbed_inductor_config_for_logging(),
"cuda_version": torch.version.cuda,
"triton_version": triton.__version__ if has_triton() else "",
"inductor_fx_remote_cache_hit_keys": _convert_collection_to_str(
"inductor_fx_remote_cache_hit_keys"
),
"inductor_fx_remote_cache_miss_keys": _convert_collection_to_str(
"inductor_fx_remote_cache_miss_keys"
),
"remote_cache_version": remote_cache_version,
"inductor_fx_remote_cache_backend_type": inductor_fx_remote_cache_backend_type,
}

compilation_metrics = CompilationMetrics.create({**metrics, **common_metrics})
# TODO: The following are legacy fields, populated from the fields that replace
# them. Remove these when we decide we can really deprecate them.
legacy_metrics = {
"start_time": start_time_ns / 1e9,
"entire_frame_compile_time_s": us_to_s("dynamo_cumulative_compile_time_us"),
"backend_compile_time_s": us_to_s("aot_autograd_cumulative_compile_time_us"),
"inductor_compile_time_s": us_to_s("inductor_cumulative_compile_time_us"),
"code_gen_time_s": us_to_s("inductor_code_gen_cumulative_compile_time_us"),
"remote_cache_time_saved_s": us_to_s("distributed_ephemeral_timeout_us"),
"remote_fx_graph_cache_get_time_ms": us_to_ms(
"remote_fx_graph_cache_get_time_us"
),
"remote_fx_graph_cache_put_time_ms": us_to_ms(
"remote_fx_graph_cache_put_time_us"
),
"structured_logging_overhead_s": structured_logging_overhead_s,
}

compilation_metrics = CompilationMetrics(
**{**legacy_metrics, **common_metrics, **metrics}
)
_compilation_metrics.append(compilation_metrics)

name = "compilation_metrics"
Expand Down

0 comments on commit ad7ab5e

Please sign in to comment.