Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD SAM2 (tflite) #88

Open
kyakuno opened this issue Aug 21, 2024 · 32 comments
Open

ADD SAM2 (tflite) #88

kyakuno opened this issue Aug 21, 2024 · 32 comments
Assignees

Comments

@kyakuno
Copy link
Contributor

kyakuno commented Aug 21, 2024

edge-ai-torchで変換を検討。
https://github.com/facebookresearch/segment-anything-2
https://medium.com/axinc/ai-edge-torch%E3%81%A7pytorch%E3%81%8B%E3%82%89tflite%E3%81%AB%E5%A4%89%E6%8F%9B%E3%81%99%E3%82%8B-376be7dc5619
難しそうであれば、下記と同様に、Pytorch -> Kerasを検討。
https://github.com/tirthasheshpatel/segment_anything_keras

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 21, 2024

image_encoderをonnxには変換できるが、edge-ai-torchとCUDAでtfliteに変換しようとすると下記のエラーになる。

ValueError: Cannot view a tensor with shape torch.Size([1024, 4, 4, 288]) and strides (4608, 4, 1, 16) as a tensor with shape (16384, 288)!

While executing %view_38 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%view_37, [16384, 288]), kwargs = {})
Original traceback:
  File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/sam2_base.py", line 196, in forward
    backbone_out = self.forward_image(input_image)
  File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/sam2_base.py", line 485, in forward_image
    backbone_out = self.image_encoder(img_batch)
  File "/home/kyakuno/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/backbones/image_encoder.py", line 31, in forward
    features, pos = self.neck(self.trunk(sample))
  File "/home/kyakuno/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/backbones/hieradet.py", line 284, in forward
    x = blk(x)
  File "/home/kyakuno/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/backbones/hieradet.py", line 147, in forward
    x = self.attn(x)
  File "/home/kyakuno/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/backbones/hieradet.py", line 77, in forward
    x = self.proj(x)

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 21, 2024

segment_anything_kerasにSAM2対応のIssueはあるが未対応。
tirthasheshpatel/segment_anything_keras#4

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 21, 2024

edge-ai-torchをCPUモードで動かすとエラーが変わる。

<unknown>:0: error: failed while converting: 'main':
Some ops are not supported by the native TFLite runtime, you can enable TF kernels fallback using TF Select. See instructions: https://www.tensorflow.org/lite/guide/ops_select
TF Select ops: Relu6
Details:
        tf.Relu6(tensor<256x1xi64>) -> (tensor<256x1xi64>)
        tf.Relu6(tensor<256xi64>) -> (tensor<256xi64>)

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 21, 2024

tfliteのconverterには_ai_edge_converter_flagsでフラグを与えらえる。
https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 21, 2024

下記でFlexを有効にするとImageEncoderのエクスポート自体はできた。

            import ai_edge_torch
            import tensorflow as tf
            sample_inputs = (input_image,)
            tfl_converter_flags = {'target_spec': {'supported_ops': [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]}}
            edge_model = ai_edge_torch.convert(self.model, sample_inputs, _ai_edge_converter_flags=tfl_converter_flags)
            edge_model.export("image_encoder.tflite")

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 21, 2024

Relu6はFlexRelu6になる。

スクリーンショット 2024-08-21 15 29 38

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 21, 2024

SegmentAnthing Int8の論文。ピークが2つある分布になるため、対処が必要と記載がある。
https://openaccess.thecvf.com/content/CVPR2024/papers/Lv_PTQ4SAM_Post-Training_Quantization_for_Segment_Anything_CVPR_2024_paper.pdf

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 21, 2024

edge-ai-torchでflexを有効にして量子化すると、MixedPrecisionのグラフになる。
Convはint8で、それ以外のオペレータはFloatで動く。

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 21, 2024

full int8 quantはまだサポートされていない気配がある。
google-ai-edge/ai-edge-torch#150

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 21, 2024

量子化対象のオペレータのリストが下記にある。
https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/quantize/pt2e_quantizer.py

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

PromptEncoderは下記のエラーになる。

RuntimeError: This model contains ops not capturable by Pytorch/XLA: aten::nonzero

エクスポートできない原因は、prompt_encoder.pyの_embed_pointsの下記のロジック。
ONNXだとWhereになる部分。

        point_embedding[labels == -1] = self.not_a_point_embed.weight
        point_embedding[labels == 0] += self.point_embeddings[0].weight
        point_embedding[labels == 1] += self.point_embeddings[1].weight
        point_embedding[labels == 2] += self.point_embeddings[2].weight
        point_embedding[labels == 3] += self.point_embeddings[3].weight

