Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][LoRA] Implement hot-swapping of LoRA #9453

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

BenjaminBossan
Copy link
Member

This PR adds the possibility to hot-swap LoRA adapters. It is WIP.

Description

As of now, users can already load multiple LoRA adapters. They can offload existing adapters or they can unload them (i.e. delete them). However, they cannot "hotswap" adapters yet, i.e. substitute the weights from one LoRA adapter with the weights of another, without the need to create a separate LoRA adapter.

Generally, hot-swapping may not appear not super useful but when the model is compiled, it is necessary to prevent recompilation. See #9279 for more context.

Caveats

To hot-swap a LoRA adapter for another, these two adapters should target exactly the same layers and the "hyper-parameters" of the two adapters should be identical. For instance, the LoRA alpha has to be the same: Given that we keep the alpha from the first adapter, the LoRA scaling would be incorrect for the second adapter otherwise.

Theoretically, we could override the scaling dict with the alpha values derived from the second adapter's config, but changing the dict will trigger a guard for recompilation, defeating the main purpose of the feature.

I also found that compilation flags can have an impact on whether this works or not. E.g. when passing "reduce-overhead", there will be errors of the type:

input name: arg861_1. data pointer changed from 139647332027392 to
139647331054592

I don't know enough about compilation to determine whether this is problematic or not.

Current state

This is obviously WIP right now to collect feedback and discuss which direction to take this. If this PR turns out to be useful, the hot-swapping functions will be added to PEFT itself and can be imported here (or there is a separate copy in diffusers to avoid the need for a min PEFT version to use this feature).

Moreover, more tests need to be added to better cover this feature, although we don't necessarily need tests for the hot-swapping functionality itself, since those tests will be added to PEFT.

Furthermore, as of now, this is only implemented for the unet. Other pipeline components have yet to implement this feature.

Finally, it should be properly documented.

I would like to collect feedback on the current state of the PR before putting more time into finalizing it.

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

This PR adds the possibility to hot-swap LoRA adapters. It is WIP.

Description

As of now, users can already load multiple LoRA adapters. They can
offload existing adapters or they can unload them (i.e. delete them).
However, they cannot "hotswap" adapters yet, i.e. substitute the weights
from one LoRA adapter with the weights of another, without the need to
create a separate LoRA adapter.

Generally, hot-swapping may not appear not super useful but when the
model is compiled, it is necessary to prevent recompilation. See huggingface#9279
for more context.

Caveats

To hot-swap a LoRA adapter for another, these two adapters should target
exactly the same layers and the "hyper-parameters" of the two adapters
should be identical. For instance, the LoRA alpha has to be the same:
Given that we keep the alpha from the first adapter, the LoRA scaling
would be incorrect for the second adapter otherwise.

Theoretically, we could override the scaling dict with the alpha values
derived from the second adapter's config, but changing the dict will
trigger a guard for recompilation, defeating the main purpose of the
feature.

I also found that compilation flags can have an impact on whether this
works or not. E.g. when passing "reduce-overhead", there will be errors
of the type:

> input name: arg861_1. data pointer changed from 139647332027392 to
139647331054592

I don't know enough about compilation to determine whether this is
problematic or not.

Current state

This is obviously WIP right now to collect feedback and discuss which
direction to take this. If this PR turns out to be useful, the
hot-swapping functions will be added to PEFT itself and can be imported
here (or there is a separate copy in diffusers to avoid the need for a
min PEFT version to use this feature).

Moreover, more tests need to be added to better cover this feature,
although we don't necessarily need tests for the hot-swapping
functionality itself, since those tests will be added to PEFT.

Furthermore, as of now, this is only implemented for the unet. Other
pipeline components have yet to implement this feature.

Finally, it should be properly documented.

I would like to collect feedback on the current state of the PR before
putting more time into finalizing it.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

# values as well, but that's not implemented yet, and it would trigger a re-compilation if the model is compiled.

# TODO: This is a very rough check at the moment and there are probably better ways than to error out
config_keys_to_check = ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we fetch this by inspecting the LoraConfig class init, instead? If the target_modules differ that would make hotswap incompatible as well no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I don't think we should compare target_modules. First of all, target_modules can be a list of str but also a regex. We can't really know if the two amount to the same result or not. Second, I think it's more robust to check the actually adapted modules.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Could you help me understand how / where we're accomplishing:

Second, I think it's more robust to check the actually adapted modules.


if hotswap:
_check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
_hotswap_adapter_from_state_dict(self, state_dict, adapter_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when there's no adapter already loaded? Should we check for that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a check for that:

https://github.com/huggingface/diffusers/pull/9453/files#diff-5fdf1b6715cd3459bda0d19623f2c5fb9a578221f0ee7509f043dc334fb92dc8R308

I think this should be sufficient or should something else be checked on top?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be okay.

# TODO: kinda slow, should it get a slow marker?
env = {"TORCH_LOGS": "guards,recompiles"}
here = os.path.dirname(__file__)
file_name = os.path.join(here, "run_compiled_model_hotswap.py")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make the process faster, we could:

  1. Use a small UNet with the config from
    def prepare_init_args_and_inputs_for_common(self):
  2. Inject LoRAs into it with this config:
  3. Check for recompilation.

If we decide to proceed this way, then we could do all of this inside the test case instead of having to do via a Python file under a capture logger context manager. But I am okay with having it done via a file as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please give me some guidance here? I tried to make the suggested changes (see new committed file) but load_lora_weights is not defined on the unet, so I would need to create a pipe still? Also, if I want to create the LoRA adapter in the test instead of loading an existing one, I would probably need to call save_lora_weights and store them in a temp folder, is that the right way?

we could do all of this inside the test case instead of having to do via a Python file

That would be preferable but I'm not sure how we can capture the torch logs with TORCH_LOGS=guards,recompiles without shelling out. I don't think that something like caplog or capsys is sufficient, when I tried them, they were empty.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would something like this work?

unet.load_attn_procs(

pass

def test_hotswapping_compiled_model_does_not_trigger_recompilation(self):
# TODO: kinda slow, should it get a slow marker?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes:

  1. @slow
  2. @require_torch_2
  3. @require_torch_gpu

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this. I left some comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants