Skip to content

Commit

Permalink
Instantiation mode "partial" to "callable". Return the _target_
Browse files Browse the repository at this point in the history
… component as-is when in `_mode_="callable"` and no kwargs are specified (#7413)

### Description

A `_target_` component with `_mode_="partial"` will still be wrapped in
`functools.partial` even when no kwargs are passed:
`functool.partial(component)`. In such cases, the component can just be
returned as-is.

If you agree with this, I will add tests for it. Thank you!

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Ibrahim Hadzic <[email protected]>
  • Loading branch information
ibro45 committed Feb 6, 2024
1 parent 718e4be commit ec2cc83
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 14 deletions.
7 changes: 4 additions & 3 deletions docs/source/config_syntax.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,10 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k
- `_mode_` specifies the operating mode when the component is instantiated or the callable is called.
it currently supports the following values:
- `"default"` (default) -- return the return value of ``_target_(**kwargs)``
- `"partial"` -- return a partial function of ``functools.partial(_target_, **kwargs)`` (this is often
useful when some portion of the full set of arguments are supplied to the ``_target_``, and the user wants to
call it with additional arguments later).
- `"callable"` -- return a callable, either as ``_target_`` itself or, if ``kwargs`` are provided, as a
partial function of ``functools.partial(_target_, **kwargs)``. Useful for defining a class or function
that will be instantied or called later. User can pre-define some arguments to the ``_target_`` and call
it with additional arguments later.
- `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``,
see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall).

Expand Down
2 changes: 1 addition & 1 deletion monai/bundle/config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class ConfigComponent(ConfigItem, Instantiable):
- ``"_mode_"`` (optional): operating mode for invoking the callable ``component`` defined by ``"_target_"``:
- ``"default"``: returns ``component(**kwargs)``
- ``"partial"``: returns ``functools.partial(component, **kwargs)``
- ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)``
- ``"debug"``: returns ``pdb.runcall(component, **kwargs)``
Other fields in the config content are input arguments to the python module.
Expand Down
2 changes: 1 addition & 1 deletion monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ class CompInitMode(StrEnum):
"""

DEFAULT = "default"
PARTIAL = "partial"
CALLABLE = "callable"
DEBUG = "debug"


Expand Down
11 changes: 7 additions & 4 deletions monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,14 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any:
Args:
__path: if a string is provided, it's interpreted as the full path of the target class or function component.
If a callable is provided, ``__path(**kwargs)`` or ``functools.partial(__path, **kwargs)`` will be returned.
If a callable is provided, ``__path(**kwargs)`` will be invoked and returned for ``__mode="default"``.
For ``__mode="callable"``, the callable will be returned as ``__path`` or, if ``kwargs`` are provided,
as ``functools.partial(__path, **kwargs)`` for future invoking.
__mode: the operating mode for invoking the (callable) ``component`` represented by ``__path``:
- ``"default"``: returns ``component(**kwargs)``
- ``"partial"``: returns ``functools.partial(component, **kwargs)``
- ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)``
- ``"debug"``: returns ``pdb.runcall(component, **kwargs)``
kwargs: keyword arguments to the callable represented by ``__path``.
Expand All @@ -259,8 +262,8 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any:
return component
if m == CompInitMode.DEFAULT:
return component(**kwargs)
if m == CompInitMode.PARTIAL:
return partial(component, **kwargs)
if m == CompInitMode.CALLABLE:
return partial(component, **kwargs) if kwargs else component
if m == CompInitMode.DEBUG:
warnings.warn(
f"\n\npdb: instantiating component={component}, mode={m}\n"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
TEST_CASE_5 = [{"_target_": "LoadImaged", "_disabled_": "true", "keys": ["image"]}, dict]
# test non-monai modules and excludes
TEST_CASE_6 = [{"_target_": "torch.optim.Adam", "params": torch.nn.PReLU().parameters(), "lr": 1e-4}, torch.optim.Adam]
TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True, "_mode_": "partial"}, partial]
TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True, "_mode_": "callable"}, partial]
# test args contains "name" field
TEST_CASE_8 = [
{"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25},
Expand Down
6 changes: 2 additions & 4 deletions tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def case_pdb_inst(sarg=None):


class TestClass:

@staticmethod
def compute(a, b, func=lambda x, y: x + y):
return func(a, b)
Expand Down Expand Up @@ -127,7 +126,6 @@ def __call__(self, a, b):


class TestConfigParser(unittest.TestCase):

def test_config_content(self):
test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}}
parser = ConfigParser(config=test_config)
Expand Down Expand Up @@ -183,7 +181,7 @@ def test_function(self, config):
parser = ConfigParser(config=config, globals={"TestClass": TestClass})
for id in config:
if id in ("compute", "cls_compute"):
parser[f"{id}#_mode_"] = "partial"
parser[f"{id}#_mode_"] = "callable"
func = parser.get_parsed_content(id=id)
self.assertTrue(id in parser.ref_resolver.resolved_content)
if id == "error_func":
Expand Down Expand Up @@ -279,7 +277,7 @@ def test_lambda_reference(self):

def test_non_str_target(self):
configs = {
"fwd": {"_target_": "[email protected]", "x": "$torch.rand(1, 3, 256, 256)", "_mode_": "partial"},
"fwd": {"_target_": "[email protected]", "x": "$torch.rand(1, 3, 256, 256)", "_mode_": "callable"},
"model": {"_target_": "monai.networks.nets.resnet.resnet18", "pretrained": False, "spatial_dims": 2},
}
self.assertTrue(callable(ConfigParser(config=configs).fwd))
Expand Down

0 comments on commit ec2cc83

Please sign in to comment.