Skip to content

Commit

Permalink
Fix quanitzation & remove memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed May 7, 2024
1 parent 1291546 commit 691b3bb
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 60 deletions.
8 changes: 4 additions & 4 deletions tools/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def load_model(

if "int8" in str(checkpoint_path):
logger.info("Using int8 weight-only quantization!")
from quantize import WeightOnlyInt8QuantHandler
from .quantize import WeightOnlyInt8QuantHandler

simple_quantizer = WeightOnlyInt8QuantHandler(model)
model = simple_quantizer.convert_for_runtime()
Expand All @@ -377,7 +377,7 @@ def load_model(
path_comps = checkpoint_path.name.split(".")
assert path_comps[-2].startswith("g")
groupsize = int(path_comps[-2][1:])
from quantize import WeightOnlyInt4QuantHandler
from .quantize import WeightOnlyInt4QuantHandler

simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
model = simple_quantizer.convert_for_runtime()
Expand Down Expand Up @@ -669,9 +669,9 @@ def worker():
@click.option(
"--checkpoint-path",
type=click.Path(path_type=Path, exists=True),
default="results/text2semantic_400m_finetune/step_000002000.pth",
default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
)
@click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
@click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
@click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
@click.option("--compile/--no-compile", default=False)
@click.option("--seed", type=int, default=42)
Expand Down
78 changes: 25 additions & 53 deletions tools/llama/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import time
from pathlib import Path

import click
import torch
import torch.nn as nn
import torch.nn.functional as F

from fish_speech.models.text2semantic.llama import ModelArgs, Transformer, find_multiple
from .generate import load_model

##### Quantization Primitives ######

Expand Down Expand Up @@ -414,11 +415,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
)


@click.command()
@click.option(
"--checkpoint-path",
type=click.Path(path_type=Path, exists=True),
default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
)
@click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
@click.option(
"--mode", type=str, default="int8", help="type of quantization to perform"
)
@click.option(
"--groupsize", type=int, default=128, help="Group size for int4 quantization."
)
def quantize(
checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
mode: str = "int8",
# following arguments only available when setting int4 quantization.
groupsize: int = 128,
checkpoint_path: Path, config_name: str, mode: str, groupsize: int
) -> None:
assert checkpoint_path.is_file(), checkpoint_path

Expand All @@ -428,31 +439,14 @@ def quantize(
print("Loading model ...")
t0 = time.time()

with torch.device("meta"):
model = Transformer(
ModelArgs(
max_seq_len=4096,
vocab_size=36408,
n_layer=24,
n_head=16,
dim=1024,
rope_base=10000,
norm_eps=1e-5,
num_codebooks=4, # single codebook
codebook_size=168, # codebook size 160 + 2 special tokens
)
)

checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
checkpoint = {
k.replace("model.", ""): v
for k, v in checkpoint.items()
if k.startswith("model.")
}
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
model, _ = load_model(
config_name,
checkpoint_path=checkpoint_path,
device=device,
precision=precision,
compile=False,
max_length=2048,
)

if mode == "int8":
print(
Expand Down Expand Up @@ -490,26 +484,4 @@ def quantize(


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Quantize a model.")
parser.add_argument(
"--checkpoint_path",
type=Path,
default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
help="Path to the model checkpoint to be quantized.",
)
parser.add_argument(
"--mode",
"-q",
type=str,
default="int8",
choices=["int8", "int4"],
help="type of quantization to perform",
)
parser.add_argument(
"--groupsize", type=int, default=32, help="Group size for int4 quantization."
)

args = parser.parse_args()
quantize(args.checkpoint_path, args.mode, args.groupsize)
quantize()
12 changes: 9 additions & 3 deletions tools/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,15 @@ def inference(

# VQGAN Inference
feature_lengths = torch.tensor([result.shape[1]], device=vqgan_model.device)
fake_audios = vqgan_model.decode(
indices=result[None], feature_lengths=feature_lengths, return_audios=True
)[0, 0]

with torch.autocast(
device_type=feature_lengths.device.type, dtype=args.precision
):
fake_audios = vqgan_model.decode(
indices=result[None],
feature_lengths=feature_lengths,
return_audios=True,
)[0, 0]
fake_audios = fake_audios.float().cpu().numpy()

if streaming:
Expand Down

0 comments on commit 691b3bb

Please sign in to comment.