From 14b2d39deb4880bdebeb3a6270a3a0a8fdda67c9 Mon Sep 17 00:00:00 2001 From: James Wu Date: Tue, 20 Aug 2024 09:03:58 -0700 Subject: [PATCH] Log PT2 chromium events to scuba (#2424) Summary: X-link: https://github.com/pytorch/pytorch/pull/133859 This diff implements a bunch of views for internal scuba viewing. TODOS that I might punt to another diff: - Saving cache stats via counter is definitely sus here, but there's not really a good way to track "fx graph cache hit for this compile phase" right now. Will think about this more. - We should definitely log frame id, compile id, etc - We should definitely be logging configs. That way, we can A/B test based on whether a config is turned on. - idk what I'm doing with compile_uuid yet, but it's useful when you want to look at samples for a single run. I think if we had mast job info this field is not needed, but it's nice to be able to drill down to a single run and get its chrome trace view or icicle view, so idk Reviewed By: ezyang Differential Revision: D61392607 --- .../dynamo/dynamobench/_dynamo/utils.py | 103 +++++++++++++++--- 1 file changed, 89 insertions(+), 14 deletions(-) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 6b4ec04ec..5c5b34e30 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -26,6 +26,7 @@ import time import types import typing +import uuid import warnings import weakref from contextlib import contextmanager @@ -64,7 +65,7 @@ from torch._dispatch.python import enable_python_dispatcher from torch._guards import TracingContext from torch._subclasses.meta_utils import is_sparse_compressed -from torch._utils_internal import log_compilation_event +from torch._utils_internal import log_chromium_event_internal, log_compilation_event from torch.fx._utils import _format_graph_code, lazy_format_graph_code from torch.nn.modules.lazy import LazyModuleMixin from torch.utils._triton import has_triton, has_triton_package @@ -212,6 +213,16 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None: frame_phase_timing[key][phase_name] += time_spent +def get_cache_stats() -> Dict[str, Any]: + """Get a bunch of metadata about cache hits and misses to use in chromium events""" + cache_stats = { + "fxgraph_cache_hit": counters["inductor"]["fxgraph_cache_hit"], + "fxgraph_cache_miss": counters["inductor"]["fxgraph_cache_miss"], + "fxgraph_cache_bypass": counters["inductor"]["fxgraph_cache_bypass"], + } + return cache_stats + + # dynamo_timed is a context manager # By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics # where the key is the functions name. @@ -245,6 +256,7 @@ def dynamo_timed( phase_name: Optional[str] = None, fwd_only: bool = True, ): + chromium_log: ChromiumEventLogger = get_chromium_event_logger() if key not in compilation_time_metrics: compilation_time_metrics[key] = [] @@ -254,13 +266,22 @@ def dynamo_timed( try: with torch.profiler.record_function(f"{key} (dynamo_timed)"): t0 = time.time() - ChromiumEventLogger.log_event_start(key, time.time_ns()) + start = time.time_ns() + chromium_log.log_event_start(key, start, None) if phase_name: - ChromiumEventLogger.log_event_start(phase_name, time.time_ns()) + chromium_log.log_event_start(phase_name, start) yield + if phase_name: - ChromiumEventLogger.log_event_end(phase_name, time.time_ns()) - ChromiumEventLogger.log_event_end(key, time.time_ns()) + chromium_log.log_event_end( + phase_name, + time.time_ns(), + {"cache_stats": get_cache_stats()}, + start, + ) + chromium_log.log_event_end( + key, time.time_ns(), {"cache_stats": get_cache_stats()}, start + ) time_spent = time.time() - t0 compilation_time_metrics[key].append(time_spent) except Exception as e: @@ -814,8 +835,17 @@ class ChromiumEventLogger: a specification of the Chromium Event JSON format. """ - @staticmethod + def __init__(self): + self.stack = ["__start__"] + # Generate a unique id for this logger, which we can use in scuba to filter down + # to a single python run. + self.id_ = str(uuid.uuid4()) + + # TODO: log to init/id tlparse after I add support for it + log.info("ChromiumEventLogger initialized with id %s", self.id_) + def log_event_start( + self, event_name: str, time_ns: int, metadata: Optional[Dict[str, Any]] = None, @@ -826,18 +856,24 @@ def log_event_start( :param time_ns Timestamp in nanoseconds :param metadata: Any extra metadata associated with this event """ - ChromiumEventLogger._log_timed_event( + event = self._log_timed_event( event_name, time_ns, "B", metadata, ) + log_chromium_event_internal(event, self.stack, self.id_) + self.stack.append(event_name) + + def reset(self) -> None: + self.stack = ["__start__"] - @staticmethod def log_event_end( + self, event_name: str, time_ns: int, metadata: Optional[Dict[str, Any]] = None, + start_time_ns: Optional[int] = None, ) -> None: """ Logs the end of a single event. This function should only be @@ -846,28 +882,53 @@ def log_event_end( :param time_ns: Timestamp in nanoseconds :param metadata: Any extra metadata associated with this event """ - ChromiumEventLogger._log_timed_event( + # These stack health checks currently never happen, + # but they're written this way to future proof any weird event + # overlaps in the future. + if event_name not in self.stack: + # Something went wrong, we never called start on this event, + # or it was skipped due to overlapping events below + log.warning("ChromiumEventLogger: Start event not in stack, ignoring") + return + + event = self._log_timed_event( event_name, time_ns, "E", metadata, ) - @staticmethod + while event_name != self.stack[-1]: + # If the event isn't the most recent one to end, pop + # off the stack until it is. + # Since event_name in self.stack, this pop is always safe + log.warning( + "ChromiumEventLogger: Detected overlapping events, fixing stack" + ) + self.stack.pop() + + log_chromium_event_internal(event, self.stack, self.id_, start_time_ns) + # Finally pop the actual event off the stack + self.stack.pop() + def _log_timed_event( + self, event_name: str, time_ns: int, phase: str, metadata: Optional[Dict[str, Any]] = None, - ) -> None: + ) -> Dict[str, Any]: """ Logs a timed event in chromium format. See log_event_start, log_event_end, etc. """ event = { "name": event_name, - "ts": time_ns / 1000, # Chromium events are in ms + "ts": time_ns / 1000, # Chromium events are in micro seconds "args": metadata, "ph": phase, + # These categories are needed in all chromium traces + "cat": "dynamo_timed", + "tid": 0, "pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id } torch._logging.trace_structured( @@ -876,9 +937,10 @@ def _log_timed_event( suppress_context=False, expect_trace_id=False, # Not every chromium event will have a trace_id ) + return event - @staticmethod def log_instant_event( + self, event_name: str, time_ns: int, metadata: Optional[Dict[str, Any]] = None, @@ -895,7 +957,10 @@ def log_instant_event( "ts": time_ns / 1000, "args": metadata, "ph": "i", - "pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id + # These categories are needed in all chromium traces + "cat": "dynamo_timed", + "tid": 0, + "pid": 0, "s": "p", # We use "process" level instant events so they all appear on the same row in the trace. } torch._logging.trace_structured( @@ -904,6 +969,16 @@ def log_instant_event( suppress_context=False, expect_trace_id=True, ) + # Log an instant event with the same start and end time + log_chromium_event_internal(event, self.stack, self.id_) + + +CHROMIUM_EVENT_LOG : Optional[ChromiumEventLogger] = None +def get_chromium_event_logger() -> ChromiumEventLogger: + global CHROMIUM_EVENT_LOG + if CHROMIUM_EVENT_LOG is None: + CHROMIUM_EVENT_LOG = ChromiumEventLogger() + return CHROMIUM_EVENT_LOG @dataclasses.dataclass