Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: QDQQuantizer.__init__() missing 1 required positional argument: 'mode' #2019

Open
2 of 4 tasks
blacker521 opened this issue Sep 10, 2024 · 4 comments
Open
2 of 4 tasks
Labels
bug Something isn't working

Comments

@blacker521
Copy link

blacker521 commented Sep 10, 2024

System Info

optimum  1.22.0.dev0

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction (minimal, reproducible, runnable)

from functools import partial
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTQuantizer, ORTModelForSequenceClassification
from optimum.onnxruntime.configuration import AutoQuantizationConfig, AutoCalibrationConfig

model_id = "distilbert-base-uncased-finetuned-sst-2-english"

# onnx_model = ORTModelForSequenceClassification.from_pretrained(model_id, export=True)
tokenizer = AutoTokenizer.from_pretrained("/mnt/nas_new/plms/bce-embedding-base_v1")
quantizer = ORTQuantizer.from_pretrained("./fp32")
qconfig = AutoQuantizationConfig.tensorrt(per_channel=True)

def preprocess_fn(ex, tokenizer):
    return tokenizer(ex["sentence"])

calibration_dataset = quantizer.get_calibration_dataset(
    "glue",
    dataset_config_name="sst2",
    preprocess_function=partial(preprocess_fn, tokenizer=tokenizer),
    num_samples=50,
    dataset_split="train",
)

calibration_config = AutoCalibrationConfig.minmax(calibration_dataset)

ranges = quantizer.fit(
    dataset=calibration_dataset,
    calibration_config=calibration_config,
    operators_to_quantize=qconfig.operators_to_quantize,
)

model_quantized_path = quantizer.quantize(
    save_dir="./output_opt",
    calibration_tensors_range=ranges,
    quantization_config=qconfig,
)

Expected behavior

Save a quantified model

@blacker521 blacker521 added the bug Something isn't working label Sep 10, 2024
@IlyasMoutawwakil
Copy link
Member

Hi, I exported the model (with cli and opset > 13), ran your script and it worked as expected.
Please report which onnxruntime version you are using, the mode argument is no longer part of QDQQuantizer signature https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/qdq_quantizer.py#L129

from functools import partial

from transformers import AutoTokenizer

from optimum.onnxruntime import ORTModelForSequenceClassification, ORTQuantizer
from optimum.onnxruntime.configuration import AutoCalibrationConfig, AutoQuantizationConfig


model_id = "distilbert-base-uncased-finetuned-sst-2-english"

tokenizer = AutoTokenizer.from_pretrained(model_id)
onnx_model = ORTModelForSequenceClassification.from_pretrained("./onnx_model")
quantizer = ORTQuantizer.from_pretrained(onnx_model)

qconfig = AutoQuantizationConfig.tensorrt(per_channel=True)


def preprocess_fn(ex, tokenizer):
    return tokenizer(ex["sentence"])


calibration_dataset = quantizer.get_calibration_dataset(
    "glue",
    dataset_config_name="sst2",
    preprocess_function=partial(preprocess_fn, tokenizer=tokenizer),
    num_samples=50,
    dataset_split="train",
)

calibration_config = AutoCalibrationConfig.minmax(calibration_dataset)

ranges = quantizer.fit(
    dataset=calibration_dataset,
    calibration_config=calibration_config,
    operators_to_quantize=qconfig.operators_to_quantize,
)

model_quantized_path = quantizer.quantize(
    save_dir="./output_opt",
    calibration_tensors_range=ranges,
    quantization_config=qconfig,
)

@blacker521
Copy link
Author

I installed it the official way
python -m pip install git+https://github.com/huggingface/optimum.git
python -m pip install optimum[onnxruntime]@git+https://github.com/huggingface/optimum.git

@blacker521
Copy link
Author

Package Version


aiohappyeyeballs 2.4.0
aiohttp 3.10.5
aiosignal 1.3.1
async-timeout 4.0.3
attrs 24.2.0
certifi 2024.8.30
charset-normalizer 3.3.2
coloredlogs 15.0.1
datasets 3.0.0
dill 0.3.8
evaluate 0.4.3
filelock 3.16.0
flatbuffers 24.3.25
frozenlist 1.4.1
fsspec 2024.6.1
huggingface-hub 0.24.6
humanfriendly 10.0
idna 3.8
Jinja2 3.1.4
MarkupSafe 2.1.5
mpmath 1.3.0
multidict 6.1.0
multiprocess 0.70.16
networkx 3.1
numpy 1.24.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.20.5
nvidia-nvjitlink-cu12 12.6.68
nvidia-nvtx-cu12 12.1.105
onnx 1.16.2
onnxruntime 1.16.3
optimum 1.23.0.dev0
packaging 24.1
pandas 2.0.3
pip 24.2
protobuf 5.28.1
pyarrow 17.0.0
python-dateutil 2.9.0.post0
pytz 2024.2
PyYAML 6.0.2
regex 2024.9.11
requests 2.32.3
safetensors 0.4.5
sentencepiece 0.2.0
setuptools 72.1.0
six 1.16.0
sympy 1.13.2
tokenizers 0.19.1
torch 2.4.1
tqdm 4.66.5
transformers 4.44.2
triton 3.0.0
typing_extensions 4.12.2
tzdata 2024.1
urllib3 2.2.2
wheel 0.43.0
xxhash 3.5.0
yarl 1.11.1

@IlyasMoutawwakil
Copy link
Member

IlyasMoutawwakil commented Sep 12, 2024

yeah we should probably pin a minimum onnxruntime version, please update onnxruntime and onnx with pip install onnxruntime onnx -U

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants