Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp2023-325] Find the proper learning rate #1318

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
67e9dc2
Init tuner for finding best lr
yhna940 Aug 10, 2023
b714913
Apply lint
yhna940 Aug 11, 2023
a923847
Add ex for tuning
yhna940 Aug 11, 2023
4c9ef09
Refactor to rpc
yhna940 Aug 17, 2023
3580dd8
Apply lint
yhna940 Aug 17, 2023
55364e0
Add logger to tune
yhna940 Aug 18, 2023
882271a
Fix searcher init args
yhna940 Aug 18, 2023
6285928
Apply lint
yhna940 Aug 18, 2023
0431eb0
Fix typo
yhna940 Aug 18, 2023
b5985fb
Fix minor
yhna940 Aug 21, 2023
4b5a249
Fix rpc init
yhna940 Aug 21, 2023
6846aba
Fix env for rpc
yhna940 Aug 22, 2023
a320ee1
fix rpc device map
yhna940 Aug 23, 2023
ccb8f07
Del rpc
yhna940 Aug 23, 2023
bae8605
Fix examples
yhna940 Aug 23, 2023
18fd768
Fix minor
yhna940 Aug 23, 2023
71b4b2a
Fix typo
yhna940 Aug 23, 2023
fecfacb
Split seachers
yhna940 Aug 28, 2023
010a3f1
Comment the tuner
yhna940 Aug 28, 2023
69e62a7
Rename solver of nevergrad
yhna940 Aug 28, 2023
23d1f97
Comment the report hook
yhna940 Aug 28, 2023
482a9e5
Comment the searchers
yhna940 Aug 28, 2023
04b46a3
Add readme for tune
yhna940 Aug 28, 2023
3418ddc
Add error logging
yhna940 Aug 29, 2023
92ad439
Add unittest for tune
yhna940 Aug 30, 2023
308ece3
Apply lint
yhna940 Aug 30, 2023
0767f52
Add random searcher
yhna940 Aug 30, 2023
70d91e4
Fix unittest bug
yhna940 Aug 30, 2023
1e12211
Fix tuner unittest
yhna940 Aug 31, 2023
c4a7e04
Add tuning interface for runner
yhna940 Aug 31, 2023
4d71002
Fix minor
yhna940 Aug 31, 2023
cfc3f6a
Refactor report op
yhna940 Sep 1, 2023
3488ae1
Fix report bug
yhna940 Sep 1, 2023
c0d8e45
Update mmengine/tune/_report_hook.py
yhna940 Sep 9, 2023
62d6777
Merge branch 'open-mmlab:main' into feature/hyper-naive
yhna940 Sep 9, 2023
27bb08b
Fix comment for report hook
yhna940 Sep 9, 2023
afb5af2
Specify phase in monitor
yhna940 Sep 9, 2023
0cdf020
Fix comment on tuner for monitor
yhna940 Sep 9, 2023
48e2abc
Apply reduce operation to score in trial
yhna940 Sep 9, 2023
eb6b387
Fix comment on tuner
yhna940 Sep 9, 2023
16d5186
Enhance safe trial during tune
yhna940 Sep 9, 2023
8f5ee32
Fix unittest bug
yhna940 Sep 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions examples/tune/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Find the Optimal Learning Rate

## Install external dependencies

First, you should install `nevergrad` for tuning.

```bash
pip install nevergrad
```

## Run the example

Single device training

```bash
python examples/tune/find_lr.py
```

Distributed data parallel tuning

```bash
torchrun --nnodes 1 --nproc_per_node 8 examples/tune/find_lr.py --launcher pytorch
```
147 changes: 147 additions & 0 deletions examples/tune/find_lr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import argparse
import tempfile

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.registry import DATASETS, METRICS, MODELS
from mmengine.runner import Runner


class ToyModel(BaseModel):

def __init__(self, data_preprocessor=None):
super().__init__(data_preprocessor=data_preprocessor)
self.linear1 = nn.Linear(2, 32)
self.linear2 = nn.Linear(32, 64)
self.linear3 = nn.Linear(64, 1)

def forward(self, inputs, data_samples=None, mode='tensor'):
if isinstance(inputs, list):
inputs = torch.stack(inputs)
if isinstance(data_samples, list):
data_samples = torch.stack(data_samples)
outputs = self.linear1(inputs)
outputs = self.linear2(outputs)
outputs = self.linear3(outputs)

if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = ((data_samples - outputs)**2).mean()
outputs = dict(loss=loss)
return outputs
elif mode == 'predict':
return outputs


class ToyDataset(Dataset):
METAINFO = dict() # type: ignore
num_samples = 100
data = torch.rand(num_samples, 2) * 10
label = 3 * data[:, 0] + 4 * data[:, 1] + torch.randn(num_samples) * 0.1

