Skip to content

Commit

Permalink
readme: update training step docs
Browse files Browse the repository at this point in the history
  • Loading branch information
raehik committed Nov 9, 2023
1 parent 1078c72 commit e2a9cb2
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 27 deletions.
76 changes: 49 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,29 +159,37 @@ functions at [`step/data/lib.py`](src/gz21_ocean_momentum/step/data/lib.py). See
the CLI script for example usage.

#### Training
The [`trainScript.py`](src/gz21_ocean_momentum/trainScript.py) script runs the
model training step. You may configure various training parameters through
command-line arguments, such as number of training epochs, loss functions, and
training data. (You will want to select the output from a data processing step
for the latter.)
[cli-train]: src/gz21_ocean_momentum/trainScript.py

The [`trainScript.py`][cli-train] script runs the model training step. You may
configure various training parameters through command-line arguments, such as
number of training epochs, loss functions, and training data. (You will want to
select the output from a data processing step for the latter.)

MLflow call example:

```
mlflow run . --experiment-name <name> -e train --env-manager=local \
-P run_id=<run id> \
-P learning_rate=0/5e-4/15/5e-5/30/5e-6 -P n_epochs=200 -P weight_decay=0.00 -P train_split=0.8 \
-P test_split=0.85 -P model_module_name=models.models1 -P model_cls_name=FullyCNN -P batchsize=4 \
-P subdomains_file=examples/cli-configs/training-subdomains-paper.yaml \
-P learning_rate=0/5e-4/15/5e-5/30/5e-6 -P weight_decay=0.00 \
-P n_epochs=200 -P batchsize=4 \
-P train_split=0.8 -P test_split=0.85 \
-P model_module_name=models.models1 -P model_cls_name=FullyCNN \
-P transformation_cls_name=SoftPlusTransform -P submodel=transform3 \
-P loss_cls_name=HeteroskedasticGaussianLossV2
```

Relevant parameters:

* `exp_id`: id of the experiment containing the run that generated the forcing
data.
* `run_id`: id of the run that generated the forcing data that will be used for
training.
* `run_id`: MLflow run ID of the run that generated the forcing data that will
be used for training.
* `subdomains_file`: path to YAML file storing a list of subdomains to select
from the forcing data, which are then used for training. (Note that at
runtime, domains are be truncated to the size of the smallest domain in terms
of number of points.)
* `train_split`: use `0->N` percent of the dataset for training
* `test_split`: use `N->100` percent of the dataset for testing
* `loss_cls_name`: name of the class that defines the loss. This class should be
defined in train/losses.py in order for the script to find it. Currently, the
main available options are:
Expand All @@ -192,18 +200,32 @@ Relevant parameters:
NN used
* `model_cls_name`: name of the class defining the NN used, should be defined in
the module specified by `model_module_name`
* `train_split`: use `0->N` percent of the dataset for training
* `test_split`: use `N->100` percent of the dataset for testing

Another important way to modify the way the script runs consists in modifying
the domains used for training. These are defined in
[`training_subdomains.yaml`](training_subdomains.yaml) in terms of their
coordinates. Note that at run time domains will be truncated to the size of the
smallest domain in terms of number of points.
You may also call this script directly instead of going through `mlflow run`. In
such cases, you may replace `--run-id` with `--forcing-data-path`. See
[`trainScript`][cli-train] and [`MLproject`](MLproject) for more details.

##### Subdomains
The `subdomains_file` format is a list of bounding boxes, each defined using
four floats:

```yaml
- lat-min: 35
lat-max: 50
long-min: -50
long-max: -20
- lat-min: -40
lat-max: -25
long-min: -180
long-max: -162
# - ...
```

`lat-min` must be smaller than `lat-max`, likewise for `long-min`.

*Note:* Ensure that the spatial subdomains defined in `training_subdomains.yaml`
are contained in the domain of the forcing data you use. If they aren't, you may
get a Python error along the lines of:
*Note:* Ensure that the subdomains you use are contained in the domain of the
forcing data you use. If they aren't, you may get a confusing Python error along
the lines of:

```
RuntimeError: Calculated padded input size per channel: <smaller than 5 x 5>.
Expand Down Expand Up @@ -243,7 +265,7 @@ The inference step should then start.
### Jupyter Notebooks
The [examples/jupyter-notebooks](examples/jupyter-notebooks/) folder stores
notebooks developed during early project development, some of which were used to
generate figures used in the 2021 paper. See the readme in the folder for
generate figures used in the 2021 paper. See the readme in the above folder for
details.

### Dev Branch
Expand All @@ -253,11 +275,11 @@ use through a command line interface for the data step, and the training step
is in progress. Further work is needed for the inference step, and to adapt the Jupyter
notebooks.

## Data on Huggingface
There is GZ21 Ocean Momentum data available on [Huggingface](https://huggingface.co/):
- [the output of the data step](https://huggingface.co/datasets/M2LInES/gfdl-cmip26-gz21-ocean-forcing)
and
- [the trained model](https://huggingface.co/M2LInES/gz21-ocean-momentum)
## Data on HuggingFace
There is GZ21 Ocean Momentum data available on [HuggingFace](https://huggingface.co/)

* [the output of the data step][datasets/M2LInES/gz21-forcing-cm26] and
* [the trained model](https://huggingface.co/M2LInES/gz21-ocean-momentum).

## Contributing
We are not currently accepting contributions outside of the M2LInES and ICCS
Expand Down
4 changes: 4 additions & 0 deletions examples/cli-configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@
## General tips
* If the data step (forcing generation) is taking too long, lower `ntimes`. On a
consumer machine, for testing, 100 is good enough. (4000 will take ages.)

## Details
### `training-subdomains-paper.yaml`
Four spatial bounding boxes, used for training in the original paper.
File renamed without changes.

0 comments on commit e2a9cb2

Please sign in to comment.