下記のようにするとエクスポートできるようになる。

        labels = labels.int()
        table = torch.zeros((5, self.point_embeddings[0].weight.shape[1]))
        table[0] = self.not_a_point_embed.weight
        table[1] = self.point_embeddings[0].weight
        table[2] = self.point_embeddings[1].weight
        table[3] = self.point_embeddings[2].weight
        table[4] = self.point_embeddings[3].weight
        for i in range(labels.shape[0]):
            point_embedding[i] = point_embedding[i] + table[labels[i] + 1]

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

ImageEncoderをtfliteで推論してみる。ImageEncoderは正常にexportできている。

output1

@kyakuno kyakuno self-assigned this Aug 23, 2024
@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

MaskDecoderは下記のエラーになる。

  File "/home/kyakuno/.local/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 221, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_function ConstantVariable(int: 512) [ConstantVariable()] {}

from user code:
   File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/sam/mask_decoder.py", line 137, in forward
    masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
  File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/sam/mask_decoder.py", line 198, in predict_masks
    sparse_prompt_embeddings.size(0), -1, -1

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

I0000 00:00:1724378905.868815    9151 cpu_client.cc:470] TfrtCpuClient destroyed.

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

size -> shapeに置き換えるとここはpassする。

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

量子化でキャリブレーションしようとすると下記のエラーになる。

torch.histogram: input tensor and hist tensor should have the same dtype, but got input long int and hist float

pytorch/pytorch#74420

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

グラフの入力が問題ではなく、グラフの途中でint64のテンソルが出てきて対応できなくなっている。

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

torch/ao/quantization/observer.pyのreset_histogramで下記のコードを追加すると通る。

    def reset_histogram(self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor) -> None:
        self.min_val.resize_(min_val.shape)
        self.min_val.copy_(min_val)
        self.max_val.resize_(max_val.shape)
        self.max_val.copy_(max_val)
        assert (
            min_val.numel() == 1 and max_val.numel() == 1
        ), "histogram min/max values must be scalar."
        if x.dtype != torch.float32: # 追加
            x = x.float() # 追加
        torch.histc(
            x, self.bins, min=min_val, max=max_val, out=self.histogram  # type: ignore[arg-type]
        )

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

int8版のImageEncoderの出力。魂はあっていそう。

output1

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

使用したバージョン
torch 2.4.0
ai-edge-torch 0.2.0

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

PromptEncoderはlabelがint64で量子化できない。

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

PromptEncoderは演算量が少ないのでfloatで動かしてもいい気はする。

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

MaskDecoderの量子化結果。ImageEncoderはfloat。

output1

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

ai-edge-torchがLLMを対象にしているためか、テンソルは結構、floatになっている。
量子化ツールに手を入れないといけない感じはする。

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 23, 2024

ImageEncoderのAttentionのところは、weightはint8になっていて、floatにして行列積を行っている。

スクリーンショット 2024-08-23 14 12 20

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 26, 2024

ImageEncoderの出力にposを含めると、量子化でエラーが発生する。
is_dynamic=Trueにして、DynamicQuantizationにすると通るが、演算は全てFloatになる。

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 26, 2024

ImageEncoderのDynamicQuantizationの出力。出力は綺麗。

output1

@kyakuno
Copy link
Contributor Author

kyakuno commented Aug 26, 2024

This was referenced Aug 28, 2024
@kyakuno
Copy link
Contributor Author

kyakuno commented Sep 5, 2024

ImageEncoder、PromptEncoder、MaskDecoderは正常にtflite (float)に変換できた。

MemoryAttentionの変換は下記で行う。
axinc-ai/ailia-models#1514

@kyakuno kyakuno changed the title ADD SAM2 ADD SAM2 (tflite) Sep 5, 2024
@kyakuno
Copy link
Contributor Author

kyakuno commented Sep 11, 2024

MaskDecoderをdynamic shapeにしようとすると、下記のエラーになる。

 File "/home/kyakuno/.local/lib/python3.10/site-packages/torch_xla/experimental/unbounded_dynamism_export.py", line 115, in decompose_dynamic_shape_select
    assert symbolic_dims[
AssertionError: Selected dim cannot be symbolic

tfliteはStatic Shapeのみ対応にする。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant