Skip to content

Commit

Permalink
name in config
Browse files Browse the repository at this point in the history
  • Loading branch information
martvanrijthoven committed Jun 24, 2024
1 parent d9eb8da commit 741fc39
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 32 deletions.
7 changes: 6 additions & 1 deletion dicfg/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
_REFERENCE_MAP_SYMBOL = ":"
_REFERENCE_ATTRIBUTE_SYMBOL = "."
_OBJECT_KEY = "*object"
_TEMPLATE_KEY = "*template"


class _ObjectFactory:
Expand All @@ -32,7 +33,7 @@ def _build(self, config: dict):
@_build.register
def _build_dict(self, config: dict):
for key, value in config.items():
if _dont_build(value):
if _dont_build(value) or _is_template(key):
config[key] = value
elif _is_object(value):
config[key] = self._build_object(value)
Expand Down Expand Up @@ -105,6 +106,10 @@ def _is_object(value):
return isinstance(value, dict) and _OBJECT_KEY in value


def _is_template(key):
return key == _TEMPLATE_KEY


def _dont_build(value):
if isinstance(value, dict):
return value.pop("!build", False)
Expand Down
75 changes: 44 additions & 31 deletions dicfg/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from collections import defaultdict
from copy import deepcopy
from functools import partial, singledispatch
from functools import partial, singledispatchmethod
from pathlib import Path
from typing import List, Union

Expand All @@ -12,12 +12,12 @@
from dicfg.config import merge


def open_json_config(config_path):
def open_json_config(config_path) -> dict:
with open(str(config_path), encoding="utf8") as file:
return json.load(file)


def open_yaml_config(config_path):
def open_yaml_config(config_path) -> dict:
with open(str(config_path), encoding="utf8") as file:
return yaml.load(file, Loader=yaml.SafeLoader)

Expand All @@ -37,8 +37,7 @@ class ConfigReader:
"""ConfigReader
Args:
name (str): Name of config. Used as a reference in user configs and cli settings.
main_config_path (Union[str, Path], optional): Path to main config. Defaults to "./configs/config.yml".
main_config_path (Union[str, Path], optional): Path to main config.
presets_folder_name (str, optional): Presets folder. Defaults to 'presets'.
default_key (str, optional): Default context key. Defaults to "default".
context_keys (tuple, optional): Addtional context keys. Defaults to ().
Expand All @@ -47,14 +46,17 @@ class ConfigReader:

def __init__(
self,
name: str,
main_config_path: Union[str, Path] = "./configs/config.yml",
main_config_path: Union[str, Path],
presets_folder_name: str = "presets",
default_key: str = "default",
context_keys: tuple = (),
search_paths: tuple = (),
):
self._name = name

self._config = None
self._name = None
self._version = None

self._main_config_path = Path(main_config_path)

if not self._main_config_path.exists():
Expand Down Expand Up @@ -95,7 +97,9 @@ def read(
user_config_search_path, self._search_paths
)

self_config = self._read(self._main_config_path)
self._config = self._read(self._main_config_path)
self._name = self._config.pop('NAME')
self._version = self._config.pop('VERSION', "N/A")
user_config = self._read_user_config(user_config)

arg_preset_configs = self._read_presets(presets)
Expand All @@ -105,11 +109,15 @@ def read(

cli_config = self._read_cli(sys.argv[1:])

configs = (self_config, *preset_configs, user_config, cli_config)
configs = (self._config, *preset_configs, user_config, cli_config)
configs = self._fuse_configs(configs, self._context_keys, search_paths)

return merge(*configs).cast()

def save(self, output_folder):
output_path = ...
...

def _set_search_paths(self, user_config_search_path, search_paths):
return (
Path(),
Expand Down Expand Up @@ -150,13 +158,38 @@ def _fuse_configs(self, configs, context_keys, search_paths):
return tuple(map(fuse_config, configs))

def _fuse_config(self, config: dict, context_keys: tuple, search_paths):
config = _include_configs(config, search_paths)
config = self._include_configs(config, search_paths)
fused_config = deepcopy(
{key: deepcopy(config.get("default", {})) for key in context_keys}
)
return merge(fused_config, config)


@singledispatchmethod
def _include_configs(self, config, search_paths):
return config


@_include_configs.register
def _include_configs_str(self, config: str, search_paths):
if Path(config).suffix in _FILE_READERS:
config_path = _search_config(config, search_paths)
open_config = _FILE_READERS[Path(config_path).suffix](config_path)
print(open_config, type(open_config))
open_config = open_config.get(self._name, open_config)
print(open_config)
return self._include_configs(open_config, search_paths)
return config


@_include_configs.register
def _include_configs_dict(self, config: dict, search_paths):
for key, value in config.items():
config[key] = self._include_configs(value, search_paths)
return config



def _create_dict_from_keys(keys: list, value) -> dict:
dictionary = defaultdict(dict)
if len(keys) <= 1:
Expand All @@ -179,23 +212,3 @@ def _search_config(config_name: Union[str, Path], search_paths: tuple) -> Path:
return config_path
raise ConfigNotFoundError(config_name)


@singledispatch
def _include_configs(config, search_paths):
return config


@_include_configs.register
def _include_configs_str(config: str, search_paths):
if Path(config).suffix in _FILE_READERS:
config_path = _search_config(config, search_paths)
open_config = _FILE_READERS[Path(config_path).suffix](config_path)
return _include_configs(open_config, search_paths)
return config


@_include_configs.register
def _include_configs_dict(config: dict, search_paths):
for key, value in config.items():
config[key] = _include_configs(value, search_paths)
return config
4 changes: 4 additions & 0 deletions tests/configs/config.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
NAME: testconfig
VERSION:

default:
items: items.yml
test:
test2: [1, 2, 3]
test3: test
Expand Down
9 changes: 9 additions & 0 deletions tests/configs/items.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"*template": &items
"*object": __main__.Item
a: 10

testconfig:
- <<: *items
a: 20
- <<: *items
a: 30

0 comments on commit 741fc39

Please sign in to comment.