diff --git a/.gitignore b/.gitignore index 68bc17f9..4c24f09f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,15 @@ +*.pkl +*.h5 +offload_folder/ + +# Environment + +mergekit/bin/ +mergekit/share/ +mergekit/etc/ +merged/ +.vscode/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/examples/metrics-llama-1v2.yml b/examples/metrics-llama-1v2.yml new file mode 100644 index 00000000..9cca1ae8 --- /dev/null +++ b/examples/metrics-llama-1v2.yml @@ -0,0 +1,8 @@ +models: + - model: huggyllama/llama-7b + - model: TheBloke/Llama-2-7B-fp16 + +metric_method: all +parameters: + intra_model_metrics: true +dtype: float32 diff --git a/examples/metrics-llama-codev2.yml b/examples/metrics-llama-codev2.yml new file mode 100644 index 00000000..13413ce3 --- /dev/null +++ b/examples/metrics-llama-codev2.yml @@ -0,0 +1,6 @@ +models: + - model: meta-llama/CodeLlama-7b-Python-hf + - model: meta-llama/Llama-2-7b-hf + +metric_method: all +dtype: float32 diff --git a/examples/metrics-small.yml b/examples/metrics-small.yml new file mode 100644 index 00000000..85e3b5ef --- /dev/null +++ b/examples/metrics-small.yml @@ -0,0 +1,9 @@ +models: + - model: BEE-spoke-data/smol_llama-220M-GQA + - model: BEE-spoke-data/smol_llama-220M-openhermes + +metric_method: all +parameters: + intra_model_metrics: true + inter_model_metrics: true +dtype: float32 diff --git a/mergekit/_data/architectures/llama.json b/mergekit/_data/architectures/llama.json index c418f055..afae52eb 100644 --- a/mergekit/_data/architectures/llama.json +++ b/mergekit/_data/architectures/llama.json @@ -21,21 +21,27 @@ { "name": "model.layers.${layer_index}.self_attn.q_proj.weight", "input_space": "h_${layer_index}", + "num_attention_heads": "${num_attention_heads}", "output_space": "attn_qk_${layer_index}" }, { "name": "model.layers.${layer_index}.self_attn.k_proj.weight", "input_space": "h_${layer_index}", - "output_space": "attn_qk_${layer_index}" + "output_space": "attn_qk_${layer_index}", + "num_attention_heads": "${num_attention_heads}", + "num_key_value_heads": "${num_key_value_heads}" }, { "name": "model.layers.${layer_index}.self_attn.v_proj.weight", "input_space": "h_${layer_index}", - "output_space": "attn_v_${layer_index}" + "output_space": "attn_v_${layer_index}", + "num_attention_heads": "${num_attention_heads}", + "num_key_value_heads": "${num_key_value_heads}" }, { "name": "model.layers.${layer_index}.self_attn.o_proj.weight", "input_space": "attn_v_${layer_index}", + "num_attention_heads": "${num_attention_heads}", "output_space": "post_attn_${layer_index}" }, { diff --git a/mergekit/_data/models_and_datasets.json b/mergekit/_data/models_and_datasets.json new file mode 100644 index 00000000..5c49773e --- /dev/null +++ b/mergekit/_data/models_and_datasets.json @@ -0,0 +1,12 @@ +{ + "model_names": [ + "Qwen/Qwen2-7B-Instruct", + "arcee-ai/qwen2-7b-math-tess", + "arcee-ai/qwen2-mmamo-2", + "arcee-ai/Qwen2-Merger" + ], + "dataset_names": [ + "microsoft/orca-math-word-problems-200k", + "MuskumPillerum/General-Knowledge" + ] +} \ No newline at end of file diff --git a/mergekit/_data/models_and_datasets.py b/mergekit/_data/models_and_datasets.py new file mode 100644 index 00000000..6ebfda65 --- /dev/null +++ b/mergekit/_data/models_and_datasets.py @@ -0,0 +1,45 @@ +import json +from pathlib import Path + +def save_model_and_dataset(model_name, dataset_name): + models_and_datasets = presets().load() + if model_name not in models_and_datasets['model_names']: + models_and_datasets['model_names'].append(model_name) + if dataset_name not in models_and_datasets['dataset_names']: + models_and_datasets['dataset_names'].append(dataset_name) + presets().save(models_and_datasets['model_names'], models_and_datasets['dataset_names']) + +def model_and_dataset_to_index(model_name, dataset_name): + models_and_datasets = presets().load() + model_index = models_and_datasets['model_names'].index(model_name) if model_name in models_and_datasets['model_names'] else [] + dataset_index = models_and_datasets['dataset_names'].index(dataset_name) if dataset_name in models_and_datasets['dataset_names'] else [] + + return model_index, dataset_index + +def index_to_model_and_dataset(model_index, dataset_index): + models_and_datasets = presets().load() + model_name = models_and_datasets['model_names'][model_index] if len(models_and_datasets['model_names']) > model_index else [] + dataset_name = models_and_datasets['dataset_names'][dataset_index] if len(models_and_datasets['dataset_names']) > dataset_index else [] + return model_name, dataset_name + +class presets(): + def __init__(self): + self.FILE_PATH = Path(__file__).parent / 'models_and_datasets.json' + + def load(self): + """Load the lists from the JSON file.""" + if self.FILE_PATH.exists(): + with open(self.FILE_PATH, 'r') as file: + data = json.load(file) + return data + print(f"File {self.FILE_PATH} does not exist or is empty.") + return {} + + def save(self, model_names, dataset_names): + """Save the lists to the JSON file.""" + data = { + 'model_names': model_names, + 'dataset_names': dataset_names + } + with open(self.FILE_PATH, 'w') as file: + json.dump(data, file, indent=4) diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 653f1ac3..72100af3 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -43,6 +43,10 @@ class WeightInfo(BaseModel, frozen=True): List of alternative names for the weight, if applicable. force_dtype (Optional[str]): Mandatory dtype for the weight, if applicable. + num_key_value_heads (Optional[int]): + Number of key-value heads in the weight, relevant for GQA, if applicable. + num_attention_heads (Optional[int]): + Number of attention heads in the weight, if applicable. """ name: str @@ -53,6 +57,9 @@ class WeightInfo(BaseModel, frozen=True): aliases: Optional[Tuple[str, ...]] = None force_dtype: Optional[str] = None + num_key_value_heads: Union[int, str, None] = None + num_attention_heads: Union[int, str, None] = None + class ProceduralSpaceInfo(BaseModel, frozen=True): """Defines a procedural space computed from one or more other spaces. @@ -172,30 +179,29 @@ class JSONArchitectureDefinition(BaseModel, frozen=True): class TemplateWithArithmetic(string.Template): idpattern = r"(?a:[_a-z][_a-z0-9]*([+-]1)?)" - -def _template_substitution( - template: str, num_layers: int, layer_idx: Optional[int] = None +def _generic_template_substitution( + template: str, match_str: str = "num_layers", replace_with: int = 0 ) -> str: - if "{" not in template: + if f"{match_str}" not in template: return template substitutions = { - "num_layers": num_layers, - "num_layers+1": num_layers + 1, - "num_layers-1": num_layers - 1, + match_str: replace_with, + f"{match_str}+1": replace_with + 1, + f"{match_str}-1": replace_with - 1, } - if layer_idx is not None: - substitutions.update( - { - "layer_index": layer_idx, - "layer_index+1": layer_idx + 1, - "layer_index-1": layer_idx - 1, - } - ) - return TemplateWithArithmetic(template).substitute(substitutions) +def _template_substitution( + template: str, substitute: Dict[str, Optional[int]] +) -> str: + for key, value in substitute.items(): + if value is not None: + template = _generic_template_substitution(template, key, value) + + return template + class JsonArchitectureInfo(ArchitectureInfo, BaseModel, frozen=True): definition: JSONArchitectureDefinition @@ -206,18 +212,23 @@ def _substitute( config: PretrainedConfig, layer_idx: Optional[int] = None, ) -> Union[WeightInfo, ProceduralSpaceInfo]: - num_layers = self.num_layers(config) - - obj_dict = item.model_dump(mode="json", exclude_unset=True) + substitute = { + "num_layers": self.num_layers(config), + "layer_index": layer_idx, + "num_attention_heads": getattr(config, "num_attention_heads") if getattr(config, "num_attention_heads") is not None else None, + "num_key_value_heads": getattr(config, "num_key_value_heads") if getattr(config, "num_key_value_heads") is not None else None, + } + + obj_dict = item.model_dump(mode="json", exclude_unset=False) for key in obj_dict: if isinstance(obj_dict[key], str): obj_dict[key] = _template_substitution( - obj_dict[key], num_layers, layer_idx + obj_dict[key], substitute ) elif isinstance(obj_dict[key], list): obj_dict[key] = [ ( - _template_substitution(s, num_layers, layer_idx) + _template_substitution(s, substitute) if isinstance(s, str) else s ) diff --git a/mergekit/config.py b/mergekit/config.py index 5c79de7c..a391476b 100644 --- a/mergekit/config.py +++ b/mergekit/config.py @@ -83,7 +83,8 @@ class OutputSliceDefinition(BaseModel): class MergeConfiguration(BaseModel): - merge_method: str + merge_method: Optional[str] = None + metric_method: Optional[str] = None slices: Optional[List[OutputSliceDefinition]] = None models: Optional[List[InputModelDefinition]] = None parameters: Optional[Dict[str, ParameterSetting]] = None @@ -114,6 +115,14 @@ def validate_inputs(self): if ((not self.slices) and (not self.models)) or (self.slices and self.models): raise RuntimeError("Must specify either output slices or models to merge") return self + + @model_validator(mode="after") + def validate_methods(self): + if not self.merge_method and not self.metric_method: + raise ValueError("Either 'merge_method' or 'metric_method' must be provided.") + if self.merge_method and self.metric_method: + raise ValueError("Only one of 'merge_method' or 'metric_method' can be provided, not both.") + return self @model_validator(mode="after") def validate_tokenizer(self): diff --git a/mergekit/graph.py b/mergekit/graph.py index c81cb85b..fea69e37 100644 --- a/mergekit/graph.py +++ b/mergekit/graph.py @@ -37,6 +37,7 @@ class Task(ABC, BaseModel, Generic[ValueT], frozen=True): Abstract base class representing a task in a computational graph. This class should be extended to define specific tasks. Each task can have arguments (dependencies) and a defined execution strategy. + Pydantic BaseModel requires that all attributes are defined in the class initialisation, and cannot be changed after. Attributes: Generic[ValueT] (TypeVar): The type of the value that the task returns upon execution. @@ -106,7 +107,6 @@ def uses_accelerator(self) -> bool: """ return False - class Executor: """ Schedules and executes a set of tasks and their dependencies. @@ -241,13 +241,20 @@ def _make_schedule(self, targets: List[Task]) -> List[Task]: # they will be included in the final schedule edge_tups.append((Executor.DUMMY_TASK_VALUE, task)) + def _pad_numbers(s): + parts = s.split('.') + for i, part in enumerate(parts): + if part.isdigit(): + parts[i] = part.zfill(3) + return '.'.join(parts) + def _compare_key(task: Union[Task, str]): if task == Executor.DUMMY_TASK_VALUE: return ("", 0) - return ( - task.group_label() or "", - -task.priority(), - ) + group_label = task.group_label() or "" + padded_label = _pad_numbers(group_label) + priority = -task.priority() + return (padded_label, priority) graph = networkx.DiGraph(edge_tups) res = [ diff --git a/mergekit/merge.py b/mergekit/merge.py index 60189f44..a94e55e5 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -92,11 +92,17 @@ def run_merge( storage_device="cuda" if options.low_cpu_memory else "cpu", ) + metrics_out = [] tokenizer = None for _task, value in exec.run(quiet=options.quiet): + if merge_config.metric_method is not None: + metrics_out.append((_task, value)) if isinstance(value, TokenizerInfo): tokenizer = value.tokenizer + if metrics_out: + return metrics_out + if tokenizer: _update_config_vocab(cfg_out, tokenizer) diff --git a/mergekit/metric_methods/__init__.py b/mergekit/metric_methods/__init__.py new file mode 100644 index 00000000..6027c8b6 --- /dev/null +++ b/mergekit/metric_methods/__init__.py @@ -0,0 +1,29 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from mergekit.metric_methods.base import MetricMethod +from mergekit.metric_methods.all_metrics import AllMetric + + +def get(method: str) -> MetricMethod: + if method == "all": + return AllMetric() + raise RuntimeError(f"Unimplemented metric method {method}") + + +__all__ = [ + "MetricMethod", + "get", +] diff --git a/mergekit/metric_methods/aggregator_metrics.py b/mergekit/metric_methods/aggregator_metrics.py new file mode 100644 index 00000000..fb317406 --- /dev/null +++ b/mergekit/metric_methods/aggregator_metrics.py @@ -0,0 +1,386 @@ +import torch +import numpy as np +from typing import List, Dict, Any, Optional +from tqdm import tqdm +import torch.nn.functional as F +import math + +from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric, Results, Layer, ScatterPlot +from mergekit.metric_methods.metrics import compute_histogram + +from mergekit.architecture import WeightInfo +from sklearn.manifold import TSNE + +import enum + +class ModelAnalysisType(enum.Enum): + INDIVIDUAL = "individual" + COMPARISON = "comparison" + +class LayerComparisonType(enum.Enum): + SINGLE = "single" # Analyse Layer i + BLOCK = "block" # Compare Layer i in model 1 with layer i+(block size) in model 1 + CORRESPONDING = "corresponding" # Compare Layer i in model 1 with layer i in model 2 + ALL = "all_layers" # Compare Layer i in model 1 with Layer j in model (1 or 2) + +class MetricAggregator(): + def __init__(self, device: str = "cpu"): + self.device = device + self.valid_for = { + LayerComparisonType.SINGLE.value: False, + LayerComparisonType.BLOCK.value: False, + LayerComparisonType.CORRESPONDING.value: False, + LayerComparisonType.ALL.value: False + } + + def process_batch(self, batch_a: torch.Tensor, batch_b: Optional[torch.Tensor]) -> None: + raise NotImplementedError + + def aggregate(self) -> Metric: + raise NotImplementedError + + # def clear(self) -> None: + # raise NotImplementedError + +class Cosine_Similarity(MetricAggregator): + def __init__(self, device: str = "cpu"): + super().__init__(device=device) + self.cosine_similarities = torch.tensor([], device=self.device) + self.valid_for.update({ + LayerComparisonType.BLOCK.value: True, + LayerComparisonType.CORRESPONDING.value: True, + LayerComparisonType.ALL.value: True + }) + + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + batch_similarities = F.cosine_similarity(batch_a, batch_b, dim=1) + self.cosine_similarities = torch.cat((self.cosine_similarities, batch_similarities)) + + def aggregate(self) -> Metric: + hist = compute_histogram(self.cosine_similarities, 100) + mean_std=MeanStd( + mean=self.cosine_similarities.mean().item(), + std=self.cosine_similarities.std().item() + ) + histogram=Histogram( + count=hist[0], + edges=hist[1], + widths=hist[2] + ) + self.__init__() + return Metric( + histogram=histogram, + mean_std=mean_std + ) + + # def clear(self) -> None: + # self.cosine_similarities = torch.tensor([], device=self.device) + +class MSE(MetricAggregator): + def __init__(self, device: str = "cpu"): + super().__init__(device=device) + self.square_errors = torch.tensor([], device=self.device) + self.valid_for.update({ + LayerComparisonType.BLOCK.value: True, + LayerComparisonType.CORRESPONDING.value: True, + }) + + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + batch_square_errors = torch.square(batch_a - batch_b).flatten() + self.square_errors = torch.cat((self.square_errors, batch_square_errors)) + + def aggregate(self) -> Metric: + hist = compute_histogram(self.square_errors, 100) + mean_std=MeanStd( + mean=self.square_errors.mean().item(), + std=self.square_errors.std().item() + ) + histogram=Histogram( + count=hist[0], + edges=hist[1], + widths=hist[2] + ) + self.__init__() + return Metric( + histogram=histogram, + mean_std=mean_std + ) + +class Linearity_Score(MetricAggregator): + def __init__(self, device: str = "cpu"): + + super().__init__(device=device) + self.iterations = 0 + self.max_iterations = 50 + self.A = None + self.optimiser = None + self.initialised = False + self.done = False + self.losses = [] + + self.absolute_square_sum = 0 + self.num_elements = 0 + + self.valid_for.update({ + LayerComparisonType.BLOCK.value: True + }) + + def _first_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + batch_size, dimension = batch_a.size() + + self.A = torch.empty(dimension, dimension, device=self.device) + torch.nn.init.normal_(self.A) + self.A = torch.nn.Parameter(self.A) + + self.optimiser = torch.optim.SGD([self.A], lr=0.0001) + self.initialised = True + + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + batch_a = batch_a / torch.norm(batch_a, dim=1, keepdim=True) + batch_b = batch_b / torch.norm(batch_b, dim=1, keepdim=True) # Check dimensionality (X) + if not self.initialised: + self._first_batch(batch_a, batch_b) + if self.done: # stop training A and evaluate + residuals = batch_a @ self.A - batch_b + self.absolute_square_sum += torch.abs(residuals).sum().item() + self.num_elements += residuals.numel() + + else: + + loss = torch.norm(batch_a @ self.A - batch_b) ** 2 + loss.backward() + self.losses.append(loss.item()) + print(f'Loss: {loss.item()}') + self.optimiser.step() + + self.iterations += 1 + + if self.iterations >= self.max_iterations: + self.done = True + + def aggregate(self) -> Metric: + linearity_score = 1 - self.absolute_square_sum / self.num_elements + self.__init__() + return Metric(mean_std=MeanStd(mean=linearity_score)) + + # def clear(self) -> None: + # pass + +class _CKA(object): + def __init__(self): + self.kernel_functions = { + 'inner_product': self.inner_product, + 'rbf': self.rbf + } + + def inner_product(self, X): + return X @ X.T + + def rbf(X, sigma=None): + GX = torch.mm(X, X.t()) + diag_GX = torch.diag(GX).unsqueeze(1) + KX = diag_GX - GX + (diag_GX - GX).t() + + if sigma is None: + mdist = torch.median(KX[KX != 0]) + sigma = torch.sqrt(mdist) + + KX *= -0.5 / (sigma * sigma) + KX = torch.exp(KX) + + return KX + + def centering(self, K): + n = K.shape[0] + unit = torch.ones([n, n]).to(K.device) + I = torch.eye(n).to(K.device) + H = I - unit / n + return H @ K @ H + + def hsic(self, K_x, K_y): + """ + Hilbert-Schmidt Independence Criterion + Input: K_x, K_y: *Centered* Kernel matrices + + Returns: HSIC(K_x, K_y) + """ + return torch.trace(K_x.T @ K_y) / ((K_x.shape[0]-1) ** 2) + + def cka(self, X, Y, kernel_function='inner_product'): + K_x = self.kernel_functions[kernel_function](X) + K_y = self.kernel_functions[kernel_function](Y) + + K_x = self.centering(K_x) + K_y = self.centering(K_y) + + hsic_xy = self.hsic(K_x, K_y) + hsic_xx = self.hsic(K_x, K_x) + hsic_yy = self.hsic(K_y, K_y) + + return hsic_xy / torch.sqrt(hsic_xx * hsic_yy) + + def align(self, X, Y, knn_x, knn_y): + """ + Input: X, Y: Centered Kernel matrices + """ + assert X.shape == Y.shape + num_rows, num_cols = X.shape + rows, cols = torch.meshgrid(torch.arange(num_rows), torch.arange(num_cols), indexing='ij') + + # Check if each element in the meshgrid is a mutual nearest neighbor + mutual_nn_mask = torch.isin(rows.to(knn_x[0].device), knn_x.indices[cols]) & \ + torch.isin(cols.to(knn_x[0].device), knn_y.indices[rows]) + + return torch.sum(mutual_nn_mask * X * Y) + + def cknna(self, X, Y, kernel_function='inner_product', k=5): + K_x = self.kernel_functions[kernel_function](X) + K_y = self.kernel_functions[kernel_function](Y) + + K_x = self.centering(K_x) + K_y = self.centering(K_y) + + k_nearest_neighbors_x = torch.topk(K_x, k=k, dim=1, largest=True, sorted=False) + k_nearest_neighbors_y = torch.topk(K_y, k=k, dim=1, largest=True, sorted=False) + + align_xy = self.align(K_x, K_y, k_nearest_neighbors_x, k_nearest_neighbors_y) + align_xx = self.align(K_x, K_x, k_nearest_neighbors_x, k_nearest_neighbors_x) + align_yy = self.align(K_y, K_y, k_nearest_neighbors_y, k_nearest_neighbors_y) + + return align_xy / torch.sqrt(align_xx * align_yy) + +class CKA(MetricAggregator): + def __init__(self, device: str = "cpu"): + super().__init__(device=device) + self.cka = _CKA() + self.batches_a = [] + self.batches_b = [] + self.stop = False + self.max_batches = 20 + + self.valid_for.update({ + LayerComparisonType.BLOCK.value: True, + LayerComparisonType.CORRESPONDING.value: True, + LayerComparisonType.ALL.value: True + }) + + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + if not self.stop: + self.batches_a.append(batch_a) + self.batches_b.append(batch_b) + + if len(self.batches_a) >= self.max_batches: + self.stop = True + + def aggregate(self) -> Metric: + result = self.cka.cka(torch.concat(self.batches_a), + torch.concat(self.batches_b)) + + self.__init__(self.device) # Reset ready for next layer + return Metric(mean_std=MeanStd(mean=result.cpu().item())) + +class CNNKA(MetricAggregator): + def __init__(self, device: str = "cpu"): + super().__init__(device=device) + self.cka = _CKA() + self.batches_a = [] + self.batches_b = [] + self.stop = False + self.max_batches = 20 + + self.valid_for.update({ + LayerComparisonType.BLOCK.value: True, + LayerComparisonType.CORRESPONDING.value: True, + LayerComparisonType.ALL.value: True + }) + + def process_batch(self, batch_a: torch.Tensor, batch_b: torch.Tensor) -> None: + if not self.stop: + self.batches_a.append(batch_a) + self.batches_b.append(batch_b) + + if len(self.batches_a) >= self.max_batches: + self.stop = True + + def aggregate(self) -> Metric: + result = self.cka.cknna(torch.concat(self.batches_a), + torch.concat(self.batches_b)) + + self.__init__(self.device) # Reset ready for next layer + return Metric(mean_std=MeanStd(mean=result.cpu().item())) + +class t_SNE(MetricAggregator): + def __init__(self, device: str = "cpu"): + super().__init__(device=device) + self.tsne = TSNE(n_components=2, random_state=42) + self.batches = [] + self.max_batches = 20 + self.stop = False + + self.valid_for.update({ + LayerComparisonType.SINGLE.value: True, + }) + + def process_batch(self, batch: torch.Tensor) -> None: + if not self.stop: + self.batches.append(batch.cpu().numpy()) + + if len(self.batches) >= self.max_batches: + self.stop = True + + def aggregate(self) -> Metric: + data = np.concatenate(self.batches) + self.result = self.tsne.fit_transform(data) + + metric = Metric( + scatter_plot=ScatterPlot( + x=self.result[:, 0], + y=self.result[:, 1], + ) + ) + self.__init__(self.device) # Reset ready for next layer + return metric + +class PCA_Projection(MetricAggregator): + def __init__(self, device: str = "cpu"): + super().__init__(device=device) + self.batches = [] + self.max_batches = 20 + self.stop = False + + self.valid_for.update({ + LayerComparisonType.SINGLE.value: True, + }) + + def process_batch(self, batch: torch.Tensor) -> None: + if not self.stop: + self.batches.append(batch) + + if len(self.batches) >= self.max_batches: + self.stop = True + + def aggregate(self) -> Metric: + data = torch.cat(self.batches, dim=0) + mean = torch.mean(data, dim=0) + data -= mean + U, S, V = torch.pca_lowrank(data, q=2) + result = torch.matmul(data, V[:, :2]) + result = result.cpu().numpy() + metric = Metric( + scatter_plot=ScatterPlot( + x=result[:, 0], + y=result[:, 1], + ) + ) + self.__init__(self.device) # Reset ready for next layer + return metric + +METRICS_TABLE = { + 'cosine_similarity': Cosine_Similarity, + 'mse': MSE, + 'linearity_score': Linearity_Score, + 'cka': CKA, + 'cknna': CNNKA, + 't-sne': t_SNE, + 'pca_projection': PCA_Projection + } \ No newline at end of file diff --git a/mergekit/metric_methods/all_metrics.py b/mergekit/metric_methods/all_metrics.py new file mode 100644 index 00000000..0782dde4 --- /dev/null +++ b/mergekit/metric_methods/all_metrics.py @@ -0,0 +1,327 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from typing import Any, Dict, List, Optional, Union + +from mergekit.architecture import WeightInfo +from mergekit.common import ModelReference +from mergekit.graph import Task +from mergekit.io.tasks import GatherTensors +from mergekit.metric_methods.base import MetricMethod, Layer +import torch +from typing import Dict, List, Any +from mergekit.metric_methods.metrics import cosine_similarity, smape, scale, mse, weight_magnitude, numerical_rank + +def validate_tensors(tensors: List[torch.Tensor], weight_info: WeightInfo, expected_tensors: Optional[int] = 2): + """Validate tensor shapes and count.""" + unique_shapes = set(t.shape for t in tensors) + if len(unique_shapes) != 1: + raise RuntimeError(f"Tensor size mismatch for {weight_info.name}, sizes: {list(unique_shapes)}") + if expected_tensors: + if len(tensors) != expected_tensors: + raise RuntimeError(f"Expected {expected_tensors} tensors, got {len(tensors)}") + +def ungroup_tensor(input_tensor: torch.Tensor, gqa_groups: int) -> torch.Tensor: + """ + Ungroup a grouped tensor by repeating its rows. + """ + rows, cols = input_tensor.shape + new_rows = rows * gqa_groups + ungrouped_tensor = torch.zeros(new_rows, cols) + + for i in range(gqa_groups): + ungrouped_tensor[i*rows:(i+1)*rows] = input_tensor[i].expand(rows, -1) + + return ungrouped_tensor.to(input_tensor.device) + +def restructure_tensor(input_tensor: torch.Tensor, num_rows: int) -> torch.Tensor: + """ + Restructure a tensor by splitting its rows and permuting the dimensions. + + This is used so that the attention weights can be grouped by head in the first dimension. + """ + rows, cols = input_tensor.shape + new_rows = rows // num_rows + reshaped_tensor = input_tensor.view(new_rows, num_rows, cols) + restructured_tensor = reshaped_tensor.permute(1, 0, 2) + + return restructured_tensor + +def group_attn_head_weights(k_proj: torch.Tensor, + q_proj: torch.Tensor, + v_proj: torch.Tensor, + o_proj: torch.Tensor, + weight_info: WeightInfo) -> tuple[torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor]: + + num_heads = weight_info.num_attention_heads + assert num_heads is not None, "Number of attention heads is not defined" + + if getattr(weight_info, 'num_key_value_heads', None) and getattr(weight_info, 'num_key_value_heads', None) != 0: + gqa_groups = num_heads // weight_info.num_key_value_heads + + k_proj = ungroup_tensor(k_proj, gqa_groups) + v_proj = ungroup_tensor(v_proj, gqa_groups) + + k_proj = restructure_tensor(k_proj, num_heads) + v_proj = restructure_tensor(v_proj, num_heads) + q_proj = restructure_tensor(q_proj, num_heads) + o_proj = restructure_tensor(o_proj.T, num_heads) # Output weights are split into heads by rows, not columns + + return k_proj, v_proj, q_proj, o_proj + +# Tasks + +class MLPTask(Task[torch.Tensor]): + gather_tensors: GatherTensors + weight_info: WeightInfo + intra_model_metrics: bool = False + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute( + self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs + ) -> torch.Tensor: + weights = list(tensors.values()) + layer_results = Layer(metrics={}, + weight_info=self.weight_info) + + if self.intra_model_metrics: + validate_tensors(weights, self.weight_info, expected_tensors=1) + layer_results.add_metric(weight_magnitude(weights[0]), name='weight_magnitude') + layer_results.add_metric(numerical_rank(weights[0]), name='numerical_rank') + else: + validate_tensors(weights, self.weight_info, expected_tensors=2) + layer_results.add_metric(cosine_similarity(weights, return_heatmap=False), name = 'cosine_similarity') + layer_results.add_metric(smape(weights), name = 'smape') + layer_results.add_metric(scale(weights, return_heatmap=False), name = 'scale') + layer_results.add_metric(mse(weights, return_heatmap=False), name = 'mse') + + + return layer_results + + def group_label(self) -> Optional[str]: + return self.gather_tensors.group_label() + +class AttnTask(Task[torch.Tensor]): + weights: Dict[str, GatherTensors] + weight_infos: Dict[str, WeightInfo] + weight_info: WeightInfo + intra_model_metrics: bool = False + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + + return self.weights + + def execute( + self, k_proj: torch.Tensor, v_proj: torch.Tensor, q_proj: torch.Tensor, o_proj: torch.Tensor, **_kwargs + ) -> torch.Tensor: + # Add metrics for attention weights + + model_references = list(q_proj.keys()) + + k_proj_0, v_proj_0, q_proj_0, o_proj_0 = group_attn_head_weights(k_proj[model_references[0]], + q_proj[model_references[0]], + v_proj[model_references[0]], + o_proj[model_references[0]], + self.weight_info) + model_0_heads = torch.cat([k_proj_0, v_proj_0, q_proj_0, o_proj_0], dim=1) + layer_results = Layer(metrics={}, + weight_info=self.weight_info) + + + if self.intra_model_metrics: + + layer_results.add_metric( + metric=weight_magnitude(model_0_heads), + name='weight_magnitude' + ) + else: + + k_proj_1, v_proj_1, q_proj_1, o_proj_1 = group_attn_head_weights(k_proj[model_references[1]], + q_proj[model_references[1]], + v_proj[model_references[1]], + o_proj[model_references[1]], + self.weight_info) + + + model_1_heads = torch.cat([k_proj_1, v_proj_1, q_proj_1, o_proj_1], dim=1) + + + + layer_results.add_metric(cosine_similarity([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + return_heatmap=True), + name = 'cosine_similarity') + layer_results.add_metric(smape([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)]), + name = 'smape') + layer_results.add_metric(scale([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + return_heatmap=True), + name = 'scale') + layer_results.add_metric(mse([model_0_heads.view(model_0_heads.shape[0], -1), + model_1_heads.view(model_1_heads.shape[0], -1)], + return_heatmap=False), + name = 'mse') + + + + return layer_results + + def group_label(self) -> Optional[str]: + return max([gather_tensor.group_label() for gather_tensor in list(self.weights.values())]) + + def __hash__(self): + return hash(self.weight_info) + + def __eq__(self, other): + if not isinstance(other, AttnTask): + return False + return self.weight_info == other.weight_info + +class LayerNormTask(Task[torch.Tensor]): + gather_tensors: GatherTensors + weight_info: WeightInfo + intra_model_metrics: bool = False + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute( + self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs + ) -> torch.Tensor: + + tensors = list(tensors.values()) + + assert tensors[0].dim() == 1, "LayerNorm tensors must be 2D" + + layer_results = Layer(metrics={}, weight_info=self.weight_info) + + if self.intra_model_metrics: + layer_results.add_metric( + metric=weight_magnitude(tensors[0].unsqueeze(1)), + name='weight_magnitude' + ) + else: + assert tensors[1].dim() == 1, "LayerNorm tensors must be 2D" + layer_results.add_metric(cosine_similarity([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'cosine_similarity') + layer_results.add_metric(smape([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)]), name = 'smape') + layer_results.add_metric(scale([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'scale') + layer_results.add_metric(mse([tensors[0].unsqueeze(1), tensors[1].unsqueeze(1)], return_heatmap=True), name = 'mse') + + return layer_results + + def group_label(self) -> Optional[str]: + return self.gather_tensors.group_label() + +class DummyTask(Task[torch.Tensor]): + gather_tensors: GatherTensors + weight_info: WeightInfo + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute( + self, **_kwargs + ) -> torch.Tensor: + + return + + def group_label(self) -> Optional[str]: + return self.gather_tensors.group_label() + + +from mergekit.merge_methods.base import ConfigParameterDef + +# Metric method +class AllMetric(MetricMethod): + def __init__(self) -> None: + super().__init__() + self.attn_weight_dict: Optional[Dict[str, torch.Tensor]] = {} + self.attn_info_dict: Optional[Dict[str, WeightInfo]] = {} + self.attn_parts: Optional[List[str]] = ['k_proj', 'v_proj', 'q_proj', 'o_proj'] # hard-coded for now + self.block_count: Optional[int] = 0 + + def make_task( + self, + *, + output_weight: WeightInfo, + parameters: Optional[Dict[str, Any]] = None, + tensors: GatherTensors, + **_kwargs, + ) -> Task: + + if 'self_attn' in output_weight.name: + return self.group_attn_heads(tensors, output_weight, parameters) + elif 'mlp' in output_weight.name: + return MLPTask( + gather_tensors=tensors, + weight_info=output_weight, + intra_model_metrics=parameters['intra_model_metrics'] + ) + elif 'layernorm' in output_weight.name: + return LayerNormTask(gather_tensors=tensors, weight_info=output_weight, intra_model_metrics=parameters['intra_model_metrics']) + else: + # Executor expects a task to be returned + return DummyTask(gather_tensors=tensors, weight_info=output_weight) + + def group_attn_heads(self, tensors: GatherTensors, output_weight: WeightInfo, parameters: Dict[str, Any]): + # collect all attention weights + for part in self.attn_parts: # also check only one key + if part in output_weight.name: + assert self.attn_weight_dict.get(part) is None, f"Duplicate attention part {part}" + self.attn_weight_dict[part] = tensors + self.attn_info_dict[part] = output_weight + + # if all attention weights are collected, create attention task + if set(list(self.attn_weight_dict.keys())) == set(self.attn_parts): + weights, infos = self.attn_weight_dict, self.attn_info_dict + self.attn_weight_dict, self.attn_info_dict = {}, {} + weight_info = WeightInfo( + name=f"Attention Block {self.block_count}", + force_dtype=None, + optional=False, + aliases=None, + num_key_value_heads=int(infos['k_proj'].num_key_value_heads), + num_attention_heads=int(infos['k_proj'].num_attention_heads) + ) + self.block_count += 1 + return AttnTask(weights=weights, weight_infos=infos, weight_info=weight_info, intra_model_metrics=parameters['intra_model_metrics']) + else: + # Executor expects a task to be returned + return DummyTask(gather_tensors=tensors, weight_info=output_weight) + + + def parameters(self) -> List[ConfigParameterDef]: + return [ + ConfigParameterDef(name="intra_model_metrics", required=False, default_value=False), + ] + + diff --git a/mergekit/metric_methods/base.py b/mergekit/metric_methods/base.py new file mode 100644 index 00000000..cdc655d8 --- /dev/null +++ b/mergekit/metric_methods/base.py @@ -0,0 +1,265 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from typing import Any, List, Optional, Dict + +from mergekit.architecture import WeightInfo +from mergekit.common import ModelReference + +from mergekit.merge_methods.base import MergeMethod +from dataclasses import dataclass, field +from collections import defaultdict +import torch +from pathlib import Path +import pickle + +class MetricMethod(MergeMethod): + pass + +# Structure of the results object + + +# Results +# ├── model_path: Optional[List[str]] # One for individual model, two for comparison +# └── layers: Dict[str, Layer] +# └── Layer +# ├── weight_info: WeightInfo (remove?) +# └── metrics: Dict[str, Metric] +# └── Metric +# ├── histogram: Optional[Histogram] +# ├── mean_std: Optional[MeanStd] +# ├── scatter_plot: Optional[ScatterPlot] +# └── heatmap: Optional[Heatmap] +from enum import Enum + +class PlotType(Enum): + HISTOGRAM = 'histogram' + MEAN_STD = 'mean_std' + SCATTER_PLOT = 'scatter_plot' + HEATMAP = 'heatmap' + +@dataclass +class MeanStd: + mean: float + std: Optional[float] = None + +@dataclass +class Heatmap: + data: torch.Tensor + plot_details: Optional[Dict] = None + +@dataclass +class Histogram: + count: List[float] + edges: List[float] + widths: List[float] + +@dataclass +class ScatterPlot: + x: List[float] + y: List[float] + +@dataclass +class Metric: + histogram: Optional[Histogram] = None + mean_std: Optional[MeanStd] = None + heatmap: Optional[Heatmap] = None + scatter_plot: Optional[ScatterPlot] = None + + def filled_attributes(self) -> List[str]: + filled_attrs = [] + for attr, value in self.__dict__.items(): + if value is not None: + filled_attrs.append(attr) + return filled_attrs + +@dataclass +class Layer: + weight_info: WeightInfo + metrics: Dict[str, Metric] = field(default_factory=dict) + + def metrics_with_attribute(self, attribute: str) -> List[str]: + return [name for name, metric in self.metrics.items() if attribute in metric.filled_attributes()] + + def add_metric(self, metric: Metric, name: str): + if name not in self.metrics.keys(): + self.metrics[name] = metric + else: + raise ValueError(f"Metric with name {name} already exists in layer {self.weight_info.layer_name}.") + +def expand_to_fit(all_layer_names: List[str], values: List[float], subset_layer_names: List[str]) -> List[float]: + """ + Expands a list of values to fit a larger list of layer names, filling in missing values with None. + + Args: + all_layer_names (List[str]): List of all layer names. + values (List[float]): List of values to expand. + subset_layer_names (List[str]): List of layer names that the values correspond to. + + Returns: + List[float]: Expanded list of values, with None values for missing layers. + """ + result = [None] * len(all_layer_names) + subset_dict = dict(zip(subset_layer_names, values)) + + for i, layer in enumerate(all_layer_names): + if layer in subset_dict: + result[i] = subset_dict[layer] + + return result + +from typing import List, Tuple +from mergekit.graph import Task +from mergekit._data.models_and_datasets import save_model_and_dataset, model_and_dataset_to_index, index_to_model_and_dataset + +class Results: + # Class to store the statistics for each layer + def __init__(self): + self.layers: Dict[str, Layer] = {} + self.across_layer_metrics: Dict[str, Metric] = {} + # self.model_paths: Optional[List[str]] = None + self.representations_details: Optional[List[Tuple]] = [] # List of tuples of (model_name, dataset_name) + + def add_layer(self, layer: Layer, name: str): + if name not in self.layers.keys(): + self.layers[name] = layer + + def load_metrics(self, metrics: List[Tuple[Task, Layer]], model_paths: Optional[List[str]] = None): # Maybe remove (#) + self.model_paths = model_paths + for task, metric in metrics: + if metric is not None: + self.add_layer(metric, name=task.weight_info.name) + return self + + def load_representations_details_from_path(self, representations_path: str): + representations_path = Path(representations_path) + if representations_path.exists() and representations_path.is_file(): + file_name = representations_path.name + assert file_name.endswith('.h5'), f"File {file_name} is not an HDF5 file." + assert len(file_name.split('_')) == 4, f"File {file_name} does not follow the naming convention: '(model_num)_(dataset_num)_id_(unique_id).h5'" + + model_idx, dataset_idx = int(file_name.split('_')[0]), int(file_name.split('_')[1]) + model_name, dataset_name = index_to_model_and_dataset(model_idx, dataset_idx) + assert model_and_dataset_to_index(model_name, dataset_name) != ([], []), f"Model and dataset {model_name, dataset_name} not found in presets." + + self.representations_details.append((model_name, dataset_name)) + + def get_lineplot_data(self, metric_name: str): + means, stds = [],[] + + available_line_plots = self.available_plot_types(PlotType.MEAN_STD.value) + if metric_name not in available_line_plots: + return [], [] + + layers_with_data = available_line_plots[metric_name] + means = [self.layers[layer].metrics[metric_name].mean_std.mean for layer in layers_with_data] + stds = [self.layers[layer].metrics[metric_name].mean_std.std for layer in layers_with_data] + + means = expand_to_fit(all_layer_names=list(self.layers.keys()), values=means, subset_layer_names=layers_with_data) + stds = expand_to_fit(all_layer_names=list(self.layers.keys()), values=stds, subset_layer_names=layers_with_data) + + return means, stds + + def available_metrics(self) -> Dict[str, Dict[str, Any]]: + all_metrics = set() + for layer in self.layers.values(): + all_metrics.update(layer.metrics.keys()) + + metric_info = {} + for metric in all_metrics: + info = { + 'layers': [], + PlotType.MEAN_STD.value: False, + PlotType.HISTOGRAM.value: False, + PlotType.HEATMAP.value: False, + PlotType.SCATTER_PLOT.value: False + } + for layer_name, layer in self.layers.items(): + if metric in layer.metrics: + info['layers'].append(layer_name) + m = layer.metrics[metric] + if m.mean_std: + info[PlotType.MEAN_STD.value] = True + if m.histogram: + info[PlotType.HISTOGRAM.value] = True + if m.heatmap: + info[PlotType.HEATMAP.value] = True + if m.scatter_plot: + info[PlotType.SCATTER_PLOT.value] = True + metric_info[metric] = info + return metric_info + + def available_plot_types(self, plot_type: str) -> Dict[str, List[str]]: + # Returns dictionary with key metric_name and value: list of layers for which that metric has data + metric_info = self.available_metrics() + out = {} + plot_type = 'mean_std' if plot_type == 'line_plot' else plot_type + assert plot_type in [p.value for p in PlotType], f"Plot type {plot_type} is not valid. Must be one of {[p.value for p in PlotType]}" + for metric_name, info in metric_info.items(): + if info[plot_type]: + out[metric_name] = info['layers'] + return out + + def available_metrics_at_layer(self, layer_name: str) -> List[str]: + if layer_name in self.layers: + return list(self.layers[layer_name].metrics.keys()) + else: + return [] + + def print_metric_summary(self): + metric_info = self.available_metrics() + print("Available Metrics Summary:") + for metric, info in metric_info.items(): + print(f"\nMetric: {metric}") + print(f" Has mean/std: {'Yes' if info[PlotType.MEAN_STD.value] else 'No'}") + print(f" Has histogram: {'Yes' if info[PlotType.HISTOGRAM.value] else 'No'}") + print(f" Has heatmap: {'Yes' if info[PlotType.HEATMAP.value] else 'No'}") + print(f" Has scatter plot: {'Yes' if info[PlotType.SCATTER_PLOT.value] else 'No'}") + + def finalise(self): + self.layer_names = list(self.layers.keys()) + self.metric_names = list(set([metric for layer in self.layers.values() for metric in layer.metrics.keys()])) + + def save(self, out_dir: str, suffix: Optional[str] = None): + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + file_name = 'details_' + for i, (model_name, dataset_name) in enumerate(self.representations_details): + if i > 0: + file_name += '_and_' + m_idx, d_idx = model_and_dataset_to_index(model_name, dataset_name) + file_name += f"model_{m_idx}_dataset_{d_idx}" + + if suffix: + file_name += f"_{suffix}" + + path = out_dir / f"{file_name}.pkl" + + if not path.suffix or path.suffix != '.pkl': + path = path.with_suffix('.pkl') + + with path.open('wb') as f: + pickle.dump(self, f) + + def load(self, path: str): + path_obj = Path(path).resolve() + if path_obj.exists() and path_obj.is_file(): + with open(path_obj, 'rb') as f: + results = pickle.load(f) + assert isinstance(results, Results), "Loaded object is not a Results object" + return results + else: + raise FileNotFoundError(f"The path {path} does not exist or is not a file.") \ No newline at end of file diff --git a/mergekit/metric_methods/metrics.py b/mergekit/metric_methods/metrics.py new file mode 100644 index 00000000..8dc3618b --- /dev/null +++ b/mergekit/metric_methods/metrics.py @@ -0,0 +1,193 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from mergekit.common import ModelReference +from mergekit.metric_methods.base import MeanStd, Heatmap, Histogram, Metric +import torch +import torch.nn.functional as F +import numpy as np +from typing import List + +# Helper functions + +def compute_histogram(tensor: torch.Tensor, n_bins: int) -> List[np.ndarray]: + bin_counts, bin_edges = np.histogram(tensor.cpu().numpy(), bins=n_bins) + bin_widths = np.diff(bin_edges) + return bin_counts, bin_edges, bin_widths + +def cosine_similarity_heatmap(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + # Normalize the rows of both matrices + A_norm = A / A.norm(dim=1, keepdim=True) + B_norm = B / B.norm(dim=1, keepdim=True) + + # Compute the cosine similarity matrix + similarity_matrix = torch.mm(A_norm, B_norm.t()) + + return similarity_matrix + + +# Tensor Comparisons (Require exactly 2 tensors) + +def smape( + tensors: List[torch.Tensor], **_kwargs +) -> Metric: + """Symmetric Mean Absolute Percentage Error (smape).""" + + numerator = torch.abs(tensors[0] - tensors[1]) + denominator = (torch.abs(tensors[0]) + torch.abs(tensors[1])) + smape = torch.mean(torch.div(numerator, denominator), dim=1) + + hist_info = compute_histogram(smape, 100) + + return Metric( + histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), + mean_std=MeanStd(mean=smape.mean().item(), std=smape.std().item()) + ) + +def cosine_similarity( + tensors: List[torch.Tensor], return_heatmap=False, **_kwargs +) -> Metric: + """Cosine similarity""" + cosine_similarity = F.cosine_similarity(tensors[0], tensors[1], dim=1) + + if return_heatmap: + heatmap = cosine_similarity_heatmap(tensors[0], tensors[1]) + + assert torch.isclose(cosine_similarity, cosine_similarity, atol=1e-6).all(), "NaNs in cosine similarity" + assert torch.isclose(cosine_similarity, cosine_similarity_heatmap(tensors[0], tensors[1]).diagonal(), atol=1e-2).all(), "Diagonal elements of cosine similarity matrix do not match" + + hist_info = compute_histogram(cosine_similarity, 100) + return Metric( + histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), + mean_std=MeanStd(mean=cosine_similarity.mean().item(), std=cosine_similarity.std().item()), + heatmap=Heatmap(data=heatmap) if return_heatmap else None + ) + +def scale( + tensors: List[torch.Tensor], return_heatmap=False, **_kwargs +) -> Metric: + """ + Scale difference: ratio of absolute difference to average scale. + Complementary to cosine similarity, which measures the angle between two vectors and is invariant to scale. + + values close to 0 indicate that the scales of the two vectors are similar + """ + + norm_0 = torch.norm(tensors[0], dim=1) + norm_1 = torch.norm(tensors[1], dim=1) + + scale_diff = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1) / 2) + + if return_heatmap: + norm_0 = norm_0.unsqueeze(1) # shape becomes [num_heads, 1] + norm_1 = norm_1.unsqueeze(0) # shape becomes [1, num_heads] + + # Compute the scale difference between each pair of heads by broadcasting + heatmap = torch.abs(norm_0 - norm_1) / ((norm_0 + norm_1 + 1e-10) / 2) + + assert torch.isclose(scale_diff, heatmap.diagonal(), atol=1e-4).all(), "Diagonal elements of scale difference matrix do not match" + + hist_info = compute_histogram(scale_diff, 100) + + return Metric( + histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), + mean_std=MeanStd(mean=scale_diff.mean().item(), std=scale_diff.std().item()), + heatmap=Heatmap(data=heatmap) if return_heatmap else None + ) + +def mse( + tensors: List[torch.Tensor], return_heatmap: bool =False, **_kwargs +) -> Metric: + """Mean squared error (MSE).""" + if return_heatmap: + # Expand dimensions for broadcasting + tensors_0_exp = tensors[0].unsqueeze(1) # shape becomes [num_heads, 1, ...] + tensors_1_exp = tensors[1].unsqueeze(0) # shape becomes [1, num_heads, ...] + + # Compute squared differences + diffs = (tensors_0_exp - tensors_1_exp) ** 2 + + # Compute mean over all dimensions except the first two + heatmap = diffs.mean(dim=tuple(range(2, diffs.dim()))).cpu().numpy() + + squared_diff = (tensors[0] - tensors[1]) ** 2 + mse_per_neuron = torch.mean(squared_diff, dim=1) + + hist_info = compute_histogram(mse_per_neuron, 100) + + return Metric( + histogram=Histogram(count=hist_info[0], edges=hist_info[1], widths=hist_info[2]), + mean_std=MeanStd(mean=mse_per_neuron.mean().item(), std=mse_per_neuron.std().item()), + heatmap=Heatmap(data=heatmap) if return_heatmap else None + ) + +# Tensor Analysis (number of tensors can vary) + +def weight_magnitude(tensor: torch.Tensor) -> Metric: + weight_magnitudes = torch.abs(tensor.flatten()) + hist_info = compute_histogram(weight_magnitudes, 100) + return Metric( + histogram=Histogram(count=hist_info[0], + edges=hist_info[1], + widths=hist_info[2] + ), + mean_std=MeanStd(mean=weight_magnitudes.mean().item(), + std=weight_magnitudes.std().item()), + ) + +def numerical_rank(tensor: torch.Tensor, epsilon: float = 1e-5) -> Metric: + """ + Computes the numerical rank of the representations matrix X based on the singular values + of its sample covariance matrix. The rank is determined as the number of singular values + above a threshold. The threshold is defined as the highest singular value times a given epsilon. + + Parameters: + - X : torch.Tensor + The representations matrix from which the sample covariance matrix will be computed. + - epsilon : float, optional + The factor to multiply with the highest singular value to set the threshold (default is 1e-3). + - flip : bool, optional - allows transpose for efficient computation. False only used in testing + Returns: + - int + The numerical rank of the matrix. + + Implemented according to description in the paper: + The Tunnel Effect: Building Data Representations in Deep Neural Networks + https://arxiv.org/pdf/2305.19753.pdf + + """ + + # Center the data by subtracting the mean + X_centered = tensor - torch.mean(tensor, dim=0) + X_std = torch.std(X_centered, dim=0, unbiased=False) + X_centered /= X_std + + # Compute the sample covariance matrix + covariance_matrix = X_centered.t() @ X_centered / (tensor.shape[0] - 1) + # Compute singular values using SVD on the covariance matrix + U, singular_values, V = torch.svd(covariance_matrix.cpu()) + # Determine the threshold + threshold = singular_values[0] * epsilon + # Count singular values greater than the threshold + num_rank = torch.sum(singular_values > threshold).item() + + value = int(num_rank) + + return Metric( + mean_std=MeanStd( + mean=value, + std=None), + ) + diff --git a/mergekit/plan.py b/mergekit/plan.py index bdcd7004..80177225 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -17,7 +17,7 @@ from functools import lru_cache from typing import Any, List, Optional, Tuple -from mergekit import merge_methods +from mergekit import merge_methods, metric_methods from mergekit.architecture import ( ArchitectureInfo, ConfiguredArchitectureInfo, @@ -40,6 +40,7 @@ TensorWriterTask, ) from mergekit.merge_methods import MergeMethod +from mergekit.metric_methods import MetricMethod from mergekit.options import MergeOptions from mergekit.tokenizer import BuildTokenizer, PermutedEmbeddings @@ -65,7 +66,10 @@ def __init__( self.arch_info = arch_info self.options = options self.out_model_config = out_model_config - self._method = merge_methods.get(config.merge_method) + if getattr(config, "merge_method", None): + self._method = merge_methods.get(config.merge_method) + elif getattr(config, "metric_method", None): + self._method = metric_methods.get(config.metric_method) token_cfg = {} tokenizer_source = config.tokenizer_source @@ -254,10 +258,24 @@ def plan_slice(self, definition: OutputSliceDefinition): cfg_reader=cfg_reader, ) + def metrics_plan_to_disk(self) -> List[Task]: + """Plan the metrics to be streamed to disk, returning a list of tasks.""" + save_tasks = [] + for weight, tensor_task in self._tensors: + save_tasks.append( + tensor_task + ) + + return save_tasks + def plan_to_disk(self, out_path: str) -> List[Task]: """Plan the merge to be streamed to disk, returning a list of tasks.""" self._plan() + if self.config.metric_method: + return self.metrics_plan_to_disk() + + writer_task = TensorWriterTask( out_path=out_path, max_shard_size=self.options.out_shard_size, @@ -287,6 +305,10 @@ def plan_to_disk(self, out_path: str) -> List[Task]: def plan_in_memory(self) -> List[ReturnTensor]: """Plan the merge to be performed in memory.""" self._plan() + + if self.config.metric_method: + return self.metrics_plan_to_disk() + return [ ReturnTensor( weight_info=w, diff --git a/mergekit/plot_tools/plot_tools.py b/mergekit/plot_tools/plot_tools.py new file mode 100644 index 00000000..6fb3d204 --- /dev/null +++ b/mergekit/plot_tools/plot_tools.py @@ -0,0 +1,393 @@ +import numpy as np +from typing import List, Dict, Optional, Any, Tuple +from mergekit.graph import Task +import networkx as nx +import plotly.graph_objects as go +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +import dash +from dash import dcc, html +from dash.dependencies import Input, Output, State +from mergekit.metric_methods.all_metrics import Layer +from mergekit.metric_methods.base import Results, PlotType +from mergekit.common import ModelReference +from plotly.subplots import make_subplots + +global_colours_list = ['blue', 'red', 'green', 'purple', 'orange', 'pink'] +global_shapes_list = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up', 'triangle-down', 'pentagon', 'hexagon', 'star'] + +class ResultsHandler: + def __init__(self): + self.intra_model_results: Dict[ModelReference, Results] = {} + self.inter_model_results: Results = None + self.available_layer_plots = { + 'mean_std': [], + 'histogram': [], + 'heatmap': [], + 'scatter_plot': [] + } + + def load_results(self, results: Results): # Generalise to handle both representations and model weights + results.finalise() + if len(results.representations_details) == 2: + self.inter_model_results = results + elif len(results.representations_details) == 1: + key = len(self.intra_model_results) + self.intra_model_results[key] = results + else: + raise ValueError(f"Results should have either 1 or 2 inputs, got {len(results.representations_details)}") + + for plot_type in self.available_layer_plots.keys(): + + if self.inter_model_results is not None: + self.available_layer_plots[plot_type] += list(self.inter_model_results.available_plot_types(plot_type).keys()) + if self.intra_model_results is not None: + for model_id, results in self.intra_model_results.items(): + self.available_layer_plots[plot_type] += list(results.available_plot_types(plot_type).keys()) + + self.available_layer_plots[plot_type] = list(set(self.available_layer_plots[plot_type])) + + self.all_results = list(self.intra_model_results.values()) + [self.inter_model_results] + + def categorise_layers(self, layer_names) -> List[str]: + # Hardcoded layernames for now - can be extended to include more categories or further generalised based on config + categories = [] + for name in layer_names: + if 'Attention Block' in name: + categories.append('Attention Block') + elif 'mlp' in name: + categories.append('MLP') + elif 'layernorm' in name: + categories.append('LayerNorm') + else: + categories.append('Other') + return categories + + def plotly_line_plots(self, metric_name:str): + traces = [] + if metric_name in self.inter_model_results.available_plot_types('line_plot'): # bring if case into loop? (X) + layer_names = self.inter_model_results.layer_names + means, stds = self.inter_model_results.get_lineplot_data(metric_name) + categorised_layers = self.categorise_layers(layer_names) # Different category for each layer type + unique_categories = list(set(categorised_layers)) + traces = self._plotly_line_plot(layer_names, means, stds, categorised_layers, unique_categories) + + else: + unique_categories = list(self.intra_model_results.keys()) + for i, (model_id, results) in enumerate(self.intra_model_results.items()): + layer_names = results.layer_names + means, stds = results.get_lineplot_data(metric_name) + if means: + categorised_layers = [model_id]*len(layer_names) # Different category for each model, every layer in each model has the same category + shape = global_shapes_list[i%len(global_shapes_list)] + traces.extend(self._plotly_line_plot(layer_names, means, stds, categorised_layers, unique_categories, shape)) + + return traces, layer_names + + def _plotly_line_plot(self, x_values, means, stds, categorised_layers, unique_categories, shape:str='circle', **kwargs): + """ + + Returns: + List[go.Scatter]: List of Plotly Scatter objects. + """ + + # Assign a unique color to each category + cmap = plt.get_cmap('Set1', len(unique_categories)) + colors = [mcolors.to_hex(cmap(i)) for i in range(len(unique_categories))] + + category_styles = {cat: colors[i % len(colors)] for i, cat in enumerate(unique_categories)} + + traces = [] + + for category in unique_categories: + y_category = [means[i] if categorised_layers[i] == category else None for i in range(len(categorised_layers))] + std_category = [stds[i] if categorised_layers[i] == category else None for i in range(len(categorised_layers))] + if all([y is None for y in y_category]): + continue + + traces.append(go.Scatter( + x=x_values, + y=y_category, + error_y=dict( + type='data', + array=std_category, + visible=True + ), + mode='markers', + name=str(category), + marker=dict(color=category_styles[category]), + marker_symbol=shape + )) + return traces + + def plotly_layer_plot(self, layer_name:str, metric_name:str, plot_type:str): + assert plot_type in [p.value for p in PlotType], f"Plot type {plot_type} not in {[p.value for p in PlotType]}" + data = [] + + for result in self.all_results: + valid_metrics = result.available_plot_types(plot_type) + if metric_name in valid_metrics.keys(): + if layer_name in valid_metrics[metric_name]: + data.append(getattr(result.layers[layer_name].metrics[metric_name], plot_type)) + + return self.get_traces(data, plot_type) # Can prob use type of data to determine plot type (X) + + def get_traces(self, data:List, plot_type:str): # Can prob use type of data to determine plot type (X) + if plot_type == PlotType.HEATMAP.value: + traces = [go.Heatmap( + z=d.data, + colorscale='Bluered', + ) for d in data] + elif plot_type == PlotType.SCATTER_PLOT.value: + traces = [go.Scatter( + x = d.x, + y = d.y, + mode='markers' + ) for d in data] + elif plot_type == PlotType.HISTOGRAM.value: + traces = [] + for i, d in enumerate(data): + count, edges, widths = d.count, d.edges, d.widths + traces.append(go.Bar( + x=edges[:-1], + y=count, + width=widths, + marker=dict( + color=global_colours_list[i], + opacity=0.75, + line=dict( + color='black', + width=1 + ) + ))) + else: + raise ValueError(f'{plot_type} not valid for layer specific plot') + + return traces + + def layer_plot_options(self, layer_name: str): + metric_options = [] + avoid_duplicates = [] + for plot_type in PlotType: + if plot_type == PlotType.MEAN_STD: + continue + for result in self.all_results: + if layer_name in result.layers: + metrics = result.layers[layer_name].metrics_with_attribute(plot_type.value) + for metric in metrics: + if (metric, plot_type) not in avoid_duplicates: + metric_options.append({f"label": f"{metric.title()} {plot_type.value}", + "value": [metric, plot_type.value] + }) + avoid_duplicates.append((metric, plot_type)) + return metric_options + +def create_app(results_handler): + app = dash.Dash(__name__, external_stylesheets=['https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css']) + + app.layout = html.Div([ + create_header(), + create_line_plot_section(results_handler), + create_single_layer_section(), + create_across_layers_section(results_handler) + ]) + + register_callbacks(app, results_handler) + + return app + +def create_header(): + return html.H1('Network Weights Similarity Visualization', + style={'textAlign': 'center', 'padding': '20px'}) + +def create_line_plot_section(results_handler): + return html.Div([ + dcc.Dropdown( + id='line-plot-dropdown', + options=[{'label': metric_name.replace('_', ' ').title(), 'value': metric_name} + for metric_name in results_handler.available_layer_plots['mean_std']], + value='cosine_similarity', + style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'} + ), + dcc.Graph(id='line-plot', style={'width': '100%', 'height': '100vh'}) + ], className='container-fluid') + +def create_single_layer_section(): + return html.Div([ + html.H3('Layer Metrics', style={'textAlign': 'center'}), + dcc.Dropdown( + id='metric-dropdown', + options=[], + style={'width': '50%', 'margin': 'auto', 'display': 'block', 'font-family': 'Arial'}, + value=None + ), + dcc.Graph(id='layer-details-plot', style={'width': '100%', 'height': '80vh', 'textAlign': 'center'}) + ], className='container-fluid') + +def create_across_layers_section(results_handler): + results = list(results_handler.intra_model_results.values()) + [results_handler.inter_model_results] + + plot_sections = [] + + for result in results: + if getattr(result, 'across_layer_metrics', None): + for metric_name, metric in result.across_layer_metrics.items(): + for plot_type in ['histogram', 'heatmap', 'scatter_plot']: + if getattr(metric, plot_type, None): + plot_sections.append(html.Div([ + html.H3(f'{metric_name.replace("_", " ").title()} {plot_type.replace("_", " ").title()}', style={'textAlign': 'center'}), + dcc.Graph(id=f"{plot_type}-plot-{metric_name}-{''.join(item for tup in result.representations_details for item in tup)}", style={'width': '50%', 'height': '50%', 'position': 'relative'}) + ], className='container-fluid')) + + return html.Div(plot_sections) + +def default_option(options, current_value): + if not options: + return None + if current_value.lower() in [o.lower() for o in options]: + return current_value + for option in options: + if option.lower() in current_value.lower() or current_value.lower() in option.lower(): + return option + return options[0] + +def register_callbacks(app, results_handler): + @app.callback( + Output('metric-dropdown', 'options'), + Output('metric-dropdown', 'value'), + Input('line-plot', 'clickData'), + Input('line-plot-dropdown', 'value') + ) + def update_metric_dropdown_options(clickData, selected_metric): # What distinguishes these options from layer-specific options? + if not clickData: + return [], None + + try: + layer_name = clickData['points'][0]['x'] + options = results_handler.layer_plot_options(layer_name) + default_label = default_option(list(map(lambda x: x['label'], options)), selected_metric) + default_value = [option['value'] for option in options if option['label'] == default_label][0] + return options, default_value + + except (KeyError, IndexError, AttributeError) as e: + print(f"Error processing clickData: {e}") + return [], None + + @app.callback( + Output('layer-details-plot', 'figure'), + Input('metric-dropdown', 'value'), + State('line-plot', 'clickData') + ) + def display_layer_data(selected_metric, clickData): + if not clickData: + return go.Figure() + + try: + layer_name = clickData['points'][0]['x'] + if not selected_metric: + selected_metric = results_handler.layer_plot_options(layer_name)[0]['value'] + + metric_name, plot_type = selected_metric + + if plot_type.lower() in ["heatmap", "scatter_plot"]: #(x) + xaxis_title = "Model 1" + yaxis_title = "Model 0" + elif plot_type.lower() == 'histogram': + xaxis_title = "Value" + yaxis_title = "Count" + + traces = results_handler.plotly_layer_plot(layer_name, metric_name, plot_type) + + return create_figure(traces=traces, + title=f"{plot_type.title()} for {layer_name} | {metric_name}", + xaxis_title=xaxis_title, + yaxis_title=yaxis_title, + plot_type=plot_type + ) + + except (KeyError, IndexError, AttributeError) as e: + print(f"Error processing layer data: {e}") + return go.Figure() + + @app.callback( + Output('line-plot', 'figure'), + Input('line-plot-dropdown', 'value') + ) + def update_line_plot(selected_metric): + if not selected_metric: + return go.Figure() + + traces, layer_names = results_handler.plotly_line_plots(metric_name=selected_metric) + fig = go.Figure() + for trace in traces: + if trace: + fig.add_trace(trace) + + fig.update_layout( + title=f"{selected_metric.replace('_', ' ').title()} Across Layers", + xaxis=dict( + title='Layer', + tickvals=list(range(len(layer_names))), + ticktext=layer_names + ), + yaxis=dict(title=selected_metric.replace('_', ' ').title()) + ) + return fig + + for result in results_handler.all_results: + if getattr(result, 'across_layer_metrics', None): + for metric_name, metric in result.across_layer_metrics.items(): + for plot_type in ['histogram', 'heatmap', 'scatter_plot']: + if getattr(metric, plot_type, None): + id=f"{plot_type}-plot-{metric_name}-{''.join(item for tup in result.representations_details for item in tup)}" + + @app.callback( + Output(id, 'figure'), + Input(id, 'id') + ) + def update_across_layers_plot(_id=id, plot_type=plot_type, metric=metric, rep_details=result.representations_details): + traces = results_handler.get_traces(data = [getattr(metric, plot_type)], plot_type = plot_type) + xaxis_title = f"Model {rep_details[0][0]} {rep_details[0][1]}" + yaxis_title = f"Model {rep_details[1][0]} {rep_details[1][1]}" if len(rep_details) > 1 else xaxis_title + if hasattr(metric, 'plot_details'): + xaxis_title = metric.plot_details['xaxis_title'] if 'xaxis_title' in metric.plot_details else xaxis_title + yaxis_title = metric.plot_details['yaxis_title'] if 'yaxis_title' in metric.plot_details else yaxis_title + return create_figure(traces=traces, + title=f"{id} | {metric_name}", + xaxis_title=xaxis_title, + yaxis_title=xaxis_title, + plot_type = plot_type + ) + +def create_figure(traces, xaxis_title, yaxis_title, plot_type, title=None, subplot_titles=[None]): + if plot_type in ["scatter_plot", "heatmap"]: + num_plots = len(traces) + + num_cols = 2 if num_plots > 1 else 1 + + num_rows = (num_plots + 1) // num_cols if num_plots > 1 else 1 + + fig = make_subplots(rows=num_rows, + cols=num_cols, + subplot_titles=subplot_titles + if subplot_titles + else [f"Plot {i+1}" for i in range(num_plots)]) + + for i, trace in enumerate(traces): + row = (i // num_cols) + 1 + col = (i % num_cols) + 1 + fig.add_trace(trace, row=row, col=col) + fig.update_xaxes(title_text=xaxis_title, row=row, col=col) + fig.update_yaxes(title_text=yaxis_title, row=row, col=col) + else: + fig = go.Figure() + for trace in traces: + fig.add_trace(trace) + + fig.update_layout( + title=title, + xaxis=dict(title=xaxis_title), + yaxis=dict(title=yaxis_title) + ) + + return fig diff --git a/mergekit/pyvenv.cfg b/mergekit/pyvenv.cfg new file mode 100644 index 00000000..4c61b08f --- /dev/null +++ b/mergekit/pyvenv.cfg @@ -0,0 +1,5 @@ +home = /Users/estein/anaconda3/bin +include-system-site-packages = false +version = 3.11.5 +executable = /Users/estein/anaconda3/bin/python3.11 +command = /Users/estein/anaconda3/bin/python -m venv /Users/estein/Documents/repos/mergekit/mergekit diff --git a/mergekit/scripts/run_metrics.py b/mergekit/scripts/run_metrics.py new file mode 100644 index 00000000..4529edf3 --- /dev/null +++ b/mergekit/scripts/run_metrics.py @@ -0,0 +1,81 @@ +import click +import torch +import yaml + +from mergekit.config import MergeConfiguration +from mergekit.merge import MergeOptions +from mergekit.merge import run_merge +from mergekit.plot_tools.plot_tools import create_app, ResultsHandler +from mergekit.metric_methods.base import Results + +def create_temp_config(config_yml, **kwargs): + with open(config_yml, "r", encoding="utf-8") as config: + config = yaml.safe_load(config) + config.update(kwargs) + return MergeConfiguration.model_validate(config) + +@click.command() +@click.option('--output_path', default="./merged", help='folder to store the result in.') +@click.option('--config_yml', default="./examples/metrics-small.yml", help='merge configuration file.') +@click.option('--copy_tokenizer', default=True, help='') +@click.option('--lazy_unpickle', default=False, help='experimental low-memory model loader.') +@click.option('--low_cpu_memory', default=False, help='enable if you somehow have more VRAM than RAM+swap') +def main(output_path, config_yml, copy_tokenizer, lazy_unpickle, low_cpu_memory): + with open(config_yml, "r", encoding="utf-8") as config: + config = yaml.safe_load(config) + metric_config = MergeConfiguration.model_validate(config) + + models = metric_config.models + intra_model = config['parameters']['intra_model_metrics'] + inter_model = config['parameters']['inter_model_metrics'] + + intra_results = {} + inter_results = None + if intra_model: + print(f"Running intra-model metrics for {len(models)} models: {models}") + for model in models: + temp_config = create_temp_config(config_yml, models=[{'model':model.model.model.path}], parameters=({'intra_model_metrics':True, 'inter_model_metrics':False})) + print(f" {model}") + metrics_out = run_merge( + temp_config, + out_path=output_path, + options=MergeOptions( + cuda=torch.cuda.is_available(), + copy_tokenizer=copy_tokenizer, + lazy_unpickle=lazy_unpickle, + low_cpu_memory=low_cpu_memory, + ), + ) + intra_results[model.model] = Results().load_metrics(metrics_out, model_paths=[model.model.model.path]) + + if inter_model: + assert len(models) == 2, "Inter-model metrics require exactly 2 models" + print(f"Running inter-model metrics for {models}") + temp_config = create_temp_config(config_yml, parameters=({'intra_model_metrics':False, 'inter_model_metrics':True})) + + print(f" {models}") + metrics_out = run_merge( + temp_config, + out_path=output_path, + options=MergeOptions( + cuda=torch.cuda.is_available(), + copy_tokenizer=copy_tokenizer, + lazy_unpickle=lazy_unpickle, + low_cpu_memory=low_cpu_memory, + ), + ) + inter_results = Results().load_metrics(metrics_out, model_paths=models) + + + handler = ResultsHandler() + + handler.load_results(inter_results) + for result in intra_results.values(): + handler.load_results(result) + + + app = create_app(results_handler=handler) + app.run_server() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a8a339a7..fde49e1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,8 @@ dev = ["black~=24.4.2", "isort~=5.13.2", "pre-commit~=3.7.1"] test = ["pytest~=8.2.1"] evolve = ["ray", "cma", "lm_eval", "wandb"] vllm = ["vllm==0.3.2", "lm_eval[vllm]"] +interactive_plot = ["networkx", "plotly", "matplotlib", "dash"] +representations = ["h5py", "datasets", "bitsandbytes", "scikit-learn"] [project.urls] repository = "https://github.com/cg123/mergekit" @@ -51,6 +53,8 @@ packages = [ "mergekit", "mergekit.io", "mergekit.merge_methods", + "mergekit.metric_methods", + "mergekit.plot_tools", "mergekit.moe", "mergekit.scripts", "mergekit.evo", diff --git a/representations/configs/config.yml b/representations/configs/config.yml new file mode 100644 index 00000000..e049ef62 --- /dev/null +++ b/representations/configs/config.yml @@ -0,0 +1,11 @@ +metrics: + cosine_similarity: true + mse: true + linearity_score: true + cka: true + cknna: true + t-sne: true + pca_projection: true + +analysis_type: "individual" +comparison_type: "block" \ No newline at end of file diff --git a/representations/configs/config_comp_all.yml b/representations/configs/config_comp_all.yml new file mode 100644 index 00000000..16b04e8b --- /dev/null +++ b/representations/configs/config_comp_all.yml @@ -0,0 +1,11 @@ +metrics: + cosine_similarity: true + mse: true + linearity_score: true + cka: true + cknna: true + t-sne: true + pca_projection: true + +analysis_type: "comparison" +comparison_type: "all" \ No newline at end of file diff --git a/representations/configs/config_comp_corresponding.yml b/representations/configs/config_comp_corresponding.yml new file mode 100644 index 00000000..0cf8a5d9 --- /dev/null +++ b/representations/configs/config_comp_corresponding.yml @@ -0,0 +1,11 @@ +metrics: + cosine_similarity: true + mse: true + linearity_score: true + cka: true + cknna: true + t-sne: true + pca_projection: true + +analysis_type: "comparison" +comparison_type: "corresponding" \ No newline at end of file diff --git a/representations/configs/config_i_all.yml b/representations/configs/config_i_all.yml new file mode 100644 index 00000000..e2764960 --- /dev/null +++ b/representations/configs/config_i_all.yml @@ -0,0 +1,11 @@ +metrics: + cosine_similarity: true + mse: true + linearity_score: true + cka: true + cknna: true + t-sne: true + pca_projection: true + +analysis_type: "individual" +comparison_type: "all" \ No newline at end of file diff --git a/representations/configs/config_i_block.yml b/representations/configs/config_i_block.yml new file mode 100644 index 00000000..e049ef62 --- /dev/null +++ b/representations/configs/config_i_block.yml @@ -0,0 +1,11 @@ +metrics: + cosine_similarity: true + mse: true + linearity_score: true + cka: true + cknna: true + t-sne: true + pca_projection: true + +analysis_type: "individual" +comparison_type: "block" \ No newline at end of file diff --git a/representations/configs/config_i_single.yml b/representations/configs/config_i_single.yml new file mode 100644 index 00000000..fe653e8d --- /dev/null +++ b/representations/configs/config_i_single.yml @@ -0,0 +1,11 @@ +metrics: + cosine_similarity: true + mse: true + linearity_score: true + cka: true + cknna: true + t-sne: true + pca_projection: true + +analysis_type: "individual" +comparison_type: "single" \ No newline at end of file diff --git a/representations/experiment_setup.py b/representations/experiment_setup.py new file mode 100644 index 00000000..ca3a61f4 --- /dev/null +++ b/representations/experiment_setup.py @@ -0,0 +1,367 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import List, Dict, Any, Optional + +import numpy as np +import torch +from tqdm import tqdm +import h5py + +from mergekit.architecture import WeightInfo +from mergekit.metric_methods.base import Results, Layer +from mergekit.metric_methods.aggregator_metrics import ModelAnalysisType, LayerComparisonType, METRICS_TABLE, MetricAggregator + +from mergekit.metric_methods.base import Heatmap, Metric + +def check_memory(h5file, h5file_2=None): + # Check if full data can be loaded into memory + # Not yet implemented (X) + return True + +def convert_to_2d_array(arrays): + """ + Convert a list of 1D numpy arrays into a single 2D array. + + Parameters: + arrays (list of np.ndarray): List of 1D numpy arrays. + + Returns: + np.ndarray: 2D numpy array with dimensions N x max_length, filled with NaNs where necessary. + """ + # Determine the length of the longest array + max_length = max(len(arr) for arr in arrays) + + # Create an empty 2D array filled with NaNs + result = np.full((len(arrays), max_length), np.nan) + + # Populate the 2D array with the values from the 1D arrays + for i, arr in enumerate(arrays): + result[i, :len(arr)] = arr + + return result + +class LayerByIndex: + def __init__(self, reps_path: str, load_into_memory: bool = True): + self.reps_path = reps_path + self.representations = None + self.layers = None + self.load_into_memory = load_into_memory + self.in_memory_data = None + self.device = 'cuda' if torch.cuda.is_available() else \ + 'mps' if torch.backends.mps.is_available() else 'cpu' + + def __enter__(self): + self.representations = h5py.File(self.reps_path, 'r') + self.layers = list(self.representations.keys()) + + if self.load_into_memory: + print("Loading representations into memory") + self.in_memory_data = {layer: {} for layer in self.layers} + for layer_name in tqdm(self.layers, leave=False, initial = 1, desc='Loading Layer', total=len(self.layers)): + for batch_name, batch_data in self.representations[layer_name].items(): + data = torch.tensor(batch_data[...]).to(self.device) + self.in_memory_data[layer_name][batch_name] = data + return self + + def __exit__(self, *args, **kwargs): + if self.representations: + self.representations.close() + self.in_memory_data = None + + def __getitem__(self, idx: int): + if self.load_into_memory: + return self.in_memory_data[self.layers[idx]] + else: + return self.representations[self.layers[idx]] #(X) + + def __len__(self) -> int: + return len(self.layers) + + # def __iter__(self): + # if self.load_into_memory: + # return ((layer, self.in_memory_data[layer]) for layer in self.layers) + # else: + # return ((layer, self.representations[layer]) for layer in self.layers) + def __iter__(self): + if self.load_into_memory: + return ((layer, self.in_memory_data[layer]) for layer in self.layers) + else: + return ((layer, self.representations[layer]) for layer in self.layers) + +def valid_experiment(analysis_type, comparison_type): + if comparison_type == LayerComparisonType.ALL: + raise ValueError("Comparison type 'all_layers' is not supported") + if analysis_type == ModelAnalysisType.COMPARISON and comparison_type in [LayerComparisonType.BLOCK, LayerComparisonType.SINGLE]: + raise ValueError("Comparison type 'single' and 'block' only supported for individual analysis") + if analysis_type == ModelAnalysisType.INDIVIDUAL and comparison_type == LayerComparisonType.CORRESPONDING: + raise ValueError("Comparison type 'corresponding' only supported for comparison analysis") + +# Experiment Loops +def single(representations: LayerByIndex, metric_classes: List[MetricAggregator], results: Optional[Results], device='cpu'): + if not results: + results = Results() + for layer_name, layer in tqdm(representations, desc='Analysing Layer', + total=len(representations), leave=False, initial = 1): + # layer_name = layer.name.split('/')[-1] # layer is a dictionary of batches, doens't have .name attribute. + # num = f"{int(layer_name.split('_')[-1]):03d}" + # layer_name = f"Layer {num}" + metrics = [metric_class(device=device) for metric_class in metric_classes] + + for batch in tqdm(layer.values(), desc='Batch', + total=len(layer), leave=False, initial = 1): + # batch = torch.tensor(layer[batch][:], device=device) # Redundant now as LayerByIndex is already returning torch tensors + + # Calculate the metrics for each batch + for metric in metrics: + metric.process_batch(batch) + + layer_results = Layer(WeightInfo(name=layer_name)) + # Aggregate over the batches and add to the layer results + for metric in metrics: + layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) + # metric.clear() + + results.add_layer(layer_results, layer_name) + + return results + +def corresponding(representations_0: LayerByIndex, representations_1: LayerByIndex, metric_classes: List[MetricAggregator], results: Optional[Results], device='cpu'): + if not results: + results = Results() + + for (layer_0_name, layer_0), (layer_1_name, layer_1) in tqdm(zip(representations_0, representations_1), + desc='Comparing Representations at layer', + total=len(representations_0), initial = 1): + + if layer_0_name != layer_1_name: + raise ValueError(f'Layer mismatch: {layer_0_name} != {layer_1_name}') + + metrics = [metric_class(device=device) for metric_class in metric_classes] + + for batch_0, batch_1 in tqdm(zip(layer_0.values(), layer_1.values()), + desc='Batch', total=len(layer_0), leave=False, initial = 1): + + # Calculate the metrics for each batch + for metric in metrics: + metric.process_batch(batch_0, batch_1) + + layer_results = Layer(WeightInfo(name=layer_0_name)) + # Aggregate over the batches and add to the layer results + for metric in metrics: + layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) # (X) + + results.add_layer(layer_results, layer_0_name) + + return results + +def block(representations, block_size, metric_classes, device='cpu'): + out = {metric().__class__.__name__.lower(): [] for metric in metric_classes} + for idx, (block_start_name, block_start) in tqdm(enumerate(representations), desc=f'Comparing {block_size}-block, Block Start at Layer', + total=len(representations) - block_size, leave=False, initial = 1): + if idx + block_size >= len(representations): + break + + # Create metrics + metrics = [metric_class(device=device) for metric_class in metric_classes] + block_end = representations[idx + block_size] + + for batch_0, batch_1 in tqdm(zip(block_start.values(), block_end.values()), desc='Batch', + total=len(block_start), leave=False, initial = 1): + for metric in metrics: + metric.process_batch(batch_0, batch_1) + + # Aggregate metrics and add to results + for metric in metrics: + out[metric.__class__.__name__.lower()].append(metric.aggregate()) + + for metric_name, metric in out.items(): + out[metric_name] = np.array([m.mean_std.mean for m in out[metric_name]]) + return out + +def all_layers(representations_0, representations_1, metric_classes, results, device='cpu'): + for layer_0_name, layer_0 in tqdm(representations_0, desc='Model 0 Layers', + total=len(representations_0), leave=False, initial = 1): + for layer_1_name, layer_1 in tqdm(representations_1, desc='Model 1 Layers', + total=len(representations_1), leave=False, initial = 1): + if len(layer_0) != len(layer_1): + raise ValueError(f'Layer mismatch: {len(layer_0)} != {len(layer_1)}') + + metrics = [metric_class(device=device) for metric_class in metric_classes] + + for batch_0, batch_1 in tqdm(zip(layer_0.values(), layer_1.values()), desc='Batch', + total=len(layer_0), leave=False, initial = 1): + + for metric in metrics: + metric.process_batch(batch_0, batch_1) + + layer_results = Layer(WeightInfo(name=f"{layer_0_name} - {layer_1_name}")) + for metric in metrics: + layer_results.add_metric(metric.aggregate(), metric.__class__.__name__.lower()) + + results.add_layer(layer_results, f"{layer_0_name} - {layer_1_name}") + + return results + +@dataclass +class Configuration: + representation_paths: List[Path] + metrics: Dict[str, bool] + analysis_type: ModelAnalysisType + comparison_type: LayerComparisonType + out_dir: Path + device: str + data: Optional[LayerByIndex] = None + + @classmethod + def from_dict(cls, config_dict: Dict): + return cls( + representation_paths=list([path for path in config_dict['representations_to_analyse'].iterdir() if path.suffix == '.h5']), + metrics=config_dict['metrics'], + analysis_type=ModelAnalysisType(config_dict['analysis_type']), + comparison_type=LayerComparisonType(config_dict['comparison_type']), + out_dir=Path(config_dict['out_dir']), + device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + ).validate() + + def validate(self): + if self.comparison_type == LayerComparisonType.ALL: + raise ValueError("Comparison type 'all_layers' is not supported") + if self.analysis_type == ModelAnalysisType.COMPARISON and self.comparison_type in [LayerComparisonType.BLOCK, LayerComparisonType.SINGLE]: + raise ValueError("Comparison type 'single' and 'block' only supported for individual analysis") + if self.analysis_type == ModelAnalysisType.INDIVIDUAL and self.comparison_type == LayerComparisonType.CORRESPONDING: + raise ValueError("Comparison type 'corresponding' only supported for comparison analysis") + return self + +class Experiment(ABC): + @abstractmethod + def run(self, config: Configuration): + pass + +class SingleExperiment(Experiment): + def run(self, config: Configuration): + # Implementation for single experiment + use_metrics = {name: METRICS_TABLE[name] + for name, enabled in config.metrics.items() + if enabled} + + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.SINGLE.value]] + for representation_path in config.representation_paths: + individual_results = Results() + individual_results.load_representations_details_from_path(representation_path) + + if not representation_path.exists(): + raise FileNotFoundError(f"Representation file {representation_path} not found") + + with LayerByIndex(representation_path, load_into_memory = check_memory(representation_path)) as representations: + individual_results = single(representations, + metrics, + results=individual_results, + device=config.device) + + individual_results.save(config.out_dir, suffix=f"{config.analysis_type.value}+{config.comparison_type.value}") + +class CorrespondingExperiment(Experiment): + def run(self, config: Configuration): + use_metrics = {name: METRICS_TABLE[name] + for name, enabled in config.metrics.items() + if enabled} + + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.CORRESPONDING.value]] + comparison_results = Results() + stop = False + for rep_0 in config.representation_paths: + for rep_1 in config.representation_paths: + if ((rep_0 == rep_1 and config.analysis_type == ModelAnalysisType.INDIVIDUAL) or \ + (rep_0 != rep_1 and config.analysis_type == ModelAnalysisType.COMPARISON) and not stop): + if rep_0 != rep_1: + stop = True + + with LayerByIndex(rep_0) as representations_0, LayerByIndex(rep_1) as representations_1: + comparison_results.load_representations_details_from_path(rep_0) + comparison_results.load_representations_details_from_path(rep_1) + comparison_results = corresponding(representations_0=representations_0, + representations_1=representations_1, + metric_classes=metrics, + results=comparison_results, + device=config.device) + + comparison_results.save(config.out_dir, suffix=f"{config.analysis_type.value}+{config.comparison_type.value}") + +class BlockExperiment(Experiment): + def run(self, config: Configuration): + use_metrics = {name: METRICS_TABLE[name] + for name, enabled in config.metrics.items() + if enabled} + + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.BLOCK.value]] + for representation_path in config.representation_paths: + heatmaps = {} + block_results = Results() + block_results.load_representations_details_from_path(representation_path) + if not representation_path.exists(): + raise FileNotFoundError(f"Representation file {representation_path} not found") + for metric in metrics: + heatmaps[metric().__class__.__name__.lower()] = [] + with LayerByIndex(representation_path, load_into_memory = check_memory(representation_path)) as representations: + for block_size in range(1, 9): # (X) + block_res = block(representations=representations, + block_size=block_size, + metric_classes=metrics, + device=config.device) + for metric in metrics: + heatmaps[metric().__class__.__name__.lower()].append(block_res[metric().__class__.__name__.lower()]) + + for metric in metrics: + block_results.across_layer_metrics[metric().__class__.__name__.lower()] = Metric( + heatmap = Heatmap( + data = convert_to_2d_array(heatmaps[metric().__class__.__name__.lower()]), # Definitely a simpler way to code this (X) + plot_details={'title': f'{metric().__class__.__name__} across N-blocks', 'xlabel': 'Layer', 'ylabel': 'Block Size'} + ) + ) + block_results.save(config.out_dir, suffix=f"{config.analysis_type.value}+{config.comparison_type.value}") + +class AllLayersExperiment(Experiment): + def run(self, config: Configuration): + use_metrics = {name: METRICS_TABLE[name] + for name, enabled in config.metrics.items() + if enabled} + + metrics = [metric for metric in use_metrics.values() if metric().valid_for[LayerComparisonType.ALL.value]] + + stop = False + for rep_0 in config.representation_paths: + for rep_1 in config.representation_paths: + if ((rep_0 == rep_1 and config.analysis_type == ModelAnalysisType.INDIVIDUAL) or \ + (rep_0 != rep_1 and config.analysis_type == ModelAnalysisType.COMPARISON) and not stop): + if rep_0 != rep_1: + stop = True + load_into_memory = check_memory(rep_0, rep_1) + with LayerByIndex(rep_0, load_into_memory) as representations_0, \ + LayerByIndex(rep_1, load_into_memory) as representations_1: + + comparison_results = all_layers(representations_0=representations_0, + representations_1=representations_1, + metric_classes=metrics, + results=comparison_results, + device=config.device) + comparison_results.load_representations_details_from_path(rep_0) + comparison_results.load_representations_details_from_path(rep_1) + + comparison_results.save(config.out_dir, suffix=f"{config.analysis_type.value}+{config.comparison_type.value}") + + +class ExperimentFactory: + experiments: Dict[str, Experiment] = { + "single": SingleExperiment, + "corresponding": CorrespondingExperiment, + "block": BlockExperiment, + "all_layers": AllLayersExperiment, + } + + @classmethod + def create(cls, experiment_type: str) -> Experiment: + experiment_class = cls.experiments.get(experiment_type) + if not experiment_class: + raise ValueError(f"Unknown experiment type: {experiment_type}") + return experiment_class() diff --git a/representations/representation_metrics.py b/representations/representation_metrics.py new file mode 100644 index 00000000..10f6e2ad --- /dev/null +++ b/representations/representation_metrics.py @@ -0,0 +1,25 @@ +from pathlib import Path + +import click +import yaml + +from experiment_setup import Configuration, ExperimentFactory + +def run(config_yml: str = "config"): + mergekit_root = Path(__file__).parent.parent + + config = yaml.safe_load(open((mergekit_root / 'representations' / 'configs' / config_yml).with_suffix('.yml'), 'r')) + config['out_dir'] = mergekit_root / 'representations' / 'results_out' + config['representations_to_analyse'] = mergekit_root / 'representations' / 'representations_to_analyse' + config = Configuration.from_dict(config) + + experiment = ExperimentFactory.create(config.comparison_type.name.lower()) + experiment.run(config) + +@click.command() +@click.option('--config_yml', default="config_i_block", help='path to the configuration file.') +def main(config_yml: str = "config"): + run(config_yml) + +if __name__ == "__main__": + main() diff --git a/representations/store_representations.py b/representations/store_representations.py new file mode 100644 index 00000000..f7c8621d --- /dev/null +++ b/representations/store_representations.py @@ -0,0 +1,115 @@ +import click +import h5py +import numpy as np +from tqdm import tqdm +import torch +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer +import datasets +import os +import random +from pathlib import Path +from mergekit._data.models_and_datasets import save_model_and_dataset, model_and_dataset_to_index +from typing import List +import gc +import uuid + +def set_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def get_last_non_padded_tokens(hidden_states, attention_mask) -> List[torch.Tensor]: + last_non_padded_hidden_states = [] + for layer in hidden_states: + batch_size, _, _ = layer.size() + batch_last_tokens = [] + for batch in range(batch_size): + last_non_pad_index = attention_mask[batch].nonzero(as_tuple=True)[0].max() + last_token = layer[batch, last_non_pad_index, :] + batch_last_tokens.append(last_token.unsqueeze(0)) + last_non_padded_hidden_states.append(torch.cat(batch_last_tokens, dim=0)) + return last_non_padded_hidden_states + +def store_representations(model_name, output_dir, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset): + # Generate the unique ID using UUID + unique_id = uuid.uuid4().hex[:4] + + #!important: Set seed for consistent batch order across runs + set_seed(42) + save_model_and_dataset(model_name, dataset_name) + model_index, dataset_index = model_and_dataset_to_index(model_name, dataset_name) + output_name = Path(output_dir) / f'{model_index}_{dataset_index}_id_{unique_id}.h5'.replace("/","_") + set_seed(42) + assert not output_name.exists(), f'{output_name} already exists.' + for reps_name in output_name.parent.iterdir(): + if f'{model_index}_{dataset_index}' in reps_name.name: + raise ValueError(f'Representations for model {model_index} and dataset {dataset_index} already exist in {output_name.parent}') + os.makedirs(output_name.parent, exist_ok=True) + + + dataset = datasets.load_dataset(dataset_name, split=dataset_subset) + if dataset_size: + dataset = dataset.select(range(dataset_size)) + device = "cuda" if torch.cuda.is_available() \ + else "mps" if torch.backends.mps.is_available() \ + else "cpu" + + model = AutoModelForCausalLM.from_pretrained(model_name, + device_map="auto", + output_hidden_states=True) + model.eval() + + tokenizer = AutoTokenizer.from_pretrained(model_name) + + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + + dataloader = DataLoader(dataset[dataset_column], batch_size=batch_size, shuffle=False, drop_last=True) + + with h5py.File(output_name, 'w') as h5file: + for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")): + inputs = tokenizer(batch, return_tensors="pt", padding="longest", max_length=max_length, truncation=True).to(device) + with torch.no_grad(): + outputs = model(**inputs) + attention_mask = inputs["attention_mask"] + hidden_states = outputs.hidden_states + + # Remove the first element to account for the input layer not being considered a model hidden layer + last_non_padded_hidden_states = get_last_non_padded_tokens(hidden_states, attention_mask)[1:] + + last_non_padded_hidden_states = last_non_padded_hidden_states + for layer, hidden_state in enumerate(last_non_padded_hidden_states): + layer_group = h5file.require_group(f'layer_{layer:03d}') + file_name = f'batch_{batch_idx}.pt' + + layer_group.create_dataset(file_name, data=hidden_state.to('cpu'), compression="gzip") + + # Ensure that the length of last_non_padded_hidden_states matches the number of model hidden layers minus one + assert len(last_non_padded_hidden_states) == model.config.num_hidden_layers, "Length of last_non_padded_hidden_states \ + does not match expected number of hidden layers." + + if torch.cuda.is_available(): + # Clear GPU memory + del model + torch.cuda.empty_cache() + torch.cuda.synchronize() + gc.collect() + +@click.command() +@click.option('--model_name', default="BEE-spoke-data/smol_llama-220M-GQA", help='model to use.') +@click.option('--output_dir', default="./representations/representations_store", help='folder to store the result in.') +@click.option('--dataset_name', default="arcee-ai/sec-data-mini", help='dataset to use.') +@click.option('--batch_size', default=8, help='batch size.') +@click.option('--max_length', default=1024, help='maximum length of the input.') +@click.option('--dataset_size', default=4000, help='size of the dataset.') +@click.option('--dataset_column', default="text", help='column of the dataset to use.') +@click.option('--dataset_subset', default="train", help='subset of the dataset to use.') +def main(model_name, output_dir, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset): + store_representations(model_name, output_dir, dataset_name, batch_size, max_length, dataset_size, dataset_column, dataset_subset) +if __name__ == "__main__": + main() diff --git a/representations/visualise_representation_results.py b/representations/visualise_representation_results.py new file mode 100644 index 00000000..2e26114a --- /dev/null +++ b/representations/visualise_representation_results.py @@ -0,0 +1,21 @@ +import click +from mergekit.plot_tools.plot_tools import create_app, ResultsHandler +from mergekit.metric_methods.base import Results +from pathlib import Path + +def main(input_dir): + handler = ResultsHandler() + for res in Path(input_dir).iterdir(): + results = Results() + results = results.load(res.absolute()) + + handler.load_results(results) + + app = create_app(results_handler=handler) + app.run_server() + +if __name__ == '__main__': + mergekit_root = Path(__file__).parent.parent + input_dir = mergekit_root / 'representations' / 'results_to_visualise' + + main(input_dir) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..4f1e252b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,893 @@ +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# pip-compile --generate-hashes --output-file=requirements.txt pyproject.toml +# +accelerate==0.30.1 \ + --hash=sha256:8dd4edd532a4dac72558c5fe6fe8cb70d0c8ec9e8733f48db97d51ee41cbe763 \ + --hash=sha256:96779c618889646b86dc928c9e55e86e50a7ccab59e1692e22096481977ae682 + # via + # mergekit (pyproject.toml) + # peft +annotated-types==0.7.0 \ + --hash=sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53 \ + --hash=sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89 + # via pydantic +certifi==2024.2.2 \ + --hash=sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f \ + --hash=sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1 + # via requests +charset-normalizer==3.3.2 \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 + # via requests +click==8.1.7 \ + --hash=sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28 \ + --hash=sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de + # via mergekit (pyproject.toml) +filelock==3.14.0 \ + --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ + --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a + # via + # huggingface-hub + # torch + # transformers +fsspec==2024.5.0 \ + --hash=sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a \ + --hash=sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c + # via + # huggingface-hub + # torch +huggingface-hub==0.23.2 \ + --hash=sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827 \ + --hash=sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2 + # via + # accelerate + # mergekit (pyproject.toml) + # peft + # tokenizers + # transformers +idna==3.7 \ + --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ + --hash=sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0 + # via requests +immutables==0.20 \ + --hash=sha256:085ac48ee3eef7baf070f181cae574489bbf65930a83ec5bbd65c9940d625db3 \ + --hash=sha256:1009a4e00e2e69a9b40c2f1272795f5a06ad72c9bf4638594d518e9cbd7a721a \ + --hash=sha256:1d2f83e6a6a8455466cd97b9a90e2b4f7864648616dfa6b19d18f49badac3876 \ + --hash=sha256:2761e3dc2a6406943ce77b3505e9b3c1187846de65d7247548dc7edaa202fcba \ + --hash=sha256:2837b1078abc66d9f009bee9085cf62515d5516af9a5c9ea2751847e16efd236 \ + --hash=sha256:2bcea81e7516bd823b4ed16f4f794531097888675be13e833b1cc946370d5237 \ + --hash=sha256:2dd0dcef2f8d4523d34dbe1d2b7804b3d2a51fddbd104aad13f506a838a2ea15 \ + --hash=sha256:380e2957ba3d63422b2f3fbbff0547c7bbe6479d611d3635c6411005a4264525 \ + --hash=sha256:393dde58ffd6b4c089ffdf4cef5fe73dad37ce4681acffade5f5d5935ec23c93 \ + --hash=sha256:47f56aea56e597ecf6631f24a4e26007b6a5f4fe30278b96eb90bc1f60506164 \ + --hash=sha256:4ba726b7a3a696b9d4b122fa2c956bc68e866f3df1b92765060c88c64410ff82 \ + --hash=sha256:525fb361bd7edc8a891633928d549713af8090c79c25af5cc06eb90b48cb3c64 \ + --hash=sha256:5302ce9c7827f8300f3dc34a695abb71e4a32bab09e65e5ad6e454785383347f \ + --hash=sha256:532be32c7a25dae6cade28825c76d3004cf4d166a0bfacf04bda16056d59ba26 \ + --hash=sha256:5a88adf1dcc9d8ab07dba5e74deefcd5b5e38bc677815cbf9365dc43b69f1f08 \ + --hash=sha256:5bb32aee1ea16fbb90f58f8bd96016bca87aba0a8e574e5fa218d0d83b142851 \ + --hash=sha256:62f8a7a22939278127b7a206d05679b268b9cf665437125625348e902617cbad \ + --hash=sha256:65954eb861c61af48debb1507518d45ae7d594b4fba7282785a70b48c5f51f9b \ + --hash=sha256:83794712f0507416f2818edc63f84305358b8656a93e5b9e2ab056d9803c7507 \ + --hash=sha256:85dd9765b068f7beb297553fddfcf7f904bd58a184c520830a106a58f0c9bfb4 \ + --hash=sha256:96899994842c37cf4b9d6d2bedf685aae7810bd73f1538f8cba5426e2d65cb85 \ + --hash=sha256:9cd2ee9c10bf00be3c94eb51854bc0b761326bd0a7ea0dad4272a3f182269ae6 \ + --hash=sha256:a606410b2ccb6ae339c3f26cccc9a92bcb16dc06f935d51edfd8ca68cf687e50 \ + --hash=sha256:a82afc3945e9ceb9bcd416dc4ed9b72f92760c42787e26de50610a8b81d48120 \ + --hash=sha256:ac86f4372f4cfaa00206c12472fd3a78753092279e0552b7e1880944d71b04fe \ + --hash=sha256:b0436cc831b47e26bef637bcf143cf0273e49946cfb7c28c44486d70513a3080 \ + --hash=sha256:b51aec54b571ae466113509d4dc79a2808dc2ae9263b71fd6b37778cb49eb292 \ + --hash=sha256:c086ccb44d9d3824b9bf816365d10b1b82837efc7119f8bab56bd7a27ed805a9 \ + --hash=sha256:c1214b5a175df783662b7de94b4a82db55cc0ee206dd072fa9e279fb8895d8df \ + --hash=sha256:cc51a01a64a6d2cd7db210a49ad010c2ac2e9e026745f23fd31e0784096dcfff \ + --hash=sha256:d4f78cb748261f852953620ed991de74972446fd484ec69377a41e2f1a1beb75 \ + --hash=sha256:d6449186ea91b7c17ec8e7bd9bf059858298b1db5c053f5d27de8eba077578ce \ + --hash=sha256:d828e7580f1fa203ddeab0b5e91f44bf95706e7f283ca9fbbcf0ae08f63d3084 \ + --hash=sha256:dea0ae4d7f31b145c18c16badeebc2f039d09411be4a8febb86e1244cf7f1ce0 \ + --hash=sha256:e3a5462f6d3549bbf7d02ce929fb0cb6df9539445f0589105de4e8b99b906e69 \ + --hash=sha256:e771198edc11a9e02ffa693911b3918c6cde0b64ad2e6672b076dbe005557ad8 \ + --hash=sha256:e8e82754f72823085643a2c0e6a4c489b806613e94af205825fa81df2ba147a0 \ + --hash=sha256:f063f53b5c0e8f541ae381f1d828f3d05bbed766a2d6c817f9218b8b37a4cb66 \ + --hash=sha256:f17f25f21e82a1c349a61191cfb13e442a348b880b74cb01b00e0d1e848b63f4 \ + --hash=sha256:f349a7e0327b92dcefb863e49ace086f2f26e6689a4e022c98720c6e9696e763 \ + --hash=sha256:fc739fc07cff5df2e4f31addbd48660b5ac0da56e9f719f8bb45da8ddd632c63 + # via mergekit (pyproject.toml) +jinja2==3.1.4 \ + --hash=sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369 \ + --hash=sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d + # via torch +markupsafe==2.1.5 \ + --hash=sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf \ + --hash=sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff \ + --hash=sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f \ + --hash=sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3 \ + --hash=sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532 \ + --hash=sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f \ + --hash=sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617 \ + --hash=sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df \ + --hash=sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4 \ + --hash=sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906 \ + --hash=sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f \ + --hash=sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4 \ + --hash=sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8 \ + --hash=sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371 \ + --hash=sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2 \ + --hash=sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465 \ + --hash=sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52 \ + --hash=sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6 \ + --hash=sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169 \ + --hash=sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad \ + --hash=sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2 \ + --hash=sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0 \ + --hash=sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029 \ + --hash=sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f \ + --hash=sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a \ + --hash=sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced \ + --hash=sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5 \ + --hash=sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c \ + --hash=sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf \ + --hash=sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9 \ + --hash=sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb \ + --hash=sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad \ + --hash=sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3 \ + --hash=sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1 \ + --hash=sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46 \ + --hash=sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc \ + --hash=sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a \ + --hash=sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee \ + --hash=sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900 \ + --hash=sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5 \ + --hash=sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea \ + --hash=sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f \ + --hash=sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5 \ + --hash=sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e \ + --hash=sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a \ + --hash=sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f \ + --hash=sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50 \ + --hash=sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a \ + --hash=sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b \ + --hash=sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4 \ + --hash=sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff \ + --hash=sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2 \ + --hash=sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46 \ + --hash=sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b \ + --hash=sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf \ + --hash=sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5 \ + --hash=sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5 \ + --hash=sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab \ + --hash=sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd \ + --hash=sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68 + # via jinja2 +mpmath==1.3.0 \ + --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ + --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c + # via sympy +networkx==3.3 \ + --hash=sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9 \ + --hash=sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2 + # via torch +numpy==1.26.4 \ + --hash=sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b \ + --hash=sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818 \ + --hash=sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20 \ + --hash=sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0 \ + --hash=sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010 \ + --hash=sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a \ + --hash=sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea \ + --hash=sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c \ + --hash=sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71 \ + --hash=sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110 \ + --hash=sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be \ + --hash=sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a \ + --hash=sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a \ + --hash=sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5 \ + --hash=sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed \ + --hash=sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd \ + --hash=sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c \ + --hash=sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e \ + --hash=sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0 \ + --hash=sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c \ + --hash=sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a \ + --hash=sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b \ + --hash=sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0 \ + --hash=sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6 \ + --hash=sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2 \ + --hash=sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a \ + --hash=sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30 \ + --hash=sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218 \ + --hash=sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5 \ + --hash=sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07 \ + --hash=sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2 \ + --hash=sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4 \ + --hash=sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764 \ + --hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \ + --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ + --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f + # via + # accelerate + # peft + # transformers +packaging==24.0 \ + --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ + --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 + # via + # accelerate + # huggingface-hub + # peft + # transformers +peft==0.11.1 \ + --hash=sha256:76f2d2a4c9e0644e2741465663b8a02097775e9725d26d7b41551e6f1e72e7dd \ + --hash=sha256:c1a04462e589a1305a06f7c118be0b8602b829f9bfc2104b5c6514c7678c2310 + # via mergekit (pyproject.toml) +protobuf==5.27.0 \ + --hash=sha256:07f2b9a15255e3cf3f137d884af7972407b556a7a220912b252f26dc3121e6bf \ + --hash=sha256:2f83bf341d925650d550b8932b71763321d782529ac0eaf278f5242f513cc04e \ + --hash=sha256:56937f97ae0dcf4e220ff2abb1456c51a334144c9960b23597f044ce99c29c89 \ + --hash=sha256:587be23f1212da7a14a6c65fd61995f8ef35779d4aea9e36aad81f5f3b80aec5 \ + --hash=sha256:673ad60f1536b394b4fa0bcd3146a4130fcad85bfe3b60eaa86d6a0ace0fa374 \ + --hash=sha256:744489f77c29174328d32f8921566fb0f7080a2f064c5137b9d6f4b790f9e0c1 \ + --hash=sha256:7cb65fc8fba680b27cf7a07678084c6e68ee13cab7cace734954c25a43da6d0f \ + --hash=sha256:a17f4d664ea868102feaa30a674542255f9f4bf835d943d588440d1f49a3ed15 \ + --hash=sha256:aabbbcf794fbb4c692ff14ce06780a66d04758435717107c387f12fb477bf0d8 \ + --hash=sha256:b276e3f477ea1eebff3c2e1515136cfcff5ac14519c45f9b4aa2f6a87ea627c4 \ + --hash=sha256:f51f33d305e18646f03acfdb343aac15b8115235af98bc9f844bf9446573827b + # via mergekit (pyproject.toml) +psutil==5.9.8 \ + --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ + --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ + --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ + --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ + --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ + --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ + --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ + --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ + --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ + --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ + --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ + --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ + --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ + --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ + --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ + --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 + # via + # accelerate + # peft +pydantic==2.7.1 \ + --hash=sha256:e029badca45266732a9a79898a15ae2e8b14840b1eabbb25844be28f0b33f3d5 \ + --hash=sha256:e9dbb5eada8abe4d9ae5f46b9939aead650cd2b68f249bb3a8139dbe125803cc + # via mergekit (pyproject.toml) +pydantic-core==2.18.2 \ + --hash=sha256:0098300eebb1c837271d3d1a2cd2911e7c11b396eac9661655ee524a7f10587b \ + --hash=sha256:042473b6280246b1dbf530559246f6842b56119c2926d1e52b631bdc46075f2a \ + --hash=sha256:05b7133a6e6aeb8df37d6f413f7705a37ab4031597f64ab56384c94d98fa0e90 \ + --hash=sha256:0680b1f1f11fda801397de52c36ce38ef1c1dc841a0927a94f226dea29c3ae3d \ + --hash=sha256:0d69b4c2f6bb3e130dba60d34c0845ba31b69babdd3f78f7c0c8fae5021a253e \ + --hash=sha256:1404c69d6a676245199767ba4f633cce5f4ad4181f9d0ccb0577e1f66cf4c46d \ + --hash=sha256:182245ff6b0039e82b6bb585ed55a64d7c81c560715d1bad0cbad6dfa07b4027 \ + --hash=sha256:1a388a77e629b9ec814c1b1e6b3b595fe521d2cdc625fcca26fbc2d44c816804 \ + --hash=sha256:1d90c3265ae107f91a4f279f4d6f6f1d4907ac76c6868b27dc7fb33688cfb347 \ + --hash=sha256:20aca1e2298c56ececfd8ed159ae4dde2df0781988c97ef77d5c16ff4bd5b400 \ + --hash=sha256:219da3f096d50a157f33645a1cf31c0ad1fe829a92181dd1311022f986e5fbe3 \ + --hash=sha256:22057013c8c1e272eb8d0eebc796701167d8377441ec894a8fed1af64a0bf399 \ + --hash=sha256:223ee893d77a310a0391dca6df00f70bbc2f36a71a895cecd9a0e762dc37b349 \ + --hash=sha256:224c421235f6102e8737032483f43c1a8cfb1d2f45740c44166219599358c2cd \ + --hash=sha256:2334ce8c673ee93a1d6a65bd90327588387ba073c17e61bf19b4fd97d688d63c \ + --hash=sha256:269322dcc3d8bdb69f054681edff86276b2ff972447863cf34c8b860f5188e2e \ + --hash=sha256:2728b01246a3bba6de144f9e3115b532ee44bd6cf39795194fb75491824a1413 \ + --hash=sha256:2b8ed04b3582771764538f7ee7001b02e1170223cf9b75dff0bc698fadb00cf3 \ + --hash=sha256:2e29d20810dfc3043ee13ac7d9e25105799817683348823f305ab3f349b9386e \ + --hash=sha256:36789b70d613fbac0a25bb07ab3d9dba4d2e38af609c020cf4d888d165ee0bf3 \ + --hash=sha256:390193c770399861d8df9670fb0d1874f330c79caaca4642332df7c682bf6b91 \ + --hash=sha256:3a6515ebc6e69d85502b4951d89131ca4e036078ea35533bb76327f8424531ce \ + --hash=sha256:3f9a801e7c8f1ef8718da265bba008fa121243dfe37c1cea17840b0944dfd72c \ + --hash=sha256:43f0f463cf89ace478de71a318b1b4f05ebc456a9b9300d027b4b57c1a2064fb \ + --hash=sha256:4456f2dca97c425231d7315737d45239b2b51a50dc2b6f0c2bb181fce6207664 \ + --hash=sha256:470b94480bb5ee929f5acba6995251ada5e059a5ef3e0dfc63cca287283ebfa6 \ + --hash=sha256:4774f3184d2ef3e14e8693194f661dea5a4d6ca4e3dc8e39786d33a94865cefd \ + --hash=sha256:4b4356d3538c3649337df4074e81b85f0616b79731fe22dd11b99499b2ebbdf3 \ + --hash=sha256:553ef617b6836fc7e4df130bb851e32fe357ce36336d897fd6646d6058d980af \ + --hash=sha256:6132dd3bd52838acddca05a72aafb6eab6536aa145e923bb50f45e78b7251043 \ + --hash=sha256:6a46e22a707e7ad4484ac9ee9f290f9d501df45954184e23fc29408dfad61350 \ + --hash=sha256:6e5c584d357c4e2baf0ff7baf44f4994be121e16a2c88918a5817331fc7599d7 \ + --hash=sha256:75250dbc5290e3f1a0f4618db35e51a165186f9034eff158f3d490b3fed9f8a0 \ + --hash=sha256:75f7e9488238e920ab6204399ded280dc4c307d034f3924cd7f90a38b1829563 \ + --hash=sha256:78363590ef93d5d226ba21a90a03ea89a20738ee5b7da83d771d283fd8a56761 \ + --hash=sha256:7ca4ae5a27ad7a4ee5170aebce1574b375de390bc01284f87b18d43a3984df72 \ + --hash=sha256:800d60565aec896f25bc3cfa56d2277d52d5182af08162f7954f938c06dc4ee3 \ + --hash=sha256:82d5d4d78e4448683cb467897fe24e2b74bb7b973a541ea1dcfec1d3cbce39fb \ + --hash=sha256:852e966fbd035a6468fc0a3496589b45e2208ec7ca95c26470a54daed82a0788 \ + --hash=sha256:868649da93e5a3d5eacc2b5b3b9235c98ccdbfd443832f31e075f54419e1b96b \ + --hash=sha256:886eec03591b7cf058467a70a87733b35f44707bd86cf64a615584fd72488b7c \ + --hash=sha256:8b172601454f2d7701121bbec3425dd71efcb787a027edf49724c9cefc14c038 \ + --hash=sha256:95b9d5e72481d3780ba3442eac863eae92ae43a5f3adb5b4d0a1de89d42bb250 \ + --hash=sha256:98758d627ff397e752bc339272c14c98199c613f922d4a384ddc07526c86a2ec \ + --hash=sha256:997abc4df705d1295a42f95b4eec4950a37ad8ae46d913caeee117b6b198811c \ + --hash=sha256:9b5155ff768083cb1d62f3e143b49a8a3432e6789a3abee8acd005c3c7af1c74 \ + --hash=sha256:9e08e867b306f525802df7cd16c44ff5ebbe747ff0ca6cf3fde7f36c05a59a81 \ + --hash=sha256:9fdad8e35f278b2c3eb77cbdc5c0a49dada440657bf738d6905ce106dc1de439 \ + --hash=sha256:a1874c6dd4113308bd0eb568418e6114b252afe44319ead2b4081e9b9521fe75 \ + --hash=sha256:a8309f67285bdfe65c372ea3722b7a5642680f3dba538566340a9d36e920b5f0 \ + --hash=sha256:ae0a8a797a5e56c053610fa7be147993fe50960fa43609ff2a9552b0e07013e8 \ + --hash=sha256:b14d82cdb934e99dda6d9d60dc84a24379820176cc4a0d123f88df319ae9c150 \ + --hash=sha256:b1bd7e47b1558ea872bd16c8502c414f9e90dcf12f1395129d7bb42a09a95438 \ + --hash=sha256:b3ef08e20ec49e02d5c6717a91bb5af9b20f1805583cb0adfe9ba2c6b505b5ae \ + --hash=sha256:b89ed9eb7d616ef5714e5590e6cf7f23b02d0d539767d33561e3675d6f9e3857 \ + --hash=sha256:c4fcf5cd9c4b655ad666ca332b9a081112cd7a58a8b5a6ca7a3104bc950f2038 \ + --hash=sha256:c6fdc8627910eed0c01aed6a390a252fe3ea6d472ee70fdde56273f198938374 \ + --hash=sha256:c9bd70772c720142be1020eac55f8143a34ec9f82d75a8e7a07852023e46617f \ + --hash=sha256:ca7b0c1f1c983e064caa85f3792dd2fe3526b3505378874afa84baf662e12241 \ + --hash=sha256:cbca948f2d14b09d20268cda7b0367723d79063f26c4ffc523af9042cad95592 \ + --hash=sha256:cc1cfd88a64e012b74e94cd00bbe0f9c6df57049c97f02bb07d39e9c852e19a4 \ + --hash=sha256:ccdd111c03bfd3666bd2472b674c6899550e09e9f298954cfc896ab92b5b0e6d \ + --hash=sha256:cfeecd1ac6cc1fb2692c3d5110781c965aabd4ec5d32799773ca7b1456ac636b \ + --hash=sha256:d4d938ec0adf5167cb335acb25a4ee69a8107e4984f8fbd2e897021d9e4ca21b \ + --hash=sha256:d7d904828195733c183d20a54230c0df0eb46ec746ea1a666730787353e87182 \ + --hash=sha256:d91cb5ea8b11607cc757675051f61b3d93f15eca3cefb3e6c704a5d6e8440f4e \ + --hash=sha256:d9319e499827271b09b4e411905b24a426b8fb69464dfa1696258f53a3334641 \ + --hash=sha256:e0e8b1be28239fc64a88a8189d1df7fad8be8c1ae47fcc33e43d4be15f99cc70 \ + --hash=sha256:e18609ceaa6eed63753037fc06ebb16041d17d28199ae5aba0052c51449650a9 \ + --hash=sha256:e1b395e58b10b73b07b7cf740d728dd4ff9365ac46c18751bf8b3d8cca8f625a \ + --hash=sha256:e23ec367a948b6d812301afc1b13f8094ab7b2c280af66ef450efc357d2ae543 \ + --hash=sha256:e25add29b8f3b233ae90ccef2d902d0ae0432eb0d45370fe315d1a5cf231004b \ + --hash=sha256:e6dac87ddb34aaec85f873d737e9d06a3555a1cc1a8e0c44b7f8d5daeb89d86f \ + --hash=sha256:ef26c9e94a8c04a1b2924149a9cb081836913818e55681722d7f29af88fe7b38 \ + --hash=sha256:eff2de745698eb46eeb51193a9f41d67d834d50e424aef27df2fcdee1b153845 \ + --hash=sha256:f0a21cbaa69900cbe1a2e7cad2aa74ac3cf21b10c3efb0fa0b80305274c0e8a2 \ + --hash=sha256:f459a5ce8434614dfd39bbebf1041952ae01da6bed9855008cb33b875cb024c0 \ + --hash=sha256:f93a8a2e3938ff656a7c1bc57193b1319960ac015b6e87d76c76bf14fe0244b4 \ + --hash=sha256:fb2bd7be70c0fe4dfd32c951bc813d9fe6ebcbfdd15a07527796c8204bd36242 + # via pydantic +pyyaml==6.0.1 \ + --hash=sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5 \ + --hash=sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc \ + --hash=sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df \ + --hash=sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741 \ + --hash=sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206 \ + --hash=sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27 \ + --hash=sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595 \ + --hash=sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62 \ + --hash=sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98 \ + --hash=sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696 \ + --hash=sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290 \ + --hash=sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9 \ + --hash=sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d \ + --hash=sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6 \ + --hash=sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867 \ + --hash=sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47 \ + --hash=sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486 \ + --hash=sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6 \ + --hash=sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3 \ + --hash=sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007 \ + --hash=sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938 \ + --hash=sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0 \ + --hash=sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c \ + --hash=sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735 \ + --hash=sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d \ + --hash=sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28 \ + --hash=sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4 \ + --hash=sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba \ + --hash=sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8 \ + --hash=sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef \ + --hash=sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5 \ + --hash=sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd \ + --hash=sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3 \ + --hash=sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0 \ + --hash=sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515 \ + --hash=sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c \ + --hash=sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c \ + --hash=sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924 \ + --hash=sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34 \ + --hash=sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43 \ + --hash=sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859 \ + --hash=sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673 \ + --hash=sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54 \ + --hash=sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a \ + --hash=sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b \ + --hash=sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab \ + --hash=sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa \ + --hash=sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c \ + --hash=sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585 \ + --hash=sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d \ + --hash=sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f + # via + # accelerate + # huggingface-hub + # peft + # transformers +regex==2024.5.15 \ + --hash=sha256:0721931ad5fe0dda45d07f9820b90b2148ccdd8e45bb9e9b42a146cb4f695649 \ + --hash=sha256:10002e86e6068d9e1c91eae8295ef690f02f913c57db120b58fdd35a6bb1af35 \ + --hash=sha256:10e4ce0dca9ae7a66e6089bb29355d4432caed736acae36fef0fdd7879f0b0cb \ + --hash=sha256:119af6e56dce35e8dfb5222573b50c89e5508d94d55713c75126b753f834de68 \ + --hash=sha256:1337b7dbef9b2f71121cdbf1e97e40de33ff114801263b275aafd75303bd62b5 \ + --hash=sha256:13cdaf31bed30a1e1c2453ef6015aa0983e1366fad2667657dbcac7b02f67133 \ + --hash=sha256:1595f2d10dff3d805e054ebdc41c124753631b6a471b976963c7b28543cf13b0 \ + --hash=sha256:16093f563098448ff6b1fa68170e4acbef94e6b6a4e25e10eae8598bb1694b5d \ + --hash=sha256:1878b8301ed011704aea4c806a3cadbd76f84dece1ec09cc9e4dc934cfa5d4da \ + --hash=sha256:19068a6a79cf99a19ccefa44610491e9ca02c2be3305c7760d3831d38a467a6f \ + --hash=sha256:19dfb1c504781a136a80ecd1fff9f16dddf5bb43cec6871778c8a907a085bb3d \ + --hash=sha256:1b5269484f6126eee5e687785e83c6b60aad7663dafe842b34691157e5083e53 \ + --hash=sha256:1c1c174d6ec38d6c8a7504087358ce9213d4332f6293a94fbf5249992ba54efa \ + --hash=sha256:2431b9e263af1953c55abbd3e2efca67ca80a3de8a0437cb58e2421f8184717a \ + --hash=sha256:287eb7f54fc81546346207c533ad3c2c51a8d61075127d7f6d79aaf96cdee890 \ + --hash=sha256:2b4c884767504c0e2401babe8b5b7aea9148680d2e157fa28f01529d1f7fcf67 \ + --hash=sha256:35cb514e137cb3488bce23352af3e12fb0dbedd1ee6e60da053c69fb1b29cc6c \ + --hash=sha256:391d7f7f1e409d192dba8bcd42d3e4cf9e598f3979cdaed6ab11288da88cb9f2 \ + --hash=sha256:3ad070b823ca5890cab606c940522d05d3d22395d432f4aaaf9d5b1653e47ced \ + --hash=sha256:3cd7874d57f13bf70078f1ff02b8b0aa48d5b9ed25fc48547516c6aba36f5741 \ + --hash=sha256:3e507ff1e74373c4d3038195fdd2af30d297b4f0950eeda6f515ae3d84a1770f \ + --hash=sha256:455705d34b4154a80ead722f4f185b04c4237e8e8e33f265cd0798d0e44825fa \ + --hash=sha256:4a605586358893b483976cffc1723fb0f83e526e8f14c6e6614e75919d9862cf \ + --hash=sha256:4babf07ad476aaf7830d77000874d7611704a7fcf68c9c2ad151f5d94ae4bfc4 \ + --hash=sha256:4eee78a04e6c67e8391edd4dad3279828dd66ac4b79570ec998e2155d2e59fd5 \ + --hash=sha256:5397de3219a8b08ae9540c48f602996aa6b0b65d5a61683e233af8605c42b0f2 \ + --hash=sha256:5b5467acbfc153847d5adb21e21e29847bcb5870e65c94c9206d20eb4e99a384 \ + --hash=sha256:5eaa7ddaf517aa095fa8da0b5015c44d03da83f5bd49c87961e3c997daed0de7 \ + --hash=sha256:632b01153e5248c134007209b5c6348a544ce96c46005d8456de1d552455b014 \ + --hash=sha256:64c65783e96e563103d641760664125e91bd85d8e49566ee560ded4da0d3e704 \ + --hash=sha256:64f18a9a3513a99c4bef0e3efd4c4a5b11228b48aa80743be822b71e132ae4f5 \ + --hash=sha256:673b5a6da4557b975c6c90198588181029c60793835ce02f497ea817ff647cb2 \ + --hash=sha256:68811ab14087b2f6e0fc0c2bae9ad689ea3584cad6917fc57be6a48bbd012c49 \ + --hash=sha256:6e8d717bca3a6e2064fc3a08df5cbe366369f4b052dcd21b7416e6d71620dca1 \ + --hash=sha256:71a455a3c584a88f654b64feccc1e25876066c4f5ef26cd6dd711308aa538694 \ + --hash=sha256:72d7a99cd6b8f958e85fc6ca5b37c4303294954eac1376535b03c2a43eb72629 \ + --hash=sha256:7b59138b219ffa8979013be7bc85bb60c6f7b7575df3d56dc1e403a438c7a3f6 \ + --hash=sha256:7dbe2467273b875ea2de38ded4eba86cbcbc9a1a6d0aa11dcf7bd2e67859c435 \ + --hash=sha256:833616ddc75ad595dee848ad984d067f2f31be645d603e4d158bba656bbf516c \ + --hash=sha256:87e2a9c29e672fc65523fb47a90d429b70ef72b901b4e4b1bd42387caf0d6835 \ + --hash=sha256:8fe45aa3f4aa57faabbc9cb46a93363edd6197cbc43523daea044e9ff2fea83e \ + --hash=sha256:9e717956dcfd656f5055cc70996ee2cc82ac5149517fc8e1b60261b907740201 \ + --hash=sha256:9efa1a32ad3a3ea112224897cdaeb6aa00381627f567179c0314f7b65d354c62 \ + --hash=sha256:9ff11639a8d98969c863d4617595eb5425fd12f7c5ef6621a4b74b71ed8726d5 \ + --hash=sha256:a094801d379ab20c2135529948cb84d417a2169b9bdceda2a36f5f10977ebc16 \ + --hash=sha256:a0981022dccabca811e8171f913de05720590c915b033b7e601f35ce4ea7019f \ + --hash=sha256:a0bd000c6e266927cb7a1bc39d55be95c4b4f65c5be53e659537537e019232b1 \ + --hash=sha256:a32b96f15c8ab2e7d27655969a23895eb799de3665fa94349f3b2fbfd547236f \ + --hash=sha256:a81e3cfbae20378d75185171587cbf756015ccb14840702944f014e0d93ea09f \ + --hash=sha256:ac394ff680fc46b97487941f5e6ae49a9f30ea41c6c6804832063f14b2a5a145 \ + --hash=sha256:ada150c5adfa8fbcbf321c30c751dc67d2f12f15bd183ffe4ec7cde351d945b3 \ + --hash=sha256:b2b6f1b3bb6f640c1a92be3bbfbcb18657b125b99ecf141fb3310b5282c7d4ed \ + --hash=sha256:b802512f3e1f480f41ab5f2cfc0e2f761f08a1f41092d6718868082fc0d27143 \ + --hash=sha256:ba68168daedb2c0bab7fd7e00ced5ba90aebf91024dea3c88ad5063c2a562cca \ + --hash=sha256:bfc4f82cabe54f1e7f206fd3d30fda143f84a63fe7d64a81558d6e5f2e5aaba9 \ + --hash=sha256:c0c18345010870e58238790a6779a1219b4d97bd2e77e1140e8ee5d14df071aa \ + --hash=sha256:c3bea0ba8b73b71b37ac833a7f3fd53825924165da6a924aec78c13032f20850 \ + --hash=sha256:c486b4106066d502495b3025a0a7251bf37ea9540433940a23419461ab9f2a80 \ + --hash=sha256:c49e15eac7c149f3670b3e27f1f28a2c1ddeccd3a2812cba953e01be2ab9b5fe \ + --hash=sha256:c6a2b494a76983df8e3d3feea9b9ffdd558b247e60b92f877f93a1ff43d26656 \ + --hash=sha256:cab12877a9bdafde5500206d1020a584355a97884dfd388af3699e9137bf7388 \ + --hash=sha256:cac27dcaa821ca271855a32188aa61d12decb6fe45ffe3e722401fe61e323cd1 \ + --hash=sha256:cdd09d47c0b2efee9378679f8510ee6955d329424c659ab3c5e3a6edea696294 \ + --hash=sha256:cf2430df4148b08fb4324b848672514b1385ae3807651f3567871f130a728cc3 \ + --hash=sha256:d0a3d8d6acf0c78a1fff0e210d224b821081330b8524e3e2bc5a68ef6ab5803d \ + --hash=sha256:d0c0c0003c10f54a591d220997dd27d953cd9ccc1a7294b40a4be5312be8797b \ + --hash=sha256:d1f059a4d795e646e1c37665b9d06062c62d0e8cc3c511fe01315973a6542e40 \ + --hash=sha256:d347a741ea871c2e278fde6c48f85136c96b8659b632fb57a7d1ce1872547600 \ + --hash=sha256:d3ee02d9e5f482cc8309134a91eeaacbdd2261ba111b0fef3748eeb4913e6a2c \ + --hash=sha256:d99ceffa25ac45d150e30bd9ed14ec6039f2aad0ffa6bb87a5936f5782fc1569 \ + --hash=sha256:e38a7d4e8f633a33b4c7350fbd8bad3b70bf81439ac67ac38916c4a86b465456 \ + --hash=sha256:e4682f5ba31f475d58884045c1a97a860a007d44938c4c0895f41d64481edbc9 \ + --hash=sha256:e5bb9425fe881d578aeca0b2b4b3d314ec88738706f66f219c194d67179337cb \ + --hash=sha256:e64198f6b856d48192bf921421fdd8ad8eb35e179086e99e99f711957ffedd6e \ + --hash=sha256:e6662686aeb633ad65be2a42b4cb00178b3fbf7b91878f9446075c404ada552f \ + --hash=sha256:ec54d5afa89c19c6dd8541a133be51ee1017a38b412b1321ccb8d6ddbeb4cf7d \ + --hash=sha256:f5b1dff3ad008dccf18e652283f5e5339d70bf8ba7c98bf848ac33db10f7bc7a \ + --hash=sha256:f8ec0c2fea1e886a19c3bee0cd19d862b3aa75dcdfb42ebe8ed30708df64687a \ + --hash=sha256:f9ebd0a36102fcad2f03696e8af4ae682793a5d30b46c647eaf280d6cfb32796 + # via transformers +requests==2.32.2 \ + --hash=sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289 \ + --hash=sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c + # via + # huggingface-hub + # transformers +safetensors==0.4.3 \ + --hash=sha256:018b691383026a2436a22b648873ed11444a364324e7088b99cd2503dd828400 \ + --hash=sha256:01e4b22e3284cd866edeabe4f4d896229495da457229408d2e1e4810c5187121 \ + --hash=sha256:01feb3089e5932d7e662eda77c3ecc389f97c0883c4a12b5cfdc32b589a811c3 \ + --hash=sha256:02318f01e332cc23ffb4f6716e05a492c5f18b1d13e343c49265149396284a44 \ + --hash=sha256:02ef3a24face643456020536591fbd3c717c5abaa2737ec428ccbbc86dffa7a4 \ + --hash=sha256:03a4447c784917c9bf01d8f2ac5080bc15c41692202cd5f406afba16629e84d6 \ + --hash=sha256:084fc436e317f83f7071fc6a62ca1c513b2103db325cd09952914b50f51cf78f \ + --hash=sha256:0bf4f9d6323d9f86eef5567eabd88f070691cf031d4c0df27a40d3b4aaee755b \ + --hash=sha256:0d52c958dc210265157573f81d34adf54e255bc2b59ded6218500c9b15a750eb \ + --hash=sha256:0d5ffc6a80f715c30af253e0e288ad1cd97a3d0086c9c87995e5093ebc075e50 \ + --hash=sha256:0d9cd8e1560dfc514b6d7859247dc6a86ad2f83151a62c577428d5102d872721 \ + --hash=sha256:0dd37306546b58d3043eb044c8103a02792cc024b51d1dd16bd3dd1f334cb3ed \ + --hash=sha256:1139eb436fd201c133d03c81209d39ac57e129f5e74e34bb9ab60f8d9b726270 \ + --hash=sha256:19bbdf95de2cf64f25cd614c5236c8b06eb2cfa47cbf64311f4b5d80224623a3 \ + --hash=sha256:1ab6527a20586d94291c96e00a668fa03f86189b8a9defa2cdd34a1a01acc7d5 \ + --hash=sha256:1b89381517891a7bb7d1405d828b2bf5d75528299f8231e9346b8eba092227f9 \ + --hash=sha256:1f598b713cc1a4eb31d3b3203557ac308acf21c8f41104cdd74bf640c6e538e3 \ + --hash=sha256:22d21760dc6ebae42e9c058d75aa9907d9f35e38f896e3c69ba0e7b213033856 \ + --hash=sha256:22f3b5d65e440cec0de8edaa672efa888030802e11c09b3d6203bff60ebff05a \ + --hash=sha256:2a0deb16a1d3ea90c244ceb42d2c6c276059616be21a19ac7101aa97da448faf \ + --hash=sha256:2a1f4430cc0c9d6afa01214a4b3919d0a029637df8e09675ceef1ca3f0dfa0df \ + --hash=sha256:2d603846a8585b9432a0fd415db1d4c57c0f860eb4aea21f92559ff9902bae4d \ + --hash=sha256:2f85fc50c4e07a21e95c24e07460fe6f7e2859d0ce88092838352b798ce711c2 \ + --hash=sha256:309b10dbcab63269ecbf0e2ca10ce59223bb756ca5d431ce9c9eeabd446569da \ + --hash=sha256:3615a96dd2dcc30eb66d82bc76cda2565f4f7bfa89fcb0e31ba3cea8a1a9ecbb \ + --hash=sha256:38e2a8666178224a51cca61d3cb4c88704f696eac8f72a49a598a93bbd8a4af9 \ + --hash=sha256:393e6e391467d1b2b829c77e47d726f3b9b93630e6a045b1d1fca67dc78bf632 \ + --hash=sha256:3f9cdca09052f585e62328c1c2923c70f46814715c795be65f0b93f57ec98a02 \ + --hash=sha256:41a727a7f5e6ad9f1db6951adee21bbdadc632363d79dc434876369a17de6ad6 \ + --hash=sha256:420a98f593ff9930f5822560d14c395ccbc57342ddff3b463bc0b3d6b1951550 \ + --hash=sha256:446e9fe52c051aeab12aac63d1017e0f68a02a92a027b901c4f8e931b24e5397 \ + --hash=sha256:455d538aa1aae4a8b279344a08136d3f16334247907b18a5c3c7fa88ef0d3c46 \ + --hash=sha256:4f9bac020faba7f5dc481e881b14b6425265feabb5bfc552551d21189c0eddc3 \ + --hash=sha256:53c4879b9c6bd7cd25d114ee0ef95420e2812e676314300624594940a8d6a91f \ + --hash=sha256:5757e4688f20df083e233b47de43845d1adb7e17b6cf7da5f8444416fc53828d \ + --hash=sha256:585c9ae13a205807b63bef8a37994f30c917ff800ab8a1ca9c9b5d73024f97ee \ + --hash=sha256:5d07cbca5b99babb692d76d8151bec46f461f8ad8daafbfd96b2fca40cadae65 \ + --hash=sha256:5fc6775529fb9f0ce2266edd3e5d3f10aab068e49f765e11f6f2a63b5367021d \ + --hash=sha256:622afd28968ef3e9786562d352659a37de4481a4070f4ebac883f98c5836563e \ + --hash=sha256:6f9568f380f513a60139971169c4a358b8731509cc19112369902eddb33faa4d \ + --hash=sha256:70a5319ef409e7f88686a46607cbc3c428271069d8b770076feaf913664a07ac \ + --hash=sha256:74707624b81f1b7f2b93f5619d4a9f00934d5948005a03f2c1845ffbfff42212 \ + --hash=sha256:7c4fa560ebd4522adddb71dcd25d09bf211b5634003f015a4b815b7647d62ebe \ + --hash=sha256:7de32d0d34b6623bb56ca278f90db081f85fb9c5d327e3c18fd23ac64f465768 \ + --hash=sha256:840b7ac0eff5633e1d053cc9db12fdf56b566e9403b4950b2dc85393d9b88d67 \ + --hash=sha256:840caf38d86aa7014fe37ade5d0d84e23dcfbc798b8078015831996ecbc206a3 \ + --hash=sha256:8651c7299cbd8b4161a36cd6a322fa07d39cd23535b144d02f1c1972d0c62f3c \ + --hash=sha256:868ad1b6fc41209ab6bd12f63923e8baeb1a086814cb2e81a65ed3d497e0cf8f \ + --hash=sha256:88887f69f7a00cf02b954cdc3034ffb383b2303bc0ab481d4716e2da51ddc10e \ + --hash=sha256:89f9f17b0dacb913ed87d57afbc8aad85ea42c1085bd5de2f20d83d13e9fc4b2 \ + --hash=sha256:8c496c5401c1b9c46d41a7688e8ff5b0310a3b9bae31ce0f0ae870e1ea2b8caf \ + --hash=sha256:8cf18888606dad030455d18f6c381720e57fc6a4170ee1966adb7ebc98d4d6a3 \ + --hash=sha256:8d22c1a10dff3f64d0d68abb8298a3fd88ccff79f408a3e15b3e7f637ef5c980 \ + --hash=sha256:90964917f5b0fa0fa07e9a051fbef100250c04d150b7026ccbf87a34a54012e0 \ + --hash=sha256:9bfb92f82574d9e58401d79c70c716985dc049b635fef6eecbb024c79b2c46ad \ + --hash=sha256:9c6ad011c1b4e3acff058d6b090f1da8e55a332fbf84695cf3100c649cc452d1 \ + --hash=sha256:a11c374eb63a9c16c5ed146457241182f310902bd2a9c18255781bb832b6748b \ + --hash=sha256:a7cef55929dcbef24af3eb40bedec35d82c3c2fa46338bb13ecf3c5720af8a61 \ + --hash=sha256:a844cdb5d7cbc22f5f16c7e2a0271170750763c4db08381b7f696dbd2c78a361 \ + --hash=sha256:ae7613a119a71a497d012ccc83775c308b9c1dab454806291427f84397d852fd \ + --hash=sha256:b1648568667f820b8c48317c7006221dc40aced1869908c187f493838a1362bc \ + --hash=sha256:b1e31be7945f66be23f4ec1682bb47faa3df34cb89fc68527de6554d3c4258a4 \ + --hash=sha256:b277482120df46e27a58082df06a15aebda4481e30a1c21eefd0921ae7e03f65 \ + --hash=sha256:b7ffba80aa49bd09195145a7fd233a7781173b422eeb995096f2b30591639517 \ + --hash=sha256:b852e47eb08475c2c1bd8131207b405793bfc20d6f45aff893d3baaad449ed14 \ + --hash=sha256:bb4f8c5d0358a31e9a08daeebb68f5e161cdd4018855426d3f0c23bb51087055 \ + --hash=sha256:bbae3b4b9d997971431c346edbfe6e41e98424a097860ee872721e176040a893 \ + --hash=sha256:befdf0167ad626f22f6aac6163477fcefa342224a22f11fdd05abb3995c1783c \ + --hash=sha256:c0acbe31340ab150423347e5b9cc595867d814244ac14218932a5cf1dd38eb39 \ + --hash=sha256:c41e1893d1206aa7054029681778d9a58b3529d4c807002c156d58426c225173 \ + --hash=sha256:c59d51f182c729f47e841510b70b967b0752039f79f1de23bcdd86462a9b09ee \ + --hash=sha256:cd6fff9e56df398abc5866b19a32124815b656613c1c5ec0f9350906fd798aac \ + --hash=sha256:cdd0a3b5da66e7f377474599814dbf5cbf135ff059cc73694de129b58a5e8a2c \ + --hash=sha256:cf476bca34e1340ee3294ef13e2c625833f83d096cfdf69a5342475602004f95 \ + --hash=sha256:d0dd4a1db09db2dba0f94d15addc7e7cd3a7b0d393aa4c7518c39ae7374623c3 \ + --hash=sha256:d1456f814655b224d4bf6e7915c51ce74e389b413be791203092b7ff78c936dd \ + --hash=sha256:d14d30c25897b2bf19b6fb5ff7e26cc40006ad53fd4a88244fdf26517d852dd7 \ + --hash=sha256:d244bcafeb1bc06d47cfee71727e775bca88a8efda77a13e7306aae3813fa7e4 \ + --hash=sha256:d8815b5e1dac85fc534a97fd339e12404db557878c090f90442247e87c8aeaea \ + --hash=sha256:d88b33980222085dd6001ae2cad87c6068e0991d4f5ccf44975d216db3b57376 \ + --hash=sha256:d8c5093206ef4b198600ae484230402af6713dab1bd5b8e231905d754022bec7 \ + --hash=sha256:d9c289f140a9ae4853fc2236a2ffc9a9f2d5eae0cb673167e0f1b8c18c0961ac \ + --hash=sha256:dcf5705cab159ce0130cd56057f5f3425023c407e170bca60b4868048bae64fd \ + --hash=sha256:e011cc162503c19f4b1fd63dfcddf73739c7a243a17dac09b78e57a00983ab35 \ + --hash=sha256:e066e8861eef6387b7c772344d1fe1f9a72800e04ee9a54239d460c400c72aab \ + --hash=sha256:e0b2104df1579d6ba9052c0ae0e3137c9698b2d85b0645507e6fd1813b70931a \ + --hash=sha256:e375d975159ac534c7161269de24ddcd490df2157b55c1a6eeace6cbb56903f0 \ + --hash=sha256:e4119532cd10dba04b423e0f86aecb96cfa5a602238c0aa012f70c3a40c44b50 \ + --hash=sha256:e7dbbde64b6c534548696808a0e01276d28ea5773bc9a2dfb97a88cd3dffe3df \ + --hash=sha256:e9afd5358719f1b2cf425fad638fc3c887997d6782da317096877e5b15b2ce93 \ + --hash=sha256:ec4b52ce9a396260eb9731eb6aea41a7320de22ed73a1042c2230af0212758ce \ + --hash=sha256:edb5698a7bc282089f64c96c477846950358a46ede85a1c040e0230344fdde10 \ + --hash=sha256:ee463219d9ec6c2be1d331ab13a8e0cd50d2f32240a81d498266d77d07b7e71e \ + --hash=sha256:efcc860be094b8d19ac61b452ec635c7acb9afa77beb218b1d7784c6d41fe8ad \ + --hash=sha256:f5e6883af9a68c0028f70a4c19d5a6ab6238a379be36ad300a22318316c00cb0 \ + --hash=sha256:f9650713b2cfa9537a2baf7dd9fee458b24a0aaaa6cafcea8bdd5fb2b8efdc34 \ + --hash=sha256:faefeb3b81bdfb4e5a55b9bbdf3d8d8753f65506e1d67d03f5c851a6c87150e9 \ + --hash=sha256:fb9c65bd82f9ef3ce4970dc19ee86be5f6f93d032159acf35e663c6bea02b237 \ + --hash=sha256:fe746d03ed8d193674a26105e4f0fe6c726f5bb602ffc695b409eaf02f04763d \ + --hash=sha256:fef5d70683643618244a4f5221053567ca3e77c2531e42ad48ae05fae909f542 + # via + # accelerate + # mergekit (pyproject.toml) + # peft + # transformers +sentencepiece==0.2.0 \ + --hash=sha256:0461324897735512a32d222e3d886e24ad6a499761952b6bda2a9ee6e4313ea5 \ + --hash=sha256:0993dbc665f4113017892f1b87c3904a44d0640eda510abcacdfb07f74286d36 \ + --hash=sha256:0a91aaa3c769b52440df56fafda683b3aa48e3f2169cf7ee5b8c8454a7f3ae9b \ + --hash=sha256:0f67eae0dbe6f2d7d6ba50a354623d787c99965f068b81e145d53240198021b0 \ + --hash=sha256:1380ce6540a368de2ef6d7e6ba14ba8f3258df650d39ba7d833b79ee68a52040 \ + --hash=sha256:17982700c4f6dbb55fa3594f3d7e5dd1c8659a274af3738e33c987d2a27c9d5c \ + --hash=sha256:188779e1298a1c8b8253c7d3ad729cb0a9891e5cef5e5d07ce4592c54869e227 \ + --hash=sha256:1e0f9c4d0a6b0af59b613175f019916e28ade076e21242fd5be24340d8a2f64a \ + --hash=sha256:20813a68d4c221b1849c62c30e1281ea81687894d894b8d4a0f4677d9311e0f5 \ + --hash=sha256:22e37bac44dd6603388cb598c64ff7a76e41ca774646f21c23aadfbf5a2228ab \ + --hash=sha256:27f90c55a65013cbb8f4d7aab0599bf925cde4adc67ae43a0d323677b5a1c6cb \ + --hash=sha256:298f21cc1366eb60311aedba3169d30f885c363ddbf44214b0a587d2908141ad \ + --hash=sha256:2a3149e3066c2a75e0d68a43eb632d7ae728c7925b517f4c05c40f6f7280ce08 \ + --hash=sha256:2fde4b08cfe237be4484c6c7c2e2c75fb862cfeab6bd5449ce4caeafd97b767a \ + --hash=sha256:3212121805afc58d8b00ab4e7dd1f8f76c203ddb9dc94aa4079618a31cf5da0f \ + --hash=sha256:38aed822fb76435fa1f12185f10465a94ab9e51d5e8a9159e9a540ce926f0ffd \ + --hash=sha256:3f1ec95aa1e5dab11f37ac7eff190493fd87770f7a8b81ebc9dd768d1a3c8704 \ + --hash=sha256:4547683f330289ec4f093027bfeb87f9ef023b2eb6f879fdc4a8187c7e0ffb90 \ + --hash=sha256:4c378492056202d1c48a4979650981635fd97875a00eabb1f00c6a236b013b5e \ + --hash=sha256:536b934e244829e3fe6c4f198652cd82da48adb9aa145c9f00889542726dee3d \ + --hash=sha256:632f3594d3e7ac8b367bca204cb3fd05a01d5b21455acd097ea4c0e30e2f63d7 \ + --hash=sha256:6cf333625234f247ab357b0bd9836638405ea9082e1543d5b8408f014979dcbf \ + --hash=sha256:7140d9e5a74a0908493bb4a13f1f16a401297bd755ada4c707e842fbf6f0f5bf \ + --hash=sha256:787e480ca4c1d08c9985a7eb1eae4345c107729c99e9b5a9a00f2575fc7d4b4b \ + --hash=sha256:7a673a72aab81fef5ebe755c6e0cc60087d1f3a4700835d40537183c1703a45f \ + --hash=sha256:7b06b70af54daa4b4904cbb90b4eb6d35c9f3252fdc86c9c32d5afd4d30118d8 \ + --hash=sha256:7c867012c0e8bcd5bdad0f791609101cb5c66acb303ab3270218d6debc68a65e \ + --hash=sha256:7cd6175f7eaec7142d2bf6f6597ce7db4c9ac89acf93fcdb17410c3a8b781eeb \ + --hash=sha256:7fd6071249c74f779c5b27183295b9202f8dedb68034e716784364443879eaa6 \ + --hash=sha256:859ba1acde782609a0910a26a60e16c191a82bf39b5621107552c0cd79fad00f \ + --hash=sha256:89f65f69636b7e9c015b79dff9c9985a9bc7d19ded6f79ef9f1ec920fdd73ecf \ + --hash=sha256:926ef920ae2e8182db31d3f5d081ada57804e3e1d3a8c4ef8b117f9d9fb5a945 \ + --hash=sha256:98501e075f35dd1a1d5a20f65be26839fcb1938752ec61539af008a5aa6f510b \ + --hash=sha256:a1151d6a6dd4b43e552394aed0edfe9292820272f0194bd56c7c1660a0c06c3d \ + --hash=sha256:a52c19171daaf2e697dc6cbe67684e0fa341b1248966f6aebb541de654d15843 \ + --hash=sha256:b293734059ef656dcd65be62ff771507bea8fed0a711b6733976e1ed3add4553 \ + --hash=sha256:b99a308a2e5e569031ab164b74e6fab0b6f37dfb493c32f7816225f4d411a6dd \ + --hash=sha256:bcbbef6cc277f8f18f36959e305f10b1c620442d75addc79c21d7073ae581b50 \ + --hash=sha256:bed9cf85b296fa2b76fc2547b9cbb691a523864cebaee86304c43a7b4cb1b452 \ + --hash=sha256:c581258cf346b327c62c4f1cebd32691826306f6a41d8c4bec43b010dee08e75 \ + --hash=sha256:cdb701eec783d3ec86b7cd4c763adad8eaf6b46db37ee1c36e5e6c44b3fe1b5f \ + --hash=sha256:d0cb51f53b6aae3c36bafe41e86167c71af8370a039f542c43b0cce5ef24a68c \ + --hash=sha256:d1e5ca43013e8935f25457a4fca47e315780172c3e821b4b13a890668911c792 \ + --hash=sha256:d490142b0521ef22bc1085f061d922a2a6666175bb6b42e588ff95c0db6819b2 \ + --hash=sha256:d7b67e724bead13f18db6e1d10b6bbdc454af574d70efbb36f27d90387be1ca3 \ + --hash=sha256:d8cf876516548b5a1d6ac4745d8b554f5c07891d55da557925e5c13ff0b4e6ad \ + --hash=sha256:e3d1d2cc4882e8d6a1adf9d5927d7716f80617fc693385661caff21888972269 \ + --hash=sha256:e58b47f933aca74c6a60a79dcb21d5b9e47416256c795c2d58d55cec27f9551d \ + --hash=sha256:ea5f536e32ea8ec96086ee00d7a4a131ce583a1b18d130711707c10e69601cb2 \ + --hash=sha256:f295105c6bdbb05bd5e1b0cafbd78ff95036f5d3641e7949455a3f4e5e7c3109 \ + --hash=sha256:f4d158189eb2ecffea3a51edf6d25e110b3678ec47f1a40f2d541eafbd8f6250 \ + --hash=sha256:fb89f811e5efd18bab141afc3fea3de141c3f69f3fe9e898f710ae7fe3aab251 \ + --hash=sha256:ff88712338b01031910e8e61e7239aff3ce8869ee31a47df63cb38aadd591bea + # via mergekit (pyproject.toml) +sympy==1.12 \ + --hash=sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5 \ + --hash=sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8 + # via torch +tokenizers==0.19.1 \ + --hash=sha256:01d62812454c188306755c94755465505836fd616f75067abcae529c35edeb57 \ + --hash=sha256:02e81bf089ebf0e7f4df34fa0207519f07e66d8491d963618252f2e0729e0b46 \ + --hash=sha256:04ce49e82d100594715ac1b2ce87d1a36e61891a91de774755f743babcd0dd52 \ + --hash=sha256:07f9295349bbbcedae8cefdbcfa7f686aa420be8aca5d4f7d1ae6016c128c0c5 \ + --hash=sha256:08a44864e42fa6d7d76d7be4bec62c9982f6f6248b4aa42f7302aa01e0abfd26 \ + --hash=sha256:0b5ca92bfa717759c052e345770792d02d1f43b06f9e790ca0a1db62838816f3 \ + --hash=sha256:0b9394bd204842a2a1fd37fe29935353742be4a3460b6ccbaefa93f58a8df43d \ + --hash=sha256:0bcce02bf1ad9882345b34d5bd25ed4949a480cf0e656bbd468f4d8986f7a3f1 \ + --hash=sha256:0e64bfde9a723274e9a71630c3e9494ed7b4c0f76a1faacf7fe294cd26f7ae7c \ + --hash=sha256:10a707cc6c4b6b183ec5dbfc5c34f3064e18cf62b4a938cb41699e33a99e03c1 \ + --hash=sha256:16baac68651701364b0289979ecec728546133e8e8fe38f66fe48ad07996b88b \ + --hash=sha256:1de5bc8652252d9357a666e609cb1453d4f8e160eb1fb2830ee369dd658e8975 \ + --hash=sha256:1f0360cbea28ea99944ac089c00de7b2e3e1c58f479fb8613b6d8d511ce98267 \ + --hash=sha256:2e8a3dd055e515df7054378dc9d6fa8c8c34e1f32777fb9a01fea81496b3f9d3 \ + --hash=sha256:3174c76efd9d08f836bfccaca7cfec3f4d1c0a4cf3acbc7236ad577cc423c840 \ + --hash=sha256:35583cd46d16f07c054efd18b5d46af4a2f070a2dd0a47914e66f3ff5efb2b1e \ + --hash=sha256:39c1ec76ea1027438fafe16ecb0fb84795e62e9d643444c1090179e63808c69d \ + --hash=sha256:3b11853f17b54c2fe47742c56d8a33bf49ce31caf531e87ac0d7d13d327c9334 \ + --hash=sha256:427c4f0f3df9109314d4f75b8d1f65d9477033e67ffaec4bca53293d3aca286d \ + --hash=sha256:43350270bfc16b06ad3f6f07eab21f089adb835544417afda0f83256a8bf8b75 \ + --hash=sha256:453e4422efdfc9c6b6bf2eae00d5e323f263fff62b29a8c9cd526c5003f3f642 \ + --hash=sha256:4692ab92f91b87769d950ca14dbb61f8a9ef36a62f94bad6c82cc84a51f76f6a \ + --hash=sha256:4ad23d37d68cf00d54af184586d79b84075ada495e7c5c0f601f051b162112dc \ + --hash=sha256:4f3fefdc0446b1a1e6d81cd4c07088ac015665d2e812f6dbba4a06267d1a2c95 \ + --hash=sha256:56ae39d4036b753994476a1b935584071093b55c7a72e3b8288e68c313ca26e7 \ + --hash=sha256:5c88d1481f1882c2e53e6bb06491e474e420d9ac7bdff172610c4f9ad3898059 \ + --hash=sha256:61b7fe8886f2e104d4caf9218b157b106207e0f2a4905c9c7ac98890688aabeb \ + --hash=sha256:621d670e1b1c281a1c9698ed89451395d318802ff88d1fc1accff0867a06f153 \ + --hash=sha256:6258c2ef6f06259f70a682491c78561d492e885adeaf9f64f5389f78aa49a051 \ + --hash=sha256:6309271f57b397aa0aff0cbbe632ca9d70430839ca3178bf0f06f825924eca22 \ + --hash=sha256:638e43936cc8b2cbb9f9d8dde0fe5e7e30766a3318d2342999ae27f68fdc9bd6 \ + --hash=sha256:63c38f45d8f2a2ec0f3a20073cccb335b9f99f73b3c69483cd52ebc75369d8a1 \ + --hash=sha256:670b802d4d82bbbb832ddb0d41df7015b3e549714c0e77f9bed3e74d42400fbe \ + --hash=sha256:6852c5b2a853b8b0ddc5993cd4f33bfffdca4fcc5d52f89dd4b8eada99379285 \ + --hash=sha256:6b2da5c32ed869bebd990c9420df49813709e953674c0722ff471a116d97b22d \ + --hash=sha256:6c330c0eb815d212893c67a032e9dc1b38a803eccb32f3e8172c19cc69fbb439 \ + --hash=sha256:6f8a20266e695ec9d7a946a019c1d5ca4eddb6613d4f466888eee04f16eedb85 \ + --hash=sha256:706a37cc5332f85f26efbe2bdc9ef8a9b372b77e4645331a405073e4b3a8c1c6 \ + --hash=sha256:71e3ec71f0e78780851fef28c2a9babe20270404c921b756d7c532d280349214 \ + --hash=sha256:72791f9bb1ca78e3ae525d4782e85272c63faaef9940d92142aa3eb79f3407a3 \ + --hash=sha256:76951121890fea8330d3a0df9a954b3f2a37e3ec20e5b0530e9a0044ca2e11fe \ + --hash=sha256:78e769eb3b2c79687d9cb0f89ef77223e8e279b75c0a968e637ca7043a84463f \ + --hash=sha256:7c9d5b6c0e7a1e979bec10ff960fae925e947aab95619a6fdb4c1d8ff3708ce3 \ + --hash=sha256:7fb297edec6c6841ab2e4e8f357209519188e4a59b557ea4fafcf4691d1b4c98 \ + --hash=sha256:7ff898780a155ea053f5d934925f3902be2ed1f4d916461e1a93019cc7250837 \ + --hash=sha256:82c8b8063de6c0468f08e82c4e198763e7b97aabfe573fd4cf7b33930ca4df77 \ + --hash=sha256:85aa3ab4b03d5e99fdd31660872249df5e855334b6c333e0bc13032ff4469c4a \ + --hash=sha256:89183e55fb86e61d848ff83753f64cded119f5d6e1f553d14ffee3700d0a4a49 \ + --hash=sha256:8a6298bde623725ca31c9035a04bf2ef63208d266acd2bed8c2cb7d2b7d53ce6 \ + --hash=sha256:8b01afb7193d47439f091cd8f070a1ced347ad0f9144952a30a41836902fe09e \ + --hash=sha256:952078130b3d101e05ecfc7fc3640282d74ed26bcf691400f872563fca15ac97 \ + --hash=sha256:952b80dac1a6492170f8c2429bd11fcaa14377e097d12a1dbe0ef2fb2241e16c \ + --hash=sha256:9620b78e0b2d52ef07b0d428323fb34e8ea1219c5eac98c2596311f20f1f9266 \ + --hash=sha256:9ed240c56b4403e22b9584ee37d87b8bfa14865134e3e1c3fb4b2c42fafd3256 \ + --hash=sha256:a179856d1caee06577220ebcfa332af046d576fb73454b8f4d4b0ba8324423ea \ + --hash=sha256:a2b718f316b596f36e1dae097a7d5b91fc5b85e90bf08b01ff139bd8953b25af \ + --hash=sha256:ac11016d0a04aa6487b1513a3a36e7bee7eec0e5d30057c9c0408067345c48d2 \ + --hash=sha256:ad57d59341710b94a7d9dbea13f5c1e7d76fd8d9bcd944a7a6ab0b0da6e0cc66 \ + --hash=sha256:b07c538ba956843833fee1190cf769c60dc62e1cf934ed50d77d5502194d63b1 \ + --hash=sha256:b279ab506ec4445166ac476fb4d3cc383accde1ea152998509a94d82547c8e2a \ + --hash=sha256:b2edbc75744235eea94d595a8b70fe279dd42f3296f76d5a86dde1d46e35f574 \ + --hash=sha256:b342d2ce8fc8d00f376af068e3274e2e8649562e3bc6ae4a67784ded6b99428d \ + --hash=sha256:b4399b59d1af5645bcee2072a463318114c39b8547437a7c2d6a186a1b5a0e2d \ + --hash=sha256:b4c89aa46c269e4e70c4d4f9d6bc644fcc39bb409cb2a81227923404dd6f5227 \ + --hash=sha256:b70bfbe3a82d3e3fb2a5e9b22a39f8d1740c96c68b6ace0086b39074f08ab89a \ + --hash=sha256:b82931fa619dbad979c0ee8e54dd5278acc418209cc897e42fac041f5366d626 \ + --hash=sha256:bac0b0eb952412b0b196ca7a40e7dce4ed6f6926489313414010f2e6b9ec2adf \ + --hash=sha256:bb9dfe7dae85bc6119d705a76dc068c062b8b575abe3595e3c6276480e67e3f1 \ + --hash=sha256:bcd266ae85c3d39df2f7e7d0e07f6c41a55e9a3123bb11f854412952deacd828 \ + --hash=sha256:bea6f9947e9419c2fda21ae6c32871e3d398cba549b93f4a65a2d369662d9403 \ + --hash=sha256:c27b99889bd58b7e301468c0838c5ed75e60c66df0d4db80c08f43462f82e0d3 \ + --hash=sha256:c2a0d47a89b48d7daa241e004e71fb5a50533718897a4cd6235cb846d511a478 \ + --hash=sha256:c5c2ff13d157afe413bf7e25789879dd463e5a4abfb529a2d8f8473d8042e28f \ + --hash=sha256:c85cf76561fbd01e0d9ea2d1cbe711a65400092bc52b5242b16cfd22e51f0c58 \ + --hash=sha256:ca407133536f19bdec44b3da117ef0d12e43f6d4b56ac4c765f37eca501c7bda \ + --hash=sha256:cbf001afbbed111a79ca47d75941e9e5361297a87d186cbfc11ed45e30b5daba \ + --hash=sha256:ce05fde79d2bc2e46ac08aacbc142bead21614d937aac950be88dc79f9db9022 \ + --hash=sha256:d16ff18907f4909dca9b076b9c2d899114dd6abceeb074eca0c93e2353f943aa \ + --hash=sha256:d26194ef6c13302f446d39972aaa36a1dda6450bc8949f5eb4c27f51191375bd \ + --hash=sha256:d8c5d59d7b59885eab559d5bc082b2985555a54cda04dda4c65528d90ad252ad \ + --hash=sha256:d924204a3dbe50b75630bd16f821ebda6a5f729928df30f582fb5aade90c818a \ + --hash=sha256:dadc509cc8a9fe460bd274c0e16ac4184d0958117cf026e0ea8b32b438171594 \ + --hash=sha256:dd26e3afe8a7b61422df3176e06664503d3f5973b94f45d5c45987e1cb711876 \ + --hash=sha256:ddf672ed719b4ed82b51499100f5417d7d9f6fb05a65e232249268f35de5ed14 \ + --hash=sha256:dfedf31824ca4915b511b03441784ff640378191918264268e6923da48104acc \ + --hash=sha256:e28cab1582e0eec38b1f38c1c1fb2e56bce5dc180acb1724574fc5f47da2a4fe \ + --hash=sha256:e742d76ad84acbdb1a8e4694f915fe59ff6edc381c97d6dfdd054954e3478ad4 \ + --hash=sha256:e83a31c9cf181a0a3ef0abad2b5f6b43399faf5da7e696196ddd110d332519ee \ + --hash=sha256:e8d1ed93beda54bbd6131a2cb363a576eac746d5c26ba5b7556bc6f964425594 \ + --hash=sha256:e8ff5b90eabdcdaa19af697885f70fe0b714ce16709cf43d4952f1f85299e73a \ + --hash=sha256:ec11802450a2487cdf0e634b750a04cbdc1c4d066b97d94ce7dd2cb51ebb325b \ + --hash=sha256:ecb2651956eea2aa0a2d099434134b1b68f1c31f9a5084d6d53f08ed43d45ff2 \ + --hash=sha256:ed69af290c2b65169f0ba9034d1dc39a5db9459b32f1dd8b5f3f32a3fcf06eab \ + --hash=sha256:eddd5783a4a6309ce23432353cdb36220e25cbb779bfa9122320666508b44b88 \ + --hash=sha256:ee59e6680ed0fdbe6b724cf38bd70400a0c1dd623b07ac729087270caeac88e3 \ + --hash=sha256:f03727225feaf340ceeb7e00604825addef622d551cbd46b7b775ac834c1e1c4 \ + --hash=sha256:f3bbb7a0c5fcb692950b041ae11067ac54826204318922da754f908d95619fbc \ + --hash=sha256:f8a9c828277133af13f3859d1b6bf1c3cb6e9e1637df0e45312e6b7c2e622b1f \ + --hash=sha256:f97660f6c43efd3e0bfd3f2e3e5615bf215680bad6ee3d469df6454b8c6e8256 \ + --hash=sha256:f9939ca7e58c2758c01b40324a59c034ce0cebad18e0d4563a9b1beab3018243 + # via transformers +torch==2.3.0 \ + --hash=sha256:09c81c5859a5b819956c6925a405ef1cdda393c9d8a01ce3851453f699d3358c \ + --hash=sha256:1bf023aa20902586f614f7682fedfa463e773e26c58820b74158a72470259459 \ + --hash=sha256:20572f426965dd8a04e92a473d7e445fa579e09943cc0354f3e6fef6130ce061 \ + --hash=sha256:493d54ee2f9df100b5ce1d18c96dbb8d14908721f76351e908c9d2622773a788 \ + --hash=sha256:4fb27b35dbb32303c2927da86e27b54a92209ddfb7234afb1949ea2b3effffea \ + --hash=sha256:5515503a193781fd1b3f5c474e89c9dfa2faaa782b2795cc4a7ab7e67de923f6 \ + --hash=sha256:6ae9f64b09516baa4ef890af0672dc981c20b1f0d829ce115d4420a247e88fba \ + --hash=sha256:729804e97b7cf19ae9ab4181f91f5e612af07956f35c8b2c8e9d9f3596a8e877 \ + --hash=sha256:758ef938de87a2653bba74b91f703458c15569f1562bf4b6c63c62d9c5a0c1f5 \ + --hash=sha256:760f8bedff506ce9e6e103498f9b1e9e15809e008368594c3a66bf74a8a51380 \ + --hash=sha256:a306c87a3eead1ed47457822c01dfbd459fe2920f2d38cbdf90de18f23f72542 \ + --hash=sha256:b0de2bdc0486ea7b14fc47ff805172df44e421a7318b7c4d92ef589a75d27410 \ + --hash=sha256:bce43af735c3da16cc14c7de2be7ad038e2fbf75654c2e274e575c6c05772ace \ + --hash=sha256:cd0dc498b961ab19cb3f8dbf0c6c50e244f2f37dbfa05754ab44ea057c944ef9 \ + --hash=sha256:d24e328226d8e2af7cf80fcb1d2f1d108e0de32777fab4aaa2b37b9765d8be73 \ + --hash=sha256:d8ea5a465dbfd8501f33c937d1f693176c9aef9d1c1b0ca1d44ed7b0a18c52ac \ + --hash=sha256:dca986214267b34065a79000cee54232e62b41dff1ec2cab9abc3fc8b3dee0ad \ + --hash=sha256:e05f836559251e4096f3786ee99f4a8cbe67bc7fbedba8ad5e799681e47c5e80 \ + --hash=sha256:e65ba85ae292909cde0dde6369826d51165a3fc8823dc1854cd9432d7f79b932 \ + --hash=sha256:f9b98bf1a3c8af2d4c41f0bf1433920900896c446d1ddc128290ff146d1eb4bd + # via + # accelerate + # mergekit (pyproject.toml) + # peft +tqdm==4.66.4 \ + --hash=sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644 \ + --hash=sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb + # via + # huggingface-hub + # mergekit (pyproject.toml) + # peft + # transformers +transformers==4.41.1 \ + --hash=sha256:f0680e0b1a01067eccd11f62f0522409422c7d6f91d532fe0f50b136a406129d \ + --hash=sha256:fa859e4c66f0896633a3bf534e0d9a29a9a88478a49f94c5d8270537dc61cc42 + # via + # mergekit (pyproject.toml) + # peft +typing-extensions==4.12.0 \ + --hash=sha256:8cbcdc8606ebcb0d95453ad7dc5065e6237b6aa230a31e81d0f440c30fed5fd8 \ + --hash=sha256:b349c66bea9016ac22978d800cfff206d5f9816951f12a7d0ec5578b0a819594 + # via + # huggingface-hub + # mergekit (pyproject.toml) + # pydantic + # pydantic-core + # torch +urllib3==2.2.1 \ + --hash=sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d \ + --hash=sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19 + # via requests