Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel runner in use #448

Merged
merged 3 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 88 additions & 2 deletions src/modelbench/benchmark_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import dataclasses
import json
import pathlib
import random
import sys
import time
from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
Expand All @@ -23,6 +27,7 @@
SUTCompletionAnnotations,
)
from modelgauge.sut import SUTResponse, SUTCompletion
from tqdm import tqdm

from modelbench.benchmarks import (
BenchmarkDefinition,
Expand All @@ -31,6 +36,70 @@
from modelbench.suts import ModelGaugeSut


class RunTracker:
"""
A base class to encapsulate run tracking. Lets you limit update frequency to minimize output noise.
To subclass, the minimum is implementing _on_update. If you want no output, just use the
NullRunTracker.
"""

def __init__(self, seconds_per_update: float = 1.0):
super().__init__()
self.seconds_per_update = seconds_per_update
self.last_update = 0
self.total_items = 0

def start(self, total_items: int):
self.total_items = total_items

def update(self, finished_items: int):
if self._now() > self.seconds_per_update + self.last_update:
self._on_update(finished_items)
self.last_update = self._now()

def done(self):
self._on_update(self.total_items)

@abstractmethod
def _on_update(self, finished_items: int):
pass

def _now(self):
return time.time()


class NullRunTracker(RunTracker):

def _on_update(self, finished_items: int):
pass


class TqdmRunTracker(RunTracker):

def start(self, total_items: int):
super().start(total_items)
self.pbar = tqdm(total=self.total_items, unit="items")
self.previous_count = 0

def _on_update(self, finished_items: int):
self.pbar.update(finished_items - self.previous_count)
self.previous_count = finished_items

def done(self):
super().done()
self.pbar.close()


class JsonRunTracker(RunTracker):

def start(self, total_items: int):
super().start(total_items)
self._on_update(0)

def _on_update(self, finished_items: int):
print(json.dumps({"progress": finished_items / self.total_items}), file=sys.stderr)


class ModelgaugeTestWrapper:
"""An attempt at cleaning up the test interface"""

Expand Down Expand Up @@ -110,6 +179,8 @@ def __init__(self, runner: "TestRunnerBase"):
self.max_items = runner.max_items
self.tests = []
self._test_lookup = {}
self.run_tracker = runner.run_tracker
self.completed_item_count = 0

# set up for result collection
self.finished_items = defaultdict(lambda: defaultdict(lambda: list()))
Expand All @@ -126,6 +197,7 @@ def add_finished_item(self, item: "TestRunItem"):
self.finished_items[item.sut.key][item.test.uid].append(item)
else:
self.failed_items[item.sut.key][item.test.uid].append(item)
self.completed_item_count += 1

def add_test_record(self, test_record: TestRecord):
self.test_records[test_record.test_uid][test_record.sut_uid] = test_record
Expand Down Expand Up @@ -285,6 +357,7 @@ def __init__(self, test_run: TestRunBase):

def handle_item(self, item) -> None:
self.test_run.add_finished_item(item)
self.test_run.run_tracker.update(self.test_run.completed_item_count)


class TestRunnerBase:
Expand All @@ -295,6 +368,7 @@ def __init__(self, data_dir: pathlib.Path):
self.suts = []
self.max_items = 10
self.thread_count = 1
self.run_tracker = NullRunTracker()

def add_sut(self, sut: ModelGaugeSut):
self.suts.append(sut)
Expand Down Expand Up @@ -342,6 +416,9 @@ def _build_pipeline(self, run):
)
return pipeline

def _expected_item_count(self, the_run: TestRunBase, pipeline: Pipeline):
return len(the_run.suts) * len(list(pipeline.source.new_item_iterable()))


class TestRunner(TestRunnerBase):

Expand All @@ -361,9 +438,12 @@ def run(self) -> TestRun:
self._check_ready_to_run()
test_run = TestRun(self)
pipeline = self._build_pipeline(test_run)
test_run.run_tracker.start(self._expected_item_count(test_run, pipeline))

pipeline.run()

self._calculate_test_results(test_run)
test_run.run_tracker.done()
return test_run


Expand All @@ -384,11 +464,12 @@ def run(self) -> BenchmarkRun:
self._check_ready_to_run()
benchmark_run = BenchmarkRun(self)
pipeline = self._build_pipeline(benchmark_run)
benchmark_run.run_tracker.start(self._expected_item_count(benchmark_run, pipeline))
pipeline.run()

self._calculate_test_results(benchmark_run)
self._calculate_benchmark_scores(benchmark_run)

benchmark_run.run_tracker.done()
return benchmark_run

def _calculate_benchmark_scores(self, benchmark_run):
Expand All @@ -398,7 +479,12 @@ def _calculate_benchmark_scores(self, benchmark_run):
for hazard in benchmark_definition.hazards():
test_records = {}
for test in hazard.tests(benchmark_run.secrets):
test_records[test.uid] = benchmark_run.test_records[test.uid][sut.uid]
records = benchmark_run.test_records[test.uid][sut.uid]
assert records, f"No records found for {benchmark_definition} {sut} {hazard} {test.uid}"
test_records[test.uid] = records

assert test_records, f"No records found for {benchmark_definition} {sut} {hazard}"

hazard_scores.append(hazard.score(test_records)) # TODO: score needs way less
benchmark_run.benchmark_scores[benchmark_definition][sut] = BenchmarkScore(
benchmark_definition, sut, hazard_scores, end_time=datetime.now()
Expand Down
Loading
Loading