@property
def metainfo(self):
return self.METAINFO

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return dict(inputs=self.data[index], data_samples=self.label[index])


class ToyMetric(BaseMetric):

def __init__(self, collect_device='cpu'):
super().__init__(collect_device=collect_device)
self.results = []

def process(self, data_batch, predictions):
true_values = data_batch['data_samples']
sqe = [(t - p)**2 for t, p in zip(true_values, predictions)]
self.results.extend(sqe)

def compute_metrics(self, results=None):
mse = torch.tensor(self.results).mean().item()
return dict(mse=mse)


def parse_args():
parser = argparse.ArgumentParser(description='Distributed Tuning')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)

args = parser.parse_args()
return args


def main():
args = parse_args()

MODELS.register_module(module=ToyModel, force=True)
METRICS.register_module(module=ToyMetric, force=True)
DATASETS.register_module(module=ToyDataset, force=True)

temp_dir = tempfile.TemporaryDirectory()

runner_cfg = dict(
work_dir=temp_dir.name,
model=dict(type='ToyModel'),
train_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
val_evaluator=[dict(type='ToyMetric')],
test_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
test_evaluator=[dict(type='ToyMetric')],
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)),
train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1),
val_cfg=dict(),
test_cfg=dict(),
launcher=args.launcher,
default_hooks=dict(logger=dict(type='LoggerHook', interval=1)),
custom_hooks=[],
env_cfg=dict(dist_cfg=dict(backend='nccl')),
experiment_name='test1')

runner = Runner.from_tuning(
runner_cfg=runner_cfg,
hparam_spec={
'optim_wrapper.optimizer.lr': {
'type': 'continuous',
'lower': 1e-5,
'upper': 1e-3
}
},
monitor='train/loss',
rule='less',
num_trials=16,
tuning_epoch=2,
searcher_cfg=dict(type='NevergradSearcher'),
)
runner.train()

temp_dir.cleanup()


if __name__ == '__main__':
main()
55 changes: 55 additions & 0 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS,
MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS,
RUNNERS, VISUALIZERS, DefaultScope)
from mmengine.tune import Tuner
from mmengine.utils import apply_to, digit_version, get_git_hash, is_seq_of
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
set_multi_processing)
Expand Down Expand Up @@ -475,6 +476,60 @@ def from_cfg(cls, cfg: ConfigType) -> 'Runner':

return runner

@classmethod
def from_tuning(
cls,
runner_cfg: ConfigType,
hparam_spec: Dict,
monitor: str,
rule: str,
num_trials: int,
tuning_iter: Optional[int] = None,
tuning_epoch: Optional[int] = None,
report_op: str = 'latest',
searcher_cfg: Dict = dict(type='RandomSearcher')
) -> 'Runner':
"""Build a runner from tuning.

Args:
runner_cfg (ConfigType): A config used for building runner. Keys of
``runner_cfg`` can see :meth:`__init__`.
hparam_spec (Dict): A dict of hyper parameters to be tuned.
monitor (str): The metric name to be monitored.
rule (Dict): The rule to measure the best metric.
num_trials (int): The maximum number of trials for tuning.
tuning_iter (Optional[int]): The maximum iterations for each trial.
If specified, tuning stops after reaching this limit.
Default is None, indicating no specific iteration limit.
tuning_epoch (Optional[int]): The maximum epochs for each trial.
If specified, tuning stops after reaching this number
of epochs. Default is None, indicating no epoch limit.
report_op (str):
Operation mode for metric reporting. Default is 'latest'.
searcher_cfg (Dict): Configuration for the searcher.
Default is `dict(type='RandomSearcher')`.

Returns:
Runner: A runner build from ``runner_cfg`` tuned by trials.
"""

runner_cfg = copy.deepcopy(runner_cfg)
tuner = Tuner(
runner_cfg=runner_cfg,
hparam_spec=hparam_spec,
monitor=monitor,
rule=rule,
num_trials=num_trials,
tuning_iter=tuning_iter,
tuning_epoch=tuning_epoch,
report_op=report_op,
searcher_cfg=searcher_cfg)
hparam = tuner.tune()['hparam']
assert isinstance(hparam, dict), 'hparam should be a dict'
for k, v in hparam.items():
Tuner.inject_config(runner_cfg, k, v)
return cls.from_cfg(runner_cfg)

@property
def experiment_name(self):
"""str: Name of experiment."""
Expand Down
5 changes: 5 additions & 0 deletions mmengine/tune/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .searchers import * # noqa F403
from .tuner import Tuner

__all__ = ['Tuner']
Loading
Loading