Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD gfpgan #86

Open
kyakuno opened this issue Jun 25, 2024 · 14 comments
Open

ADD gfpgan #86

kyakuno opened this issue Jun 25, 2024 · 14 comments
Assignees

Comments

@kyakuno
Copy link
Contributor

kyakuno commented Jun 25, 2024

ONNX -> tflite -> quantize
https://github.com/axinc-ai/ailia-models/tree/master/generative_adversarial_networks/gfpgan

@kyakuno kyakuno self-assigned this Aug 13, 2024
@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 13, 2024

変換

wget https://storage.googleapis.com/ailia-models/gfpgan/GFPGANv1.3.onnx
onnx2tf -i GFPGANv1.3.onnx

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 13, 2024

tensorflow-macosのバージョンが2.12で、古く、下記でエラーが出るので、Linuxで実行した方が良い。

tf.random.set_seed(0)

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 13, 2024

ERROR: The trace log is below.
Traceback (most recent call last):
  File "/home/kyakuno/.local/lib/python3.10/site-packages/onnx2tf/utils/common_functions.py", line 312, in print_wrapper_func
    result = func(*args, **kwargs)
  File "/home/kyakuno/.local/lib/python3.10/site-packages/onnx2tf/utils/common_functions.py", line 385, in inverted_operation_enable_disable_wrapper_func
    result = func(*args, **kwargs)
  File "/home/kyakuno/.local/lib/python3.10/site-packages/onnx2tf/utils/common_functions.py", line 55, in get_replacement_parameter_wrapper_func
    func(*args, **kwargs)
  File "/home/kyakuno/.local/lib/python3.10/site-packages/onnx2tf/ops/Expand.py", line 222, in make_node
    min_abs_err_perm_2: List[int] = [idx for idx, val in enumerate(input_tensor_shape)]
  File "/home/kyakuno/.local/lib/python3.10/site-packages/tf_keras/src/engine/keras_tensor.py", line 418, in __iter__
    raise TypeError(
TypeError: Cannot iterate over a Tensor with unknown first dimension.

ERROR: input_onnx_file_path: GFPGANv1.3.onnx
ERROR: onnx_op_name: Expand_215
ERROR: Read this and deal with it. https://github.com/PINTO0309/onnx2tf#parameter-replacement
ERROR: Alternatively, if the input OP has a dynamic dimension, use the -b or -ois option to rewrite it to a static shape and try again.
ERROR: If the input OP of ONNX before conversion is NHWC or an irregular channel arrangement other than NCHW, use the -kt or -kat option.
ERROR: Also, for models that include NonMaxSuppression in the post-processing, try the -onwdt option.

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 13, 2024

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 13, 2024

resnet18を量子化してみるテスト

# normal
import torch
import torchvision
import ai_edge_torch

resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
sample_inputs = (torch.randn(1, 3, 224, 224),)
edge_model = ai_edge_torch.convert(resnet18.eval(), sample_inputs)
edge_model.export("resnet18.tflite")

# quantize
from ai_edge_torch.quantize import pt2e_quantizer
from ai_edge_torch.quantize import quant_config
from torch.ao.quantization import quantize_pt2e

quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
    pt2e_quantizer.get_symmetric_quantization_config()
)
model = torch._export.capture_pre_autograd_graph(resnet18, sample_inputs)
model = quantize_pt2e.prepare_pt2e(model, quantizer)
model = quantize_pt2e.convert_pt2e(model, fold_quantize=False)

without_quantizer = ai_edge_torch.convert(model, sample_inputs)
with_quantizer = ai_edge_torch.convert(
    model,
    sample_inputs,
    quant_config=quant_config.QuantConfig(pt2e_quantizer=quantizer),
)
with_quantizer.export("resnet18_quantrized.tflite")

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 13, 2024

gfpganの変換

# normal
import torch
import ai_edge_torch

import os
model_name = "GFPGANv1.3"
model_path = os.path.join('experiments/pretrained_models', model_name + '.pth')
upscale = 2
arch = 'clean'
channel_multiplier = 2
from gfpgan import GFPGANer
restorer = GFPGANer(
    model_path=model_path,
    upscale=upscale,
    arch=arch,
    channel_multiplier=channel_multiplier,
    bg_upsampler=None,
    device="cpu")
model = restorer.gfpgan
model.eval()

sample_inputs = (torch.randn(1, 3, 512, 512),)
edge_model = ai_edge_torch.convert(model, sample_inputs)
edge_model.export("gfpgan_float.tflite")

# quantize
from ai_edge_torch.quantize import pt2e_quantizer
from ai_edge_torch.quantize import quant_config
from torch.ao.quantization import quantize_pt2e

quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
    pt2e_quantizer.get_symmetric_quantization_config()
)
model = torch._export.capture_pre_autograd_graph(model, sample_inputs)
model = quantize_pt2e.prepare_pt2e(model, quantizer)
model = quantize_pt2e.convert_pt2e(model, fold_quantize=False)

without_quantizer = ai_edge_torch.convert(model, sample_inputs)
with_quantizer = ai_edge_torch.convert(
    model,
    sample_inputs,
    quant_config=quant_config.QuantConfig(pt2e_quantizer=quantizer),
)
with_quantizer.export("gfpgan_int8.tflite")

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 13, 2024

normalize

           cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
            normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 13, 2024

変換したモデルはtensorflow 2.12.0だと下記のエラーになる。

ValueError: Op builtin_code out of range: 204. Are you using old TFLite binary with newer model?Registration failed.

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 13, 2024

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 13, 2024

JAX向けに追加されていそう。
https://github.com/openxla/stablehlo

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 13, 2024

下記のmake_noiseが該当しそう。
https://github.com/TencentARC/GFPGAN/blob/master/gfpgan/archs/stylegan2_clean_arch.py

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 13, 2024

new_emptyで乱数を作っているので、これをrandに置き換えると良さそう。

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 13, 2024

出力を一個にする。

gfpganv1_clean_arch.py

def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
->
def forward(self, x, return_latents=False, return_rgb=False, randomize_noise=True, **kwargs):

stylegan2_clean_arch.py

STABLEHLO_RNG_BIT_GENERATORを無効にする。(これは効果がなかった)

noise = out.new_empty(b, 1, h, w).normal_()
->
noise = torch.randn(b, 1, h, w, device=out.device)

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 13, 2024

calibrationはprepare_pt2eとconvert_pt2eの間で行う。
tensorflow/tflite-support#980

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant