Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine Tuning a Trained Model #205

Open
GaetanoPrudente opened this issue Aug 10, 2021 · 1 comment
Open

Fine Tuning a Trained Model #205

GaetanoPrudente opened this issue Aug 10, 2021 · 1 comment

Comments

@GaetanoPrudente
Copy link

Is it possible to know how to fine tune a trained model with a different dataset (augmented data)?
There is no documentation about it or about the options in order to do it

@ksv87
Copy link

ksv87 commented Aug 27, 2024

in train.py add code from yolox repo https://github.com/Megvii-BaseDetection/YOLOX/blob/f00a798c8bf59f43ab557a2f3d566afa831c8887/yolox/utils/checkpoint.py#L11

def load_ckpt(model, ckpt, logger):
    model_state_dict = model.state_dict()
    load_dict = {}
    for key_model, v in model_state_dict.items():
        if key_model not in ckpt:
            logger.warning(
                "{} is not in the ckpt. Please double check and see if this is desired.".format(
                    key_model
                )
            )
            continue
        v_ckpt = ckpt[key_model]
        if v.shape != v_ckpt.shape:
            logger.warning(
                "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
                    key_model, v_ckpt.shape, key_model, v.shape
                )
            )
            continue
        load_dict[key_model] = v_ckpt

    model.load_state_dict(load_dict, strict=False)
    return model

and add in

SSD/train.py

Lines 26 to 29 in 68dc0a2

if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
lr = cfg.SOLVER.LR * args.num_gpus # scale by num gpus

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)

    if args.finetuning_file is not None:
        ftc = torch.load(args.finetuning_file)["model"]
        model = load_ckpt(model, ftc, logger)
        logger.info(f"Loaded for finetuning {args.finetuning_file}")

    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = make_optimizer(cfg, model, lr)

in

SSD/train.py

Lines 50 to 57 in 68dc0a2

parser.add_argument(
"--config-file",
default="",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument("--local_rank", type=int, default=0)

    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "--finetuning-file",
        default=None,
        metavar="FILE",
        help="path to model for finetuning",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)

after run with python train.py --config-file configs\efficient_net_b3_ssd300_graff.yaml --finetuning-file efficient_net_b3_ssd300_voc0712.pth

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants