Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log PT2 chromium events to scuba #2424

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 94 additions & 14 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import time
import types
import typing
import uuid
import warnings
import weakref
from contextlib import contextmanager
Expand Down Expand Up @@ -64,7 +65,7 @@
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import TracingContext
from torch._subclasses.meta_utils import is_sparse_compressed
from torch._utils_internal import log_compilation_event
from torch._utils_internal import log_chromium_event_internal, log_compilation_event
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
from torch.nn.modules.lazy import LazyModuleMixin
from torch.utils._triton import has_triton, has_triton_package
Expand Down Expand Up @@ -212,6 +213,16 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
frame_phase_timing[key][phase_name] += time_spent


def get_cache_stats() -> Dict[str, Any]:
"""Get a bunch of metadata about cache hits and misses to use in chromium events"""
cache_stats = {
"fxgraph_cache_hit": counters["inductor"]["fxgraph_cache_hit"],
"fxgraph_cache_miss": counters["inductor"]["fxgraph_cache_miss"],
"fxgraph_cache_bypass": counters["inductor"]["fxgraph_cache_bypass"],
}
return cache_stats


# dynamo_timed is a context manager
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
# where the key is the functions name.
Expand Down Expand Up @@ -245,6 +256,7 @@ def dynamo_timed(
phase_name: Optional[str] = None,
fwd_only: bool = True,
):
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
if key not in compilation_time_metrics:
compilation_time_metrics[key] = []

Expand All @@ -254,13 +266,22 @@ def dynamo_timed(
try:
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
t0 = time.time()
ChromiumEventLogger.log_event_start(key, time.time_ns())
start = time.time_ns()
chromium_log.log_event_start(key, start, None)
if phase_name:
ChromiumEventLogger.log_event_start(phase_name, time.time_ns())
chromium_log.log_event_start(phase_name, start)
yield

if phase_name:
ChromiumEventLogger.log_event_end(phase_name, time.time_ns())
ChromiumEventLogger.log_event_end(key, time.time_ns())
chromium_log.log_event_end(
phase_name,
time.time_ns(),
{"cache_stats": get_cache_stats()},
start,
)
chromium_log.log_event_end(
key, time.time_ns(), {"cache_stats": get_cache_stats()}, start
)
time_spent = time.time() - t0
compilation_time_metrics[key].append(time_spent)
except Exception as e:
Expand Down Expand Up @@ -814,8 +835,17 @@ class ChromiumEventLogger:
a specification of the Chromium Event JSON format.
"""

@staticmethod
def __init__(self):
self.stack = ["__start__"]
# Generate a unique id for this logger, which we can use in scuba to filter down
# to a single python run.
self.id_ = str(uuid.uuid4())

# TODO: log to init/id tlparse after I add support for it
log.info("ChromiumEventLogger initialized with id %s", self.id_)

def log_event_start(
self,
event_name: str,
time_ns: int,
metadata: Optional[Dict[str, Any]] = None,
Expand All @@ -826,18 +856,27 @@ def log_event_start(
:param time_ns Timestamp in nanoseconds
:param metadata: Any extra metadata associated with this event
"""
ChromiumEventLogger._log_timed_event(
event = self._log_timed_event(
event_name,
time_ns,
"B",
metadata,
)
log_chromium_event_internal(event, self.stack, self.id_)
self.stack.append(event_name)

def reset(self) -> None:
# We this on every compile in case a compile crashes or restarts and we haven't
# cleared the stack.
self.stack.clear()
self.stack.append("__start__")

@staticmethod
def log_event_end(
self,
event_name: str,
time_ns: int,
metadata: Optional[Dict[str, Any]] = None,
start_time_ns: Optional[int] = None,
) -> None:
"""
Logs the end of a single event. This function should only be
Expand All @@ -846,28 +885,53 @@ def log_event_end(
:param time_ns: Timestamp in nanoseconds
:param metadata: Any extra metadata associated with this event
"""
ChromiumEventLogger._log_timed_event(
# These stack health checks currently never happen,
# but they're written this way to future proof any weird event
# overlaps in the future.
if event_name not in self.stack:
# Something went wrong, we never called start on this event,
# or it was skipped due to overlapping events below
log.warning("ChromiumEventLogger: Start event not in stack, ignoring")
return

event = self._log_timed_event(
event_name,
time_ns,
"E",
metadata,
)

@staticmethod
while event_name != self.stack[-1]:
# If the event isn't the most recent one to end, pop
# off the stack until it is.
# Since event_name in self.stack, this pop is always safe
log.warning(
"ChromiumEventLogger: Detected overlapping events, fixing stack"
)
self.stack.pop()

log_chromium_event_internal(event, self.stack, self.id_, start_time_ns)
# Finally pop the actual event off the stack
self.stack.pop()

def _log_timed_event(
self,
event_name: str,
time_ns: int,
phase: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
) -> Dict[str, Any]:
"""
Logs a timed event in chromium format. See log_event_start, log_event_end, etc.
"""
event = {
"name": event_name,
"ts": time_ns / 1000, # Chromium events are in ms
"ts": time_ns / 1000, # Chromium events are in micro seconds
"args": metadata,
"ph": phase,
# These categories are needed in all chromium traces
"cat": "dynamo_timed",
"tid": 0,
"pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id
}
torch._logging.trace_structured(
Expand All @@ -876,9 +940,10 @@ def _log_timed_event(
suppress_context=False,
expect_trace_id=False, # Not every chromium event will have a trace_id
)
return event

@staticmethod
def log_instant_event(
self,
event_name: str,
time_ns: int,
metadata: Optional[Dict[str, Any]] = None,
Expand All @@ -895,7 +960,10 @@ def log_instant_event(
"ts": time_ns / 1000,
"args": metadata,
"ph": "i",
"pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id
# These categories are needed in all chromium traces
"cat": "dynamo_timed",
"tid": 0,
"pid": 0,
"s": "p", # We use "process" level instant events so they all appear on the same row in the trace.
}
torch._logging.trace_structured(
Expand All @@ -904,6 +972,18 @@ def log_instant_event(
suppress_context=False,
expect_trace_id=True,
)
# Log an instant event with the same start and end time
log_chromium_event_internal(event, self.stack, self.id_)


CHROMIUM_EVENT_LOG: Optional[ChromiumEventLogger] = None


def get_chromium_event_logger() -> ChromiumEventLogger:
global CHROMIUM_EVENT_LOG
if CHROMIUM_EVENT_LOG is None:
CHROMIUM_EVENT_LOG = ChromiumEventLogger()
return CHROMIUM_EVENT_LOG


@dataclasses.dataclass
Expand Down
Loading