-
Notifications
You must be signed in to change notification settings - Fork 340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CodeCamp2023-325] Find the proper learning rate #1318
base: main
Are you sure you want to change the base?
Changes from 33 commits
67e9dc2
b714913
a923847
4c9ef09
3580dd8
55364e0
882271a
6285928
0431eb0
b5985fb
4b5a249
6846aba
a320ee1
ccb8f07
bae8605
18fd768
71b4b2a
fecfacb
010a3f1
69e62a7
23d1f97
482a9e5
04b46a3
3418ddc
92ad439
308ece3
0767f52
70d91e4
1e12211
c4a7e04
4d71002
cfc3f6a
3488ae1
c0d8e45
62d6777
27bb08b
afb5af2
0cdf020
48e2abc
eb6b387
16d5186
8f5ee32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Find the Optimal Learning Rate | ||
|
||
## Install external dependencies | ||
|
||
First, you should install `nevergrad` for tuning. | ||
|
||
```bash | ||
pip install nevergrad | ||
``` | ||
|
||
## Run the example | ||
|
||
Single device training | ||
|
||
```bash | ||
python examples/tune/find_lr.py | ||
``` | ||
|
||
Distributed data parallel tuning | ||
|
||
```bash | ||
torchrun --nnodes 1 --nproc_per_node 8 examples/tune/find_lr.py --launcher pytorch | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import argparse | ||
import tempfile | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.utils.data import Dataset | ||
|
||
from mmengine.evaluator import BaseMetric | ||
from mmengine.model import BaseModel | ||
from mmengine.registry import DATASETS, METRICS, MODELS | ||
from mmengine.runner import Runner | ||
|
||
|
||
class ToyModel(BaseModel): | ||
|
||
def __init__(self, data_preprocessor=None): | ||
super().__init__(data_preprocessor=data_preprocessor) | ||
self.linear1 = nn.Linear(2, 32) | ||
self.linear2 = nn.Linear(32, 64) | ||
self.linear3 = nn.Linear(64, 1) | ||
|
||
def forward(self, inputs, data_samples=None, mode='tensor'): | ||
if isinstance(inputs, list): | ||
inputs = torch.stack(inputs) | ||
if isinstance(data_samples, list): | ||
data_samples = torch.stack(data_samples) | ||
outputs = self.linear1(inputs) | ||
outputs = self.linear2(outputs) | ||
outputs = self.linear3(outputs) | ||
|
||
if mode == 'tensor': | ||
return outputs | ||
elif mode == 'loss': | ||
loss = ((data_samples - outputs)**2).mean() | ||
outputs = dict(loss=loss) | ||
return outputs | ||
elif mode == 'predict': | ||
return outputs | ||
|
||
|
||
class ToyDataset(Dataset): | ||
METAINFO = dict() # type: ignore | ||
num_samples = 100 | ||
data = torch.rand(num_samples, 2) * 10 | ||
label = 3 * data[:, 0] + 4 * data[:, 1] + torch.randn(num_samples) * 0.1 | ||
|
||
@property | ||
def metainfo(self): | ||
return self.METAINFO | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, index): | ||
return dict(inputs=self.data[index], data_samples=self.label[index]) | ||
|
||
|
||
class ToyMetric(BaseMetric): | ||
|
||
def __init__(self, collect_device='cpu'): | ||
super().__init__(collect_device=collect_device) | ||
self.results = [] | ||
|
||
def process(self, data_batch, predictions): | ||
true_values = data_batch['data_samples'] | ||
sqe = [(t - p)**2 for t, p in zip(true_values, predictions)] | ||
self.results.extend(sqe) | ||
|
||
def compute_metrics(self, results=None): | ||
mse = torch.tensor(self.results).mean().item() | ||
return dict(mse=mse) | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Distributed Tuning') | ||
parser.add_argument( | ||
'--launcher', | ||
choices=['none', 'pytorch', 'slurm', 'mpi'], | ||
default='none', | ||
help='job launcher') | ||
parser.add_argument('--local_rank', type=int, default=0) | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
MODELS.register_module(module=ToyModel, force=True) | ||
METRICS.register_module(module=ToyMetric, force=True) | ||
DATASETS.register_module(module=ToyDataset, force=True) | ||
|
||
temp_dir = tempfile.TemporaryDirectory() | ||
|
||
runner_cfg = dict( | ||
work_dir=temp_dir.name, | ||
model=dict(type='ToyModel'), | ||
train_dataloader=dict( | ||
dataset=dict(type='ToyDataset'), | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
batch_size=3, | ||
num_workers=0), | ||
val_dataloader=dict( | ||
dataset=dict(type='ToyDataset'), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
batch_size=3, | ||
num_workers=0), | ||
val_evaluator=[dict(type='ToyMetric')], | ||
test_dataloader=dict( | ||
dataset=dict(type='ToyDataset'), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
batch_size=3, | ||
num_workers=0), | ||
test_evaluator=[dict(type='ToyMetric')], | ||
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)), | ||
train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1), | ||
val_cfg=dict(), | ||
test_cfg=dict(), | ||
launcher=args.launcher, | ||
default_hooks=dict(logger=dict(type='LoggerHook', interval=1)), | ||
custom_hooks=[], | ||
env_cfg=dict(dist_cfg=dict(backend='nccl')), | ||
experiment_name='test1') | ||
|
||
runner = Runner.from_tuning( | ||
runner_cfg=runner_cfg, | ||
hparam_spec={ | ||
'optim_wrapper.optimizer.lr': { | ||
'type': 'continuous', | ||
'lower': 1e-5, | ||
'upper': 1e-3 | ||
} | ||
}, | ||
monitor='loss', | ||
rule='less', | ||
num_trials=16, | ||
tuning_epoch=2, | ||
searcher_cfg=dict(type='NevergradSearcher'), | ||
) | ||
runner.train() | ||
|
||
temp_dir.cleanup() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .searchers import * # noqa F403 | ||
from .tuner import Tuner | ||
|
||
__all__ = ['Tuner'] |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,151 @@ | ||||||
# Copyright (c) OpenMMLab. All rights reserved. | ||||||
|
||||||
from typing import Callable, Dict, List, Optional, Sequence, Union | ||||||
|
||||||
from mmengine.hooks import Hook | ||||||
|
||||||
DATA_BATCH = Optional[Union[dict, tuple, list]] | ||||||
|
||||||
|
||||||
class ReportingHook(Hook): | ||||||
"""Auxiliary hook to report the score to tuner. | ||||||
|
||||||
If tuning limit is specified, this hook will mark the loop to stop. | ||||||
|
||||||
Args: | ||||||
monitor (str): The monitored metric key to report. | ||||||
tuning_iter (int, optional): The iteration limit to stop tuning. | ||||||
Defaults to None. | ||||||
tuning_epoch (int, optional): The epoch limit to stop tuning. | ||||||
Defaults to None. | ||||||
report_op (str, optional): The operation to report the score. | ||||||
Options are 'latest', 'mean', 'min', 'max'. Defaults to 'latest'. | ||||||
max_scoreboard_len (int, optional): | ||||||
The maximum length of the scoreboard. | ||||||
""" | ||||||
|
||||||
report_op_supported: Dict[str, Callable[[List[float]], float]] = { | ||||||
'latest': lambda x: x[-1], | ||||||
'mean': lambda x: sum(x) / len(x), | ||||||
'max': max, | ||||||
'min': min | ||||||
} | ||||||
|
||||||
def __init__(self, | ||||||
monitor: str, | ||||||
tuning_iter: Optional[int] = None, | ||||||
tuning_epoch: Optional[int] = None, | ||||||
report_op: str = 'latest', | ||||||
max_scoreboard_len: int = 1024): | ||||||
assert report_op in self.report_op_supported, \ | ||||||
f'report_op {report_op} is not supported' | ||||||
yhna940 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
if tuning_iter is not None and tuning_epoch is not None: | ||||||
raise ValueError( | ||||||
'tuning_iter and tuning_epoch cannot be set at the same time') | ||||||
self.report_op = report_op | ||||||
self.tuning_iter = tuning_iter | ||||||
self.tuning_epoch = tuning_epoch | ||||||
|
||||||
self.monitor = monitor | ||||||
self.max_scoreboard_len = max_scoreboard_len | ||||||
self.scoreboard: List[float] = [] | ||||||
|
||||||
def _append_score(self, score: float): | ||||||
"""Append the score to the scoreboard.""" | ||||||
self.scoreboard.append(score) | ||||||
if len(self.scoreboard) > self.max_scoreboard_len: | ||||||
self.scoreboard.pop(0) | ||||||
|
||||||
def _should_stop(self, runner): | ||||||
"""Check if the training should be stopped. | ||||||
|
||||||
Args: | ||||||
runner (Runner): The runner of the training process. | ||||||
""" | ||||||
if self.tuning_iter is not None: | ||||||
if runner.iter + 1 >= self.tuning_iter: | ||||||
return True | ||||||
elif self.tuning_epoch is not None: | ||||||
if runner.epoch + 1 >= self.tuning_epoch: | ||||||
return True | ||||||
else: | ||||||
return False | ||||||
|
||||||
def after_train_iter(self, | ||||||
runner, | ||||||
batch_idx: int, | ||||||
data_batch: DATA_BATCH = None, | ||||||
outputs: Optional[Union[dict, Sequence]] = None, | ||||||
mode: str = 'train') -> None: | ||||||
"""Record the score after each iteration. | ||||||
|
||||||
Args: | ||||||
runner (Runner): The runner of the training process. | ||||||
batch_idx (int): The index of the current batch in the train loop. | ||||||
data_batch (dict or tuple or list, optional): Data from dataloader. | ||||||
outputs (dict, optional): Outputs from model. | ||||||
""" | ||||||
|
||||||
tag, _ = runner.log_processor.get_log_after_iter( | ||||||
runner, batch_idx, 'train') | ||||||
score = tag.get(self.monitor, None) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggest adding a prefix to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also need to check the monitored value is a number. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following your advice, I have enhanced the monitoring process by specifying prefixes to it. Moreover, I've embedded logic to verify that the monitored values are numerical to prevent potential errors. |
||||||
if score is not None: | ||||||
self._append_score(score) | ||||||
|
||||||
if self._should_stop(runner): | ||||||
runner.train_loop.stop_training = True | ||||||
|
||||||
def after_train_epoch(self, runner) -> None: | ||||||
"""Record the score after each epoch. | ||||||
|
||||||
Args: | ||||||
runner (Runner): The runner of the training process. | ||||||
""" | ||||||
if self._should_stop(runner): | ||||||
runner.train_loop.stop_training = True | ||||||
|
||||||
def after_val_epoch(self, | ||||||
runner, | ||||||
metrics: Optional[Dict[str, float]] = None) -> None: | ||||||
"""Record the score after each validation epoch. | ||||||
|
||||||
Args: | ||||||
runner (Runner): The runner of the validation process. | ||||||
metrics (Dict[str, float], optional): Evaluation results of all | ||||||
metrics on validation dataset. The keys are the names of the | ||||||
metrics, and the values are corresponding results. | ||||||
""" | ||||||
if metrics is None: | ||||||
return | ||||||
score = metrics.get(self.monitor, None) | ||||||
if score is not None: | ||||||
self._append_score(score) | ||||||
|
||||||
def report_score(self) -> Optional[float]: | ||||||
"""Aggregate the scores in the scoreboard. | ||||||
|
||||||
Returns: | ||||||
Optional[float]: The aggregated score. | ||||||
""" | ||||||
if not self.scoreboard: | ||||||
score = None | ||||||
else: | ||||||
operation = self.report_op_supported[self.report_op] | ||||||
score = operation(self.scoreboard) | ||||||
return score | ||||||
|
||||||
@classmethod | ||||||
def register_report_op(cls, name: str, func: Callable[[List[float]], | ||||||
float]): | ||||||
"""Register a new report operation. | ||||||
|
||||||
Args: | ||||||
name (str): The name of the report operation. | ||||||
func (Callable[[List[float]], float]): The function to aggregate | ||||||
the scores. | ||||||
""" | ||||||
cls.report_op_supported[name] = func | ||||||
|
||||||
def clear(self): | ||||||
"""Clear the scoreboard.""" | ||||||
self.scoreboard.clear() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scoreboard
is a new conception for users. We need to introduce it here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify the newly introduced concept of the
scoreboard
, I have incorporated additional comments in the relevant section to guide users regarding its purpose and usage.