diff --git a/mergekit/evo/actors.py b/mergekit/evo/actors.py index e107efe7..0f9f42fc 100644 --- a/mergekit/evo/actors.py +++ b/mergekit/evo/actors.py @@ -62,6 +62,7 @@ def __init__( vllm: bool = False, batch_size: Optional[int] = None, task_manager: Optional[lm_eval.tasks.TaskManager] = None, + quantization_config: Optional[transformers.BitsAndBytesConfig] = None, ): self.config = config self.genome = genome @@ -72,6 +73,7 @@ def __init__( self.vllm = vllm self.batch_size = batch_size self.task_manager = task_manager + self.quantization_config = quantization_config if config.shuffle: monkeypatch_lmeval_shuffle() @@ -105,6 +107,9 @@ def evaluate_genotype( logging.error("Model merge failed") return {"score": None, "results": None} + kwargs = {} + if self.quantization_config is not None: + kwargs["quantization_config"] = self.quantization_config logging.info(f"Model merged to {merged_path}") return evaluate_model( merged_path, @@ -114,6 +119,7 @@ def evaluate_genotype( vllm=self.vllm, batch_size=self.batch_size, task_manager=self.task_manager, + **kwargs, ) diff --git a/mergekit/evo/helpers.py b/mergekit/evo/helpers.py index f87829d5..3ab557c3 100644 --- a/mergekit/evo/helpers.py +++ b/mergekit/evo/helpers.py @@ -67,6 +67,7 @@ def evaluate_model( vllm: bool, batch_size: Optional[int] = None, task_manager: Optional[lm_eval.tasks.TaskManager] = None, + **kwargs, ) -> dict: # monkeypatch_tqdm() monkeypatch_lmeval_vllm() @@ -74,6 +75,7 @@ def evaluate_model( model_args = { "pretrained": merged_path, "dtype": "bfloat16", + **kwargs, } if vllm: model_args["gpu_memory_utilization"] = 0.8 diff --git a/mergekit/evo/strategy.py b/mergekit/evo/strategy.py index 2f4c581c..c434dd1d 100644 --- a/mergekit/evo/strategy.py +++ b/mergekit/evo/strategy.py @@ -25,6 +25,7 @@ import ray.util.queue import ray.util.scheduling_strategies import torch +import transformers from mergekit.evo.actors import InMemoryMergeEvaluator, OnDiskMergeEvaluator from mergekit.evo.config import EvolMergeConfiguration @@ -43,6 +44,7 @@ def __init__( batch_size: Optional[int] = None, task_search_path: Union[str, List[str], None] = None, model_storage_path: Optional[str] = None, + quantization_config: Optional[transformers.BitsAndBytesConfig] = None, ): self.config = config self.genome = genome @@ -51,6 +53,7 @@ def __init__( self.batch_size = batch_size self.task_manager = lm_eval.tasks.TaskManager(include_path=task_search_path) self.model_storage_path = model_storage_path + self.quantization_config = quantization_config if self.model_storage_path: os.makedirs(self.model_storage_path, exist_ok=True) @@ -91,6 +94,7 @@ def __init__( vllm=vllm, batch_size=self.batch_size, task_manager=self.task_manager, + quantization_config=self.quantization_config, ) for _ in range(self.num_gpus) ] @@ -120,6 +124,7 @@ def __init__( batch_size: Optional[int] = None, task_manager: Optional[lm_eval.tasks.TaskManager] = None, model_storage_path: Optional[str] = None, + quantization_config: Optional[transformers.BitsAndBytesConfig] = None, ): self.config = config self.genome = genome @@ -130,6 +135,7 @@ def __init__( self.batch_size = batch_size self.task_manager = task_manager self.model_storage_path = model_storage_path + self.quantization_config = quantization_config self._shutdown = False async def evaluate_genotype(self, genotype: np.ndarray): @@ -159,6 +165,9 @@ async def process_queue(self): while merged and len(evaluating) < self.num_gpus: future_result, merged_path = merged.pop() + kwargs = {} + if self.quantization_config is not None: + kwargs["quantization_config"] = self.quantization_config evaluating[ evaluate_model_ray.remote( merged_path, @@ -168,6 +177,7 @@ async def process_queue(self): vllm=self.vllm, batch_size=self.batch_size, task_manager=self.task_manager, + **kwargs, ) ] = future_result @@ -222,6 +232,8 @@ def __init__( vllm=vllm, num_gpus=self.num_gpus, task_manager=self.task_manager, + batch_size=self.batch_size, + quantization_config=self.quantization_config, ) self.actor.process_queue.remote() @@ -242,6 +254,7 @@ def evaluate_genotype_serial( vllm: bool = False, batch_size: Optional[int] = None, task_manager: Optional[lm_eval.tasks.TaskManager] = None, + quantization_config: Optional[transformers.BitsAndBytesConfig] = None, ): pg = ray.util.placement_group([{"CPU": 1, "GPU": 1}], strategy="STRICT_PACK") strat = ray.util.scheduling_strategies.PlacementGroupSchedulingStrategy( @@ -252,6 +265,9 @@ def evaluate_genotype_serial( ) if not merged_path: return {"score": None, "results": None} + kwargs = {} + if quantization_config is not None: + kwargs["quantization_config"] = quantization_config res = ray.get( evaluate_model_ray.options(scheduling_strategy=strat).remote( merged_path, @@ -261,6 +277,7 @@ def evaluate_genotype_serial( vllm=vllm, batch_size=batch_size, task_manager=task_manager, + **kwargs, ) ) ray.util.remove_placement_group(pg) @@ -292,6 +309,7 @@ def evaluate_genotypes(self, genotypes: List[np.ndarray]) -> List[dict]: vllm=self.vllm, batch_size=self.batch_size, task_manager=self.task_manager, + quantization_config=self.quantization_config, ) for x in genotypes ] diff --git a/mergekit/scripts/evolve.py b/mergekit/scripts/evolve.py index 7c8c1a09..b749fdc9 100644 --- a/mergekit/scripts/evolve.py +++ b/mergekit/scripts/evolve.py @@ -113,6 +113,18 @@ default=None, help="Maximum time to run the optimization in seconds", ) +@click.option( + "--load-in-8bit", + is_flag=True, + default=False, + help="Evaluate models at 8-bit precision", +) +@click.option( + "--load-in-4bit", + is_flag=True, + default=False, + help="Evaluate models at 4-bit precision", +) def main( genome_config_path: str, max_fevals: int, @@ -135,6 +147,8 @@ def main( save_final_model: bool, reshard: bool, timeout: Optional[float], + load_in_8bit: bool, + load_in_4bit: bool, ): config = EvolMergeConfiguration.model_validate( yaml.safe_load(open(genome_config_path, "r", encoding="utf-8")) @@ -142,6 +156,29 @@ def main( check_for_naughty_config(config, allow=allow_benchmark_tasks) + if load_in_4bit and load_in_8bit: + raise ValueError("Cannot load models in both 4-bit and 8-bit") + + if load_in_4bit or load_in_8bit: + if vllm: + raise ValueError("Cannot use vLLM with 4-bit or 8-bit models") + if in_memory: + raise ValueError("Cannot use in-memory mode with 4-bit or 8-bit models") + try: + import bitsandbytes + except ImportError: + raise RuntimeError("bitsandbytes is not installed") + + bnb_config = transformers.BitsAndBytesConfig( + load_in_8bit=load_in_8bit, + load_in_4bit=load_in_4bit, + bnb_4bit_compute_dtype="bfloat16", + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + ) + else: + bnb_config = None + if use_wandb: if not wandb: raise RuntimeError("wandb is not installed") @@ -227,6 +264,7 @@ def main( model_storage_path=os.path.join(storage_path, "merged"), batch_size=batch_size, task_search_path=task_search_path, + quantization_config=bnb_config, ) x0 = genome.initial_genotype(random=config.random_init).view(-1).numpy()