-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP][LoRA] Implement hot-swapping of LoRA #9453
base: main
Are you sure you want to change the base?
Conversation
This PR adds the possibility to hot-swap LoRA adapters. It is WIP. Description As of now, users can already load multiple LoRA adapters. They can offload existing adapters or they can unload them (i.e. delete them). However, they cannot "hotswap" adapters yet, i.e. substitute the weights from one LoRA adapter with the weights of another, without the need to create a separate LoRA adapter. Generally, hot-swapping may not appear not super useful but when the model is compiled, it is necessary to prevent recompilation. See huggingface#9279 for more context. Caveats To hot-swap a LoRA adapter for another, these two adapters should target exactly the same layers and the "hyper-parameters" of the two adapters should be identical. For instance, the LoRA alpha has to be the same: Given that we keep the alpha from the first adapter, the LoRA scaling would be incorrect for the second adapter otherwise. Theoretically, we could override the scaling dict with the alpha values derived from the second adapter's config, but changing the dict will trigger a guard for recompilation, defeating the main purpose of the feature. I also found that compilation flags can have an impact on whether this works or not. E.g. when passing "reduce-overhead", there will be errors of the type: > input name: arg861_1. data pointer changed from 139647332027392 to 139647331054592 I don't know enough about compilation to determine whether this is problematic or not. Current state This is obviously WIP right now to collect feedback and discuss which direction to take this. If this PR turns out to be useful, the hot-swapping functions will be added to PEFT itself and can be imported here (or there is a separate copy in diffusers to avoid the need for a min PEFT version to use this feature). Moreover, more tests need to be added to better cover this feature, although we don't necessarily need tests for the hot-swapping functionality itself, since those tests will be added to PEFT. Furthermore, as of now, this is only implemented for the unet. Other pipeline components have yet to implement this feature. Finally, it should be properly documented. I would like to collect feedback on the current state of the PR before putting more time into finalizing it.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
# values as well, but that's not implemented yet, and it would trigger a re-compilation if the model is compiled. | ||
|
||
# TODO: This is a very rough check at the moment and there are probably better ways than to error out | ||
config_keys_to_check = ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we fetch this by inspecting the LoraConfig
class init, instead? If the target_modules
differ that would make hotswap incompatible as well no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I don't think we should compare target_modules
. First of all, target_modules
can be a list of str but also a regex. We can't really know if the two amount to the same result or not. Second, I think it's more robust to check the actually adapted modules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Could you help me understand how / where we're accomplishing:
Second, I think it's more robust to check the actually adapted modules.
|
||
if hotswap: | ||
_check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) | ||
_hotswap_adapter_from_state_dict(self, state_dict, adapter_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens when there's no adapter already loaded? Should we check for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a check for that:
I think this should be sufficient or should something else be checked on top?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be okay.
# TODO: kinda slow, should it get a slow marker? | ||
env = {"TORCH_LOGS": "guards,recompiles"} | ||
here = os.path.dirname(__file__) | ||
file_name = os.path.join(here, "run_compiled_model_hotswap.py") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make the process faster, we could:
- Use a small UNet with the config from
def prepare_init_args_and_inputs_for_common(self): - Inject LoRAs into it with this config:
def get_unet_lora_config(): - Check for recompilation.
If we decide to proceed this way, then we could do all of this inside the test case instead of having to do via a Python file under a capture logger context manager. But I am okay with having it done via a file as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please give me some guidance here? I tried to make the suggested changes (see new committed file) but load_lora_weights
is not defined on the unet
, so I would need to create a pipe still? Also, if I want to create the LoRA adapter in the test instead of loading an existing one, I would probably need to call save_lora_weights
and store them in a temp folder, is that the right way?
we could do all of this inside the test case instead of having to do via a Python file
That would be preferable but I'm not sure how we can capture the torch logs with TORCH_LOGS=guards,recompiles
without shelling out. I don't think that something like caplog
or capsys
is sufficient, when I tried them, they were empty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would something like this work?
unet.load_attn_procs( |
pass | ||
|
||
def test_hotswapping_compiled_model_does_not_trigger_recompilation(self): | ||
# TODO: kinda slow, should it get a slow marker? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes:
@slow
@require_torch_2
@require_torch_gpu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this. I left some comments.
This PR adds the possibility to hot-swap LoRA adapters. It is WIP.
Description
As of now, users can already load multiple LoRA adapters. They can offload existing adapters or they can unload them (i.e. delete them). However, they cannot "hotswap" adapters yet, i.e. substitute the weights from one LoRA adapter with the weights of another, without the need to create a separate LoRA adapter.
Generally, hot-swapping may not appear not super useful but when the model is compiled, it is necessary to prevent recompilation. See #9279 for more context.
Caveats
To hot-swap a LoRA adapter for another, these two adapters should target exactly the same layers and the "hyper-parameters" of the two adapters should be identical. For instance, the LoRA alpha has to be the same: Given that we keep the alpha from the first adapter, the LoRA scaling would be incorrect for the second adapter otherwise.
Theoretically, we could override the scaling dict with the alpha values derived from the second adapter's config, but changing the dict will trigger a guard for recompilation, defeating the main purpose of the feature.
I also found that compilation flags can have an impact on whether this works or not. E.g. when passing "reduce-overhead", there will be errors of the type:
I don't know enough about compilation to determine whether this is problematic or not.
Current state
This is obviously WIP right now to collect feedback and discuss which direction to take this. If this PR turns out to be useful, the hot-swapping functions will be added to PEFT itself and can be imported here (or there is a separate copy in diffusers to avoid the need for a min PEFT version to use this feature).
Moreover, more tests need to be added to better cover this feature, although we don't necessarily need tests for the hot-swapping functionality itself, since those tests will be added to PEFT.
Furthermore, as of now, this is only implemented for the unet. Other pipeline components have yet to implement this feature.
Finally, it should be properly documented.
I would like to collect feedback on the current state of the PR before putting more time into finalizing it.
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.