Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples] add train flux-controlnet scripts in example. #9324

Open
wants to merge 39 commits into
base: main
Choose a base branch
from

Conversation

PromeAIpro
Copy link

What does this PR do?

In this commit we add train flux-controlnet scripts in examples, and tested it on A100-SXM4-80GB.

Using this train script, We can customize the number of layers of the transformer, by setting --num_double_layers=4 --num_single_layers=0 , by this setting, the GPU memory demand is 60G, with batchsize 2, and 512 resolution.

discussed in #9085

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 4, 2024

@haofanwang @wangqixun
would you be willing to give this a review if you have time?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@linjiapro
Copy link
Contributor

linjiapro commented Sep 11, 2024

@PromeAIpro

Can we have some sample training results (such as images) from this script attached in the doc, or anywhere suitable?

@PromeAIpro
Copy link
Author

PromeAIpro commented Sep 13, 2024

Here are some training results by lineart controlnet.

input output prompt
ComfyUI_temp_egnkb_00001_ ComfyUI_00027_ cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black gold leaf pattern dress and a white apron mouth open holding a fancy black forest cake with candles on top in the kitchen of an old dark Victorian mansion lit by candlelight with a bright window to the foggy forest and very expensive stuff everywhere
ComfyUI_temp_znagh_00001_ ComfyUI_temp_cufps_00002_ a busy urban intersection during daytime. The sky is partly cloudy with a mix of blue and white clouds. There are multiple traffic lights, and vehicles are seen waiting at the red signals. Several businesses and shops are visible on the side, with signboards and advertits. The road is wide, and there are pedestrian crossings. Overall, it appears to be a typical day in a bustling city.

First train on 512res and then fine-tune with 1024res

* `report_to="tensorboard` will ensure the training runs are tracked on Weights and Biases.
* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.

Our experiments were conducted on a single 40GB A100 GPU.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, 40GB A100 seems doable.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, this is the 80g A100 (I wrote it wrong), I did a lot of extra work to get it to train with the zero3 on the 40g A100, but I don't think this is suitable for everyone

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at all. I think it would still be nice to include the changes you had to make in the form of notes in the README. Does that work?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll see if I can add it later.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul We added a tutorial on configuring deepspeed in the readme.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some tricks to lower GPU:

  1. gradient_checkpointing
  2. bf16 or fp16.
  3. batch size 1, and then use gradient_accumulation_steps above 1

With 1, 2, 3, can this thing be controlled to be trained under 40GB?

Copy link
Author

@PromeAIpro PromeAIpro Sep 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to my practice, deepspeedzero3 must be used, @linjiapro your settings will cost about 70g when 1024 with bs 1 or 512 with bs 3.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry to bother you, have you ever tried cache text-encoder and vae latents to run with lower GPU? @PromeAIpro @linjiapro

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache text-encoder is already available in this script (saving about 10g of gpu memory on T5), about cache vae You can check how to use deepspeed in the readme, which includes cache vae.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for your PR. I just left some initial comments. LMK what you think.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Appreciate your hard work here. Left some more comments.

examples/controlnet/README_flux.md Outdated Show resolved Hide resolved
src/diffusers/pipelines/flux/pipeline_flux_controlnet.py Outdated Show resolved Hide resolved
examples/controlnet/train_controlnet_flux.py Outdated Show resolved Hide resolved
examples/controlnet/train_controlnet_flux.py Outdated Show resolved Hide resolved
examples/controlnet/train_controlnet_flux.py Show resolved Hide resolved
examples/controlnet/train_controlnet_flux.py Outdated Show resolved Hide resolved
examples/controlnet/train_controlnet_flux.py Outdated Show resolved Hide resolved
examples/controlnet/train_controlnet_flux.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member

Can we fix the code quality issues? make quality && make style?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some additional minor comments but I see existing comments are yet to be addressed. Let me know when you would like another round of review.

examples/controlnet/README_flux.md Show resolved Hide resolved
examples/controlnet/README_flux.md Outdated Show resolved Hide resolved
examples/controlnet/train_controlnet_flux.py Outdated Show resolved Hide resolved
examples/controlnet/train_controlnet_flux.py Show resolved Hide resolved
@Laidawang
Copy link

@sayakpaul hey, I think I have fixed all the issues, time to start a new review.

Comment on lines +1257 to +1270
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)

# Add noise according to flow matching.
sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)
noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we were using a different timestep sampling procedure and I suggested to have that as a default. Are we not doing that anymore?

Copy link
Author

@PromeAIpro PromeAIpro Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean to set the original sampling scheme as default?
image
For the weighting schema i just copied from here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I meant to keep the sigmoid sampling as your default and let users configure it as we do in the other scripts.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please write it down briefly? I'm not sure how to edit it. It seems to me that if you use logit_normal, you should be using sigmoid?
image

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just need to change weighting_scheme from the default value to logit_normal?
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. But it depends on an std and mean. IIRC your scheme did torch.randn() and applied sigmoid right?

Copy link
Author

@PromeAIpro PromeAIpro Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this uses torch.randn() at first, but after given the examples you provided, I think this is maybe a better solution for us?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments but my concerns:

  • Why remove the previous timesteps computing scheme?
  • Let's provide a reasonable ControlNet checkpoint derived from your experiments.

LMK if anything is unclear.

@PromeAIpro PromeAIpro closed this Sep 19, 2024
@sayakpaul
Copy link
Member

@PromeAIpro we didn't have to close this PR. Is there anything we could do to revive this PR? We could very much like to do that. Please let us know.

@sayakpaul sayakpaul reopened this Sep 19, 2024
@PromeAIpro
Copy link
Author

@PromeAIpro we didn't have to close this PR. Is there anything we could do to revive this PR? We could very much like to do that. Please let us know.

sry, i do it by mistake

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I think this is looking good. Some minor comments.

Also, we would need to add tests like in https://github.com/huggingface/diffusers/blob/main/examples/controlnet/test_controlnet.py.

@yiyixuxu could you review the changes made to the ControlNet pipeline?

examples/controlnet/README_flux.md Outdated Show resolved Hide resolved
Comment on lines +1257 to +1270
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)

# Add noise according to flow matching.
sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)
noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. But it depends on an std and mean. IIRC your scheme did torch.randn() and applied sigmoid right?

@PromeAIpro
Copy link
Author

Thanks. I think this is looking good. Some minor comments.

Also, we would need to add tests like in https://github.com/huggingface/diffusers/blob/main/examples/controlnet/test_controlnet.py.

@yiyixuxu could you review the changes made to the ControlNet pipeline?

added test in test_controlnet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants