-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Structured Config Style #8
base: datasets
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,15 @@ | ||
# @package dataset | ||
# cfg: | ||
# torch data-loader specific arguments | ||
cfg: | ||
cfg: | ||
batch_size: ${training.batch_size} | ||
num_workers: ${training.num_workers} | ||
dataroot: data | ||
|
||
common_transform: | ||
aug_transform: | ||
pre_transform: | ||
# common_transform: | ||
# aug_transform: | ||
# pre_transform: | ||
|
||
val_transform: "${dataset.cfg.common_transform}" | ||
test_transform: "${dataset.cfg.val_transform}" | ||
train_transform: | ||
- "${dataset.cfg.aug_transform}" | ||
- "${dataset.cfg.common_transform}" | ||
# val_transform: "${dataset.cfg.common_transform}" | ||
# test_transform: "${dataset.cfg.val_transform}" | ||
# train_transform: | ||
# - "${dataset.cfg.aug_transform}" | ||
# - "${dataset.cfg.common_transform}" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# @package dataset | ||
defaults: | ||
- dataset_s3dis | ||
- segmentation/default | ||
_target_: torch_points3d.dataset.s3dis1x1.s3dis_data_module | ||
cfg: | ||
fold: 5 | ||
fold : 5 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
defaults: # loads default configs | ||
- base_config | ||
- dataset: segmentation/s3dis/s3dis1x1 | ||
- training: default | ||
|
||
pretty_print: True |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
defaults: | ||
- base_trainer | ||
lr: 5e-5 | ||
|
||
# read in dataset | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import hydra | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you move this file to the test directory ? and call this file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, if you wanted to keep it. I just wrote it to demonstrate the usecase. I think I'll just delete it and we can write a proper test file later. |
||
from hydra.core.global_hydra import GlobalHydra | ||
from omegaconf import OmegaConf, DictConfig | ||
from torch_points3d.trainer import LitTrainer | ||
from torch_points3d.core.instantiator import HydraInstantiator, Instantiator | ||
from dataclasses import dataclass | ||
from hydra.core.config_store import ConfigStore | ||
from typing import List, Any, Type | ||
from omegaconf import MISSING, OmegaConf | ||
from omegaconf._utils import is_structured_config | ||
|
||
OmegaConf.register_new_resolver("get_filename", lambda x: x.split("/")[-1]) | ||
|
||
|
||
@dataclass | ||
class TrainingDataConfig: | ||
batch_size: int = 32 | ||
num_workers: int = 0 | ||
lr: float = MISSING | ||
|
||
# We seperate the dataset "cfg" from the actual dataset object | ||
# so that we can pass the "cfg" into the dataset constructors as a DictConfig | ||
# instead of as unwrapped parameters | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note this |
||
@dataclass | ||
class BaseDataConfig: | ||
batch_size: int = 32 | ||
num_workers: int = 0 | ||
dataroot: str = "data" | ||
|
||
@dataclass | ||
class BaseDataset: | ||
_target_: str | ||
cfg: BaseDataConfig | ||
|
||
@dataclass | ||
class S3DISDataConfig(BaseDataConfig): | ||
fold: int = 6 | ||
|
||
@dataclass | ||
class S3DISDataset(BaseDataset): | ||
cfg: S3DISDataConfig | ||
|
||
@dataclass | ||
class Config: | ||
dataset: Any | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I couldn't get the typing to work on this for some reason. You should be able to do |
||
training: TrainingDataConfig | ||
pretty_print: bool = False | ||
|
||
def show(x): | ||
print(f"type: {type(x).__name__}, value: {repr(x)}") | ||
|
||
cs = ConfigStore.instance() | ||
cs.store(name="base_config", node=Config) | ||
cs.store(group="dataset", name="dataset_s3dis", node=S3DISDataset) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where we define the |
||
cs.store(group="training", name="base_trainer", node=TrainingDataConfig) | ||
|
||
@hydra.main(config_path="conf", config_name="test_config") | ||
def main(cfg: DictConfig): | ||
OmegaConf.set_struct(cfg, False) # This allows getattr and hasattr methods to function correctly | ||
if cfg.get("pretty_print"): | ||
print(OmegaConf.to_yaml(cfg, resolve=True)) | ||
|
||
dset = cfg.get("dataset") | ||
show(dset) | ||
show(dset.cfg) | ||
dset_cfg = dset.cfg | ||
# for some reason the cfg object will lose its typing information if hydra passes it to the target class | ||
# so we pass it manually ourselves and keep the typing info | ||
delattr(dset, "cfg") | ||
hydra.utils.instantiate(dset, dset_cfg) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from typing import Any, Callable, Dict, Optional, Sequence | ||
from omegaconf import MISSING | ||
from omegaconf import MISSING, DictConfig | ||
from dataclasses import dataclass | ||
|
||
import hydra.utils | ||
|
@@ -16,24 +16,29 @@ class S3DISDataConfig(PointCloudDataConfig): | |
num_workers: int = 0 | ||
fold: int = 6 | ||
|
||
def show(x): | ||
print(f"type: {type(x).__name__}, value: {repr(x)}") | ||
|
||
class s3dis_data_module(PointCloudDataModule): | ||
def __init__(self, cfg: S3DISDataConfig = S3DISDataConfig()) -> None: | ||
def __init__(self, cfg: DictConfig) -> None: | ||
super().__init__(cfg) | ||
|
||
self.ds = { | ||
"train": S3DIS1x1( | ||
self.cfg.dataroot, | ||
test_area=self.cfg.fold, | ||
train=True, | ||
pre_transform=self.cfg.pre_transform, | ||
transform=self.cfg.train_transform, | ||
), | ||
"test": S3DIS1x1( | ||
self.cfg.dataroot, | ||
test_area=self.cfg.fold, | ||
train=False, | ||
pre_transform=self.cfg.pre_transform, | ||
transform=self.cfg.train_transform, | ||
), | ||
} | ||
show(cfg) | ||
cfg.num_workers = "aj" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will throw an error, because the |
||
show(cfg) | ||
# print("pre_transform: ", self.cfg.pre_transform) | ||
# self.ds = { | ||
# "train": S3DIS1x1( | ||
# self.cfg.dataroot, | ||
# test_area=self.cfg.fold, | ||
# train=True, | ||
# pre_transform=self.cfg.pre_transform, | ||
# transform=self.cfg.train_transform, | ||
# ), | ||
# "test": S3DIS1x1( | ||
# self.cfg.dataroot, | ||
# test_area=self.cfg.fold, | ||
# train=False, | ||
# pre_transform=self.cfg.pre_transform, | ||
# transform=self.cfg.train_transform, | ||
# ), | ||
# } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for
ConfigStore
we have to add a line to all the defaults lines so that it can associate the config file with the dataclass.