Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD LivePortrait #1506

Open
kyakuno opened this issue Jul 11, 2024 · 6 comments
Open

ADD LivePortrait #1506

kyakuno opened this issue Jul 11, 2024 · 6 comments
Assignees

Comments

@kyakuno
Copy link
Collaborator

kyakuno commented Jul 11, 2024

https://github.com/KwaiVGI/LivePortrait

@kyakuno
Copy link
Collaborator Author

kyakuno commented Jul 11, 2024

Test
https://huggingface.co/spaces/KwaiVGI/LivePortrait

1_0zpPqtPX6oi7LkAV8u75Bg--d6_trim_concat.mp4

@kyakuno
Copy link
Collaborator Author

kyakuno commented Jul 12, 2024

@ooe1123 GroundingDINOとSAMの対応、ありがとうございます。SAMの後、可能であればこちらをお願いできると嬉しいです。

@kyakuno
Copy link
Collaborator Author

kyakuno commented Jul 13, 2024

入力画像とリファレンス画像の顔のキーポイントを取得、入力画像のキーポイントをリファレンス画像のキーポイントに近づけるようにAIで補正、変形したキーポイントと入力画像をワープモジュールに入れて画像変換する。顔全体、目、リップで独立してキーポイントの補正をしている。

@kyakuno
Copy link
Collaborator Author

kyakuno commented Jul 13, 2024

ビデオモードで入力した画像をリファレンスとして、リアルタイムに変形したい。

@yuananf
Copy link

yuananf commented Jul 23, 2024

@ooe1123
Copy link
Contributor

ooe1123 commented Jul 25, 2024

appearance_feature_extractor.onnx

〇 src/live_portrait_wrapper.py

class LivePortraitWrapper(object):
    ...
    def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            with torch.autocast(...):
                feature_3d = self.appearance_feature_extractor(x)

class LivePortraitWrapper(object):
    ...
    def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            with torch.autocast(...):
                if 1:
                    print("------>")
                    torch.onnx.export(
                        self.appearance_feature_extractor, x, 'appearance_feature_extractor.onnx',
                        input_names=["x"],
                        output_names=["f_s"],
                        verbose=False, opset_version=17
                    )
                    print("<------")
                    exit()

motion_extractor.onnx

〇 src/live_portrait_wrapper.py

class LivePortraitWrapper(object):
    ...
    def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
        ...
        with torch.no_grad():
            with torch.autocast(...):
                kp_info = self.motion_extractor(x)

class LivePortraitWrapper(object):
    ...
    def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
        ...
        with torch.no_grad():
            with torch.autocast(...):
                kp_info = self.motion_extractor(x)

            if 1:
                class Exp(torch.nn.Module):
                    def __init__(self, motion_extractor):
                        super().__init__()
                        self.motion_extractor = motion_extractor

                    def forward(self, x):
                        kp_info = self.motion_extractor(x)
                        return kp_info["pitch"], kp_info["yaw"], kp_info["roll"], kp_info["t"], kp_info["exp"], kp_info["scale"], kp_info["kp"]

                print("------>")
                model = Exp(self.motion_extractor)
                torch.onnx.export(
                    model, x, 'motion_extractor.onnx',
                    input_names=["x"],
                    output_names=["pitch", "yaw", "roll", "t", "exp", "scale", "kp"],
                    verbose=False, opset_version=17
                )
                print("<------")
                exit()

stitching.onnx

〇 src/live_portrait_wrapper.py

class LivePortraitWrapper(object):
    ...
    def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            delta = self.stitching_retargeting_module['stitching'](feat_stiching)

class LivePortraitWrapper(object):
    ...
    def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            if 1:
                print("------>")
                torch.onnx.export(
                    self.stitching_retargeting_module['stitching'], feat_stiching, 'stitching.onnx',
                    input_names=["x"],
                    output_names=["out"],
                    verbose=False, opset_version=17
                )
                print("<------")
                exit()

warping_module.onnx

〇 src/live_portrait_wrapper.py

class LivePortraitWrapper(object):
    ...
    def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            with torch.autocast(...):
                ...
                ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)

class LivePortraitWrapper(object):
    ...
    def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            if 1:
                class Exp(torch.nn.Module):
                    def __init__(self, warping_module):
                        super().__init__()
                        self.warping_module = warping_module

                    def forward(self, feature_3d, kp_source, kp_driving):
                        ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
                        return ret_dct["out"], ret_dct["occlusion_map"], ret_dct["deformation"]

                with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
                    print("------>")
                    model = Exp(self.warping_module)
                    x = (feature_3d, kp_source, kp_driving)
                    torch.onnx.export(
                        model, x, 'warping_module.onnx',
                        input_names=["feature_3d", "kp_source", "kp_driving"],
                        output_names=["out", "occlusion_map", "deformation"],
                        verbose=False, opset_version=20
                    )
                    print("<------")
                    exit()

spade_generator.onnx

〇 src/live_portrait_wrapper.py

class LivePortraitWrapper(object):
    ...
    def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            with torch.autocast(...):
                ...
                ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)

                # decode
                ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])

class LivePortraitWrapper(object):
    ...
    def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
        ...
        with torch.no_grad():
            with torch.autocast(...):
                ...
                ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)

                if 1:
                    print("------>")
                    torch.onnx.export(
                        self.spade_generator.cpu(), ret_dct['out'].cpu().type(torch.float32), 'spade_generator.onnx',
                        input_names=["feature"],
                        output_names=["out"],
                        verbose=False, opset_version=17
                    )
                    print("<------")
                    exit()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants