Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structured Config Style #8

Open
wants to merge 1 commit into
base: datasets
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions conf/dataset/default.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# @package dataset
# cfg:
# torch data-loader specific arguments
cfg:
cfg:
batch_size: ${training.batch_size}
num_workers: ${training.num_workers}
dataroot: data

common_transform:
aug_transform:
pre_transform:
# common_transform:
# aug_transform:
# pre_transform:

val_transform: "${dataset.cfg.common_transform}"
test_transform: "${dataset.cfg.val_transform}"
train_transform:
- "${dataset.cfg.aug_transform}"
- "${dataset.cfg.common_transform}"
# val_transform: "${dataset.cfg.common_transform}"
# test_transform: "${dataset.cfg.val_transform}"
# train_transform:
# - "${dataset.cfg.aug_transform}"
# - "${dataset.cfg.common_transform}"
3 changes: 2 additions & 1 deletion conf/dataset/segmentation/s3dis/s3dis1x1.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# @package dataset
defaults:
- dataset_s3dis
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for ConfigStore we have to add a line to all the defaults lines so that it can associate the config file with the dataclass.

- segmentation/default
_target_: torch_points3d.dataset.s3dis1x1.s3dis_data_module
cfg:
fold: 5
fold : 5
6 changes: 6 additions & 0 deletions conf/test_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
defaults: # loads default configs
- base_config
- dataset: segmentation/s3dis/s3dis1x1
- training: default

pretty_print: True
2 changes: 2 additions & 0 deletions conf/training/default.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
defaults:
- base_trainer
lr: 5e-5

# read in dataset
Expand Down
74 changes: 74 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import hydra
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this file to the test directory ? and call this file test_config_store or something like that ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, if you wanted to keep it. I just wrote it to demonstrate the usecase. I think I'll just delete it and we can write a proper test file later.

from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf, DictConfig
from torch_points3d.trainer import LitTrainer
from torch_points3d.core.instantiator import HydraInstantiator, Instantiator
from dataclasses import dataclass
from hydra.core.config_store import ConfigStore
from typing import List, Any, Type
from omegaconf import MISSING, OmegaConf
from omegaconf._utils import is_structured_config

OmegaConf.register_new_resolver("get_filename", lambda x: x.split("/")[-1])


@dataclass
class TrainingDataConfig:
batch_size: int = 32
num_workers: int = 0
lr: float = MISSING

# We seperate the dataset "cfg" from the actual dataset object
# so that we can pass the "cfg" into the dataset constructors as a DictConfig
# instead of as unwrapped parameters
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this

@dataclass
class BaseDataConfig:
batch_size: int = 32
num_workers: int = 0
dataroot: str = "data"

@dataclass
class BaseDataset:
_target_: str
cfg: BaseDataConfig

@dataclass
class S3DISDataConfig(BaseDataConfig):
fold: int = 6

@dataclass
class S3DISDataset(BaseDataset):
cfg: S3DISDataConfig

@dataclass
class Config:
dataset: Any
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't get the typing to work on this for some reason. You should be able to do Type[BaseDataset] but it gave me an error...

training: TrainingDataConfig
pretty_print: bool = False

def show(x):
print(f"type: {type(x).__name__}, value: {repr(x)}")

cs = ConfigStore.instance()
cs.store(name="base_config", node=Config)
cs.store(group="dataset", name="dataset_s3dis", node=S3DISDataset)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where we define the dataset_s3dis we referenced in the config

cs.store(group="training", name="base_trainer", node=TrainingDataConfig)

@hydra.main(config_path="conf", config_name="test_config")
def main(cfg: DictConfig):
OmegaConf.set_struct(cfg, False) # This allows getattr and hasattr methods to function correctly
if cfg.get("pretty_print"):
print(OmegaConf.to_yaml(cfg, resolve=True))

dset = cfg.get("dataset")
show(dset)
show(dset.cfg)
dset_cfg = dset.cfg
# for some reason the cfg object will lose its typing information if hydra passes it to the target class
# so we pass it manually ourselves and keep the typing info
delattr(dset, "cfg")
hydra.utils.instantiate(dset, dset_cfg)


if __name__ == "__main__":
main()
43 changes: 24 additions & 19 deletions torch_points3d/dataset/s3dis1x1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, Callable, Dict, Optional, Sequence
from omegaconf import MISSING
from omegaconf import MISSING, DictConfig
from dataclasses import dataclass

import hydra.utils
Expand All @@ -16,24 +16,29 @@ class S3DISDataConfig(PointCloudDataConfig):
num_workers: int = 0
fold: int = 6

def show(x):
print(f"type: {type(x).__name__}, value: {repr(x)}")

class s3dis_data_module(PointCloudDataModule):
def __init__(self, cfg: S3DISDataConfig = S3DISDataConfig()) -> None:
def __init__(self, cfg: DictConfig) -> None:
super().__init__(cfg)

self.ds = {
"train": S3DIS1x1(
self.cfg.dataroot,
test_area=self.cfg.fold,
train=True,
pre_transform=self.cfg.pre_transform,
transform=self.cfg.train_transform,
),
"test": S3DIS1x1(
self.cfg.dataroot,
test_area=self.cfg.fold,
train=False,
pre_transform=self.cfg.pre_transform,
transform=self.cfg.train_transform,
),
}
show(cfg)
cfg.num_workers = "aj"
Copy link
Member Author

@CCInc CCInc Jul 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will throw an error, because the cfg is still ducttyped to S3DISDataConfig and hydra is enforcing the runtime checks on it, which will be very helpful.

show(cfg)
# print("pre_transform: ", self.cfg.pre_transform)
# self.ds = {
# "train": S3DIS1x1(
# self.cfg.dataroot,
# test_area=self.cfg.fold,
# train=True,
# pre_transform=self.cfg.pre_transform,
# transform=self.cfg.train_transform,
# ),
# "test": S3DIS1x1(
# self.cfg.dataroot,
# test_area=self.cfg.fold,
# train=False,
# pre_transform=self.cfg.pre_transform,
# transform=self.cfg.train_transform,
# ),
# }