diff --git a/ddtrace/bootstrap/preload.py b/ddtrace/bootstrap/preload.py index e16821dc756..a725656a469 100644 --- a/ddtrace/bootstrap/preload.py +++ b/ddtrace/bootstrap/preload.py @@ -11,6 +11,7 @@ from ddtrace.settings.profiling import config as profiling_config # noqa:F401 from ddtrace.internal.logger import get_logger # noqa:F401 from ddtrace.internal.module import ModuleWatchdog # noqa:F401 +from ddtrace.internal.products import manager # noqa:F401 from ddtrace.internal.runtime.runtime_metrics import RuntimeWorker # noqa:F401 from ddtrace.internal.tracemethods import _install_trace_methods # noqa:F401 from ddtrace.internal.utils.formats import asbool # noqa:F401 @@ -43,6 +44,15 @@ def register_post_preload(func: t.Callable) -> None: log = get_logger(__name__) +# Run the product manager protocol +manager.run_protocol() + +# Post preload operations +register_post_preload(manager.post_preload_products) + + +# TODO: Migrate the following product logic to the new product plugin interface + # DEV: We want to start the crashtracker as early as possible if crashtracker_config.enabled: log.debug("crashtracking enabled via environment variable") diff --git a/ddtrace/internal/products.py b/ddtrace/internal/products.py new file mode 100644 index 00000000000..24e98ac8ad9 --- /dev/null +++ b/ddtrace/internal/products.py @@ -0,0 +1,178 @@ +import atexit +from collections import defaultdict +from collections import deque +import sys +import typing as t + +from ddtrace.internal import forksafe +from ddtrace.internal.logger import get_logger +from ddtrace.internal.uwsgi import check_uwsgi +from ddtrace.internal.uwsgi import uWSGIConfigError +from ddtrace.internal.uwsgi import uWSGIMasterProcess + + +log = get_logger(__name__) + +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points + +try: + from typing import Protocol # noqa:F401 +except ImportError: + from typing_extensions import Protocol # type: ignore[assignment] + + +class Product(Protocol): + requires: t.List[str] + + def post_preload(self) -> None: + ... + + def start(self) -> None: + ... + + def restart(self, join: bool = False) -> None: + ... + + def stop(self, join: bool = False) -> None: + ... + + def at_exit(self, join: bool = False) -> None: + ... + + +class ProductManager: + __products__: t.Dict[str, Product] = {} # All discovered products + + def __init__(self) -> None: + self._products: t.Optional[t.List[t.Tuple[str, Product]]] = None # Topologically sorted products + + for product_plugin in entry_points(group="ddtrace.products"): + name = product_plugin.name + log.debug("Discovered product plugin '%s'", name) + + # Load the product protocol object + try: + product: Product = product_plugin.load() + except Exception: + log.exception("Failed to load product plugin '%s'", name) + continue + + log.debug("Product plugin '%s' loaded successfully", name) + + self.__products__[name] = product + + def _sort_products(self) -> t.List[t.Tuple[str, Product]]: + # Data structures for topological sorting + q: t.Deque[str] = deque() # Queue of products with no dependencies + g = defaultdict(list) # Graph of dependencies + f = {} # Remaining dependencies for each product + + for name, product in self.__products__.items(): + product_requires = getattr(product, "requires", []) + if not product_requires: + q.append(name) + else: + f[name] = list(product_requires) + for r in product_requires: + g[r].append(name) + + # Determine the product (topological) ordering + ordering = [] + while q: + n = q.popleft() + ordering.append(n) + for p in g[n]: + f[p].remove(n) + if not f[p]: + q.append(p) + del f[p] + + if f: + log.error( + "Circular dependencies among products detected. These products won't be enabled: %s.", list(f.keys()) + ) + + return [(name, self.__products__[name]) for name in ordering if name not in f] + + @property + def products(self) -> t.List[t.Tuple[str, Product]]: + if self._products is None: + self._products = self._sort_products() + return self._products + + def start_products(self) -> None: + for name, product in self.products: + try: + product.start() + log.debug("Started product '%s'", name) + except Exception: + log.exception("Failed to start product '%s'", name) + + def restart_products(self, join: bool = False) -> None: + for name, product in self.products: + try: + product.restart(join=join) + log.debug("Restarted product '%s'", name) + except Exception: + log.exception("Failed to restart product '%s'", name) + + def stop_products(self, join: bool = False) -> None: + for name, product in reversed(self.products): + try: + product.stop(join=join) + log.debug("Stopped product '%s'", name) + except Exception: + log.exception("Failed to stop product '%s'", name) + + def exit_products(self, join: bool = False) -> None: + for name, product in reversed(self.products): + try: + log.debug("Exiting product '%s'", name) + product.at_exit(join=join) + except Exception: + log.exception("Failed to exit product '%s'", name) + + def post_preload_products(self) -> None: + for name, product in self.products: + try: + product.post_preload() + log.debug("Post-preload product '%s' done", name) + except Exception: + log.exception("Failed to post_preload product '%s'", name) + + def _do_products(self) -> None: + # Start all products + self.start_products() + + # Restart products on fork + forksafe.register(self.restart_products) + + # Stop all products on exit + atexit.register(self.exit_products) + + def run_protocol(self) -> None: + # uWSGI support + try: + check_uwsgi(worker_callback=forksafe.ddtrace_after_in_child) + except uWSGIMasterProcess: + # We are in the uWSGI master process, we should handle products in the + # post-fork callback + @forksafe.register + def _() -> None: + self._do_products() + forksafe.unregister(_) + + except uWSGIConfigError: + log.error("uWSGI configuration error", exc_info=True) + except Exception: + log.exception("Failed to check uWSGI configuration") + + # Ordinary process + else: + self._do_products() + + +manager = ProductManager() diff --git a/tests/internal/test_products.py b/tests/internal/test_products.py new file mode 100644 index 00000000000..09226e1c5ff --- /dev/null +++ b/tests/internal/test_products.py @@ -0,0 +1,75 @@ +import os + +from ddtrace.internal.products import Product +from ddtrace.internal.products import ProductManager + + +class ProductManagerTest(ProductManager): + def __init__(self, products) -> None: + self._products = None + self.__products__ = products + + +class BaseProduct(Product): + requires = [] + + def __init__(self) -> None: + self.started = self.restarted = self.stopped = self.exited = self.post_preloaded = False + + def post_preload(self) -> None: + self.post_preloaded = True + + def start(self) -> None: + self.started = True + + def restart(self, join: bool = False) -> None: + self.restarted = True + + def stop(self, join: bool = False) -> None: + self.stopped = True + + def at_exit(self, join: bool = False) -> None: + self.exited = True + + +def test_product_manager_cycles(): + class A(BaseProduct): + requires = ["b"] + + class B(BaseProduct): + requires = ["a"] + + a = A() + b = B() + c = BaseProduct() + + manager = ProductManagerTest({"a": a, "b": b, "c": c}) + manager.run_protocol() + + # a and be depend on each other, so they won't start + assert not a.started and not b.started + + # c doesn't have any dependencies, so it will start + assert c.started + + +def test_product_manager_start(): + a = BaseProduct() + manager = ProductManagerTest({"a": a}) + manager.run_protocol() + assert a.started + + +def test_product_manager_restart(): + a = BaseProduct() + manager = ProductManagerTest({"a": a}) + manager.run_protocol() + assert a.started + assert not a.restarted + + pid = os.fork() + if pid == 0: + assert a.restarted + os._exit(0) + + os.waitpid(pid, 0)