Skip to content

Commit

Permalink
Fix overwriting of nested parameters in config by runtime parameters (#…
Browse files Browse the repository at this point in the history
…2378)

* Fix updating of nested params from CLI

Signed-off-by: Ankita Katiyar <[email protected]>

* Convert DictConfig to dict for proper merging

Signed-off-by: Ankita Katiyar <[email protected]>

* revert omegaconf

Signed-off-by: Ankita Katiyar <[email protected]>

* revert utils indent

Signed-off-by: Ankita Katiyar <[email protected]>

* Test for nested params with omegaconf

Signed-off-by: Ankita Katiyar <[email protected]>

* Add test for checking store does not contain DictConfig

Signed-off-by: Ankita Katiyar <[email protected]>

* docslinkcheck + move fn outside

Signed-off-by: Ankita Katiyar <[email protected]>

---------

Signed-off-by: Ankita Katiyar <[email protected]>
  • Loading branch information
ankatiyar committed Mar 2, 2023
1 parent 7d7f1dd commit 200ebdb
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/source/development/commands_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ _This command will be deprecated from Kedro version 0.19.0._
kedro lint
```

Your project is linted with [`black`](https://github.com/psf/black), [`flake8`](https://gitlab.com/pycqa/flake8) and [`isort`](https://github.com/PyCQA/isort).
Your project is linted with [`black`](https://github.com/psf/black), [`flake8`](https://github.com/PyCQA/flake8) and [`isort`](https://github.com/PyCQA/isort).


#### Test your project
Expand Down
2 changes: 1 addition & 1 deletion docs/source/development/linting.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ consistent.

## Set up linting tools
There are a variety of linting tools available to use with your Kedro projects. This guide shows you how to use
[`black`](https://github.com/psf/black), [`flake8`](https://gitlab.com/pycqa/flake8), and
[`black`](https://github.com/psf/black), [`flake8`](https://github.com/PyCQA/flake8), and
[`isort`](https://github.com/PyCQA/isort) to lint your Kedro projects.
- **`black`** is a [PEP 8](https://peps.python.org/pep-0008/) compliant opinionated Python code formatter. `black` can
check for styling inconsistencies and reformat your files in place.
Expand Down
2 changes: 1 addition & 1 deletion kedro/framework/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def _split_params(ctx, param, value):
)
dot_list.append(item)
conf = OmegaConf.from_dotlist(dot_list)
return conf
return OmegaConf.to_container(conf)


def _split_load_versions(ctx, param, value):
Expand Down
8 changes: 5 additions & 3 deletions kedro/framework/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from urllib.parse import urlparse
from warnings import warn

from omegaconf import DictConfig
from pluggy import PluginManager

from kedro.config import ConfigLoader, MissingConfigException
Expand Down Expand Up @@ -154,7 +155,9 @@ def _update_nested_dict(old_dict: Dict[Any, Any], new_dict: Dict[Any, Any]) -> N
if key not in old_dict:
old_dict[key] = value
else:
if isinstance(old_dict[key], dict) and isinstance(value, dict):
if isinstance(old_dict[key], (dict, DictConfig)) and isinstance(
value, (dict, DictConfig)
):
_update_nested_dict(old_dict[key], value)
else:
old_dict[key] = value
Expand Down Expand Up @@ -322,8 +325,7 @@ def _add_param_to_feed_dict(param_name, param_value):
"""
key = f"params:{param_name}"
feed_dict[key] = param_value

if isinstance(param_value, dict):
if isinstance(param_value, (dict, DictConfig)):
for key, val in param_value.items():
_add_param_to_feed_dict(f"{param_name}.{key}", val)

Expand Down
7 changes: 6 additions & 1 deletion tests/framework/context/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
import toml
import yaml
from omegaconf import OmegaConf
from pandas.util.testing import assert_frame_equal

from kedro import __version__ as kedro_version
Expand Down Expand Up @@ -163,7 +164,6 @@ def dummy_dataframe():
' --from-nodes "nodes3"'
)


expected_message_head = (
"There are 4 nodes that have not run.\n"
"You can resume the pipeline run by adding the following "
Expand Down Expand Up @@ -484,6 +484,11 @@ def test_validate_layers_error(layers, conflicting_datasets, mocker):
{"a": {"a.c": {"a.c.b": 4}}},
{"a": {"a.a": 1, "a.b": 2, "a.c": {"a.c.a": 3, "a.c.b": 4}}},
),
(
{"a": OmegaConf.create({"b": 1}), "x": 3},
{"a": {"c": 2}},
{"a": {"b": 1, "c": 2}, "x": 3},
),
],
)
def test_update_nested_dict(old_dict: Dict, new_dict: Dict, expected: Dict):
Expand Down
26 changes: 26 additions & 0 deletions tests/framework/session/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
import re
import subprocess
import textwrap
from collections.abc import Mapping
from pathlib import Path

import pytest
import toml
import yaml
from omegaconf import OmegaConf

from kedro import __version__ as kedro_version
from kedro.config import AbstractConfigLoader, ConfigLoader, OmegaConfigLoader
from kedro.framework.cli.utils import _split_params
from kedro.framework.context import KedroContext
from kedro.framework.project import (
ValidationError,
Expand Down Expand Up @@ -920,3 +923,26 @@ def test_setup_logging_using_omega_config_loader_class(
).as_posix()
actual_log_filepath = call_args["handlers"]["info_file_handler"]["filename"]
assert actual_log_filepath == expected_log_filepath


def get_all_values(mapping: Mapping):
for value in mapping.values():
yield value
if isinstance(value, Mapping):
yield from get_all_values(value)


@pytest.mark.parametrize("params", ["a=1,b.c=2", "a=1,b=2,c=3", ""])
def test_no_DictConfig_in_store(
params,
mock_package_name,
fake_project,
):
extra_params = _split_params(None, None, params)
session = KedroSession.create(
mock_package_name, fake_project, extra_params=extra_params
)

assert not any(
OmegaConf.is_config(value) for value in get_all_values(session._store)
)

0 comments on commit 200ebdb

Please sign in to comment.