Skip to content

Commit

Permalink
Add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Dec 31, 2023
1 parent fcbe3c3 commit 5e6e453
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 11 deletions.
2 changes: 1 addition & 1 deletion mergekit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class ImmutableMap(
):
data: immutables.Map[T_K, T_V]

@pydantic.validator("data", pre=True)
@pydantic.field_validator("data", mode="before")
def validate_data(cls, data):
return immutables.Map(data)

Expand Down
1 change: 1 addition & 0 deletions mergekit/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def run(self) -> Iterator[Tuple[Task, Any]]:
# determine last usage of each value, so they can be evicted afterwards
last_use_index = {}
for idx, task in enumerate(self.schedule):
j = len(self.schedule)
for j in range(len(self.schedule) - 1, idx, -1):
if task in self.dependencies[self.schedule[j]]:
break
Expand Down
2 changes: 1 addition & 1 deletion mergekit/io/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Sequence
from typing import Dict, Optional, Sequence

import safetensors
import torch
Expand Down
21 changes: 12 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ build-backend = "setuptools.build_meta"
name = "mergekit"
description = "Tools for merging pre-trained large language models"
readme = "README.md"
license = {text = "LGPL-3.0-or-later"}
license = { text = "LGPL-3.0-or-later" }
version = "0.0.3.1"
authors = [
{name = "Charles Goddard", email = "[email protected]"},
]
authors = [{ name = "Charles Goddard", email = "[email protected]" }]
dependencies = [
"torch>=2.0.0",
"tqdm==4.66.1",
Expand All @@ -27,11 +25,8 @@ dependencies = [
]

[project.optional-dependencies]
dev = [
"black~=23.11.0",
"isort~=5.12.0",
"pre-commit~=3.5.0",
]
dev = ["black~=23.11.0", "isort~=5.12.0", "pre-commit~=3.5.0"]
test = ["pytest~=7.4.3"]

[project.urls]
repository = "https://github.com/cg123/mergekit"
Expand All @@ -53,3 +48,11 @@ profile = "black"
line-length = 88
target-version = ['py37']
include = '\.pyi?$'


[tool.pytest.ini_options]
minversion = "6.0"
filterwarnings = [
"ignore::pydantic.PydanticDeprecatedSince20:huggingface_hub.*:",
]
testpaths = ["tests"]
208 changes: 208 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
from typing import Any, Dict, Optional

import networkx
import pytest

from mergekit.common import ImmutableMap
from mergekit.graph import Executor, Task

EXECUTION_COUNTS: Dict[Task, int] = {}


class DummyTask(Task):
result: Any
dependencies: ImmutableMap[str, Task]
name: str = "DummyTask"
grouplabel: Optional[str] = None
execution_count: int = 0

def arguments(self):
return self.dependencies

def group_label(self) -> Optional[str]:
return self.grouplabel

def execute(self, **kwargs):
EXECUTION_COUNTS[self] = EXECUTION_COUNTS.get(self, 0) + 1
return self.result


def create_mock_task(name, result=None, dependencies=None, group_label=None):
if dependencies is None:
dependencies = {}
return DummyTask(
result=result,
dependencies=ImmutableMap(data=dependencies),
name=name,
grouplabel=group_label,
)


# Test cases for the Task implementation
class TestTaskClass:
def test_task_execute(self):
# Testing the execute method
task = create_mock_task("task1", result=42)
assert task.execute() == 42, "Task execution did not return expected result"

def test_task_priority(self):
task = create_mock_task("task1")
assert task.priority() == 0, "Default priority should be 0"

def test_task_group_label(self):
task = create_mock_task("task1")
assert task.group_label() is None, "Default group label should be None"


# Test cases for the Executor implementation
class TestExecutorClass:
def test_executor_initialization(self):
# Testing initialization with single task
task = create_mock_task("task1")
executor = Executor([task])
assert executor.targets == [
task
], "Executor did not initialize with correct targets"

def test_executor_empty_list(self):
list(Executor([]).run())

def test_executor_scheduling(self):
# Testing scheduling with dependencies
task1 = create_mock_task("task1", result=1)
task2 = create_mock_task("task2", result=2, dependencies={"task1": task1})
executor = Executor([task2])
assert (
len(executor._make_schedule([task2])) == 2
), "Schedule should include two tasks"

def test_executor_dependency_building(self):
# Testing dependency building
task1 = create_mock_task("task1")
task2 = create_mock_task("task2", dependencies={"task1": task1})
executor = Executor([task2])
dependencies = executor._build_dependencies([task2])
assert task1 in dependencies[task2], "Task1 should be a dependency of Task2"

def test_executor_run(self):
# Testing execution through the run method
task1 = create_mock_task("task1", result=10)
task2 = create_mock_task("task2", result=20, dependencies={"task1": task1})
executor = Executor([task2])
results = list(executor.run())
assert (
len(results) == 1 and results[0][1] == 20
), "Executor run did not yield correct results"

def test_executor_execute(self):
# Testing execute method for side effects
task1 = create_mock_task("task1", result=10)
executor = Executor([task1])
# No assert needed; we're ensuring no exceptions are raised and method completes
executor.execute()

def test_dependency_ordering(self):
# Testing the order of task execution respects dependencies
task1 = create_mock_task("task1", result=1)
task2 = create_mock_task("task2", result=2, dependencies={"task1": task1})
task3 = create_mock_task("task3", result=3, dependencies={"task2": task2})
executor = Executor([task3])

schedule = executor._make_schedule([task3])
assert schedule.index(task1) < schedule.index(
task2
), "Task1 should be scheduled before Task2"
assert schedule.index(task2) < schedule.index(
task3
), "Task2 should be scheduled before Task3"


class TestExecutorGroupLabel:
def test_group_label_scheduling(self):
# Create tasks with group labels and dependencies
task1 = create_mock_task("task1", group_label="group1")
task2 = create_mock_task(
"task2", dependencies={"task1": task1}, group_label="group1"
)
task3 = create_mock_task("task3", group_label="group2")
task4 = create_mock_task(
"task4", dependencies={"task2": task2, "task3": task3}, group_label="group1"
)

# Initialize Executor with the tasks
executor = Executor([task4])

# Get the scheduled tasks
schedule = executor._make_schedule([task4])

# Check if tasks with the same group label are scheduled consecutively when possible
group_labels_in_order = [
task.group_label() for task in schedule if task.group_label()
]
assert group_labels_in_order == [
"group1",
"group1",
"group2",
"group1",
], "Tasks with same group label are not scheduled consecutively"

def test_group_label_with_dependencies(self):
# Creating tasks with dependencies and group labels
task1 = create_mock_task("task1", result=1, group_label="group1")
task2 = create_mock_task(
"task2", result=2, dependencies={"task1": task1}, group_label="group2"
)
task3 = create_mock_task(
"task3", result=3, dependencies={"task2": task2}, group_label="group1"
)

executor = Executor([task3])
schedule = executor._make_schedule([task3])
scheduled_labels = [
task.group_label() for task in schedule if task.group_label()
]

# Check if task3 is scheduled after task1 and task2 due to dependency, even though it has the same group label as task1
group1_indices = [
i for i, label in enumerate(scheduled_labels) if label == "group1"
]
group2_index = scheduled_labels.index("group2")

assert (
group1_indices[-1] > group2_index
), "Task with the same group label but later dependency was not scheduled after different group label"


class TestExecutorSingleExecution:
def test_single_execution_per_task(self):
EXECUTION_COUNTS.clear()

shared_task = create_mock_task("shared_task", result=100)
task1 = create_mock_task("task1", dependencies={"shared": shared_task})
task2 = create_mock_task("task2", dependencies={"shared": shared_task})
task3 = create_mock_task("task3", dependencies={"task1": task1, "task2": task2})

Executor([task3]).execute()

assert shared_task in EXECUTION_COUNTS, "Dependency not executed"
assert (
EXECUTION_COUNTS[shared_task] == 1
), "Shared dependency should be executed exactly once"


class CircularTask(Task):
def arguments(self) -> Dict[str, Task]:
return {"its_a_me": self}

def execute(self, **_kwargs) -> Any:
assert False, "Task with circular dependency executed"


class TestExecutorCircularDependency:
def test_circular_dependency(self):
with pytest.raises(networkx.NetworkXUnfeasible):
Executor([CircularTask()]).execute()


if __name__ == "__main__":
pytest.main()
51 changes: 51 additions & 0 deletions tests/test_sparsify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
import torch

from mergekit.sparsify import SparsificationMethod, sparsify


@pytest.fixture
def sample_tensor():
res = torch.randn(128, 64)
res[res == 0] = 7 # very low chance, but hey!
return res


class TestMagnitude:
def test_full_density(self, sample_tensor):
assert torch.equal(
sparsify(sample_tensor, density=1, method=SparsificationMethod.magnitude),
sample_tensor,
)

def test_zero_density(self, sample_tensor):
with pytest.raises(AssertionError):
sparsify(sample_tensor, density=0, method=SparsificationMethod.magnitude)

def test_partial_density(self, sample_tensor):
result = sparsify(
sample_tensor, density=0.5, method=SparsificationMethod.magnitude
)
assert torch.count_nonzero(result) == sample_tensor.view(-1).shape[0] // 2


class TestBernoulli:
NUM_ITERATIONS = 1000

def test_bernoulli_with_rescale(self, sample_tensor):
ref_abs_sum = sample_tensor.abs().sum()
avg_abs_sum = torch.zeros_like(ref_abs_sum)
for _ in range(TestBernoulli.NUM_ITERATIONS):
rescaled = sparsify(
sample_tensor, density=0.5, method=SparsificationMethod.rescaled_random
)
avg_abs_sum += rescaled.abs().sum()
avg_abs_sum /= TestBernoulli.NUM_ITERATIONS

assert torch.isclose(avg_abs_sum, ref_abs_sum, rtol=0.01)

def test_bernoulli_without_rescale(self, sample_tensor):
result = sparsify(
sample_tensor, density=0.5, method=SparsificationMethod.random
)
assert 0 < torch.count_nonzero(result) <= sample_tensor.view(-1).shape[0]

0 comments on commit 5e6e453

Please sign in to comment.