diff --git a/skypy/pipeline/_config.py b/skypy/pipeline/_config.py index 93d0892d6..6de372ba9 100644 --- a/skypy/pipeline/_config.py +++ b/skypy/pipeline/_config.py @@ -81,6 +81,10 @@ def construct_function(self, name, node): return (function, args, kwargs) + def construct_lambda(self, node): + lambda_val = 'lambda ' + self.construct_scalar(node) + return eval(lambda_val, {}) + def construct_quantity(self, node): value = self.construct_scalar(node) return Quantity(value) @@ -89,6 +93,9 @@ def construct_quantity(self, node): # constructor for generic functions SkyPyLoader.add_multi_constructor('!', SkyPyLoader.construct_function) +# constructor for lambda functions +SkyPyLoader.add_constructor('!lambda', SkyPyLoader.construct_lambda) + # constructor for quantities SkyPyLoader.add_constructor('!quantity', SkyPyLoader.construct_quantity) # Implicitly resolve quantities using the regex from astropy diff --git a/skypy/pipeline/tests/data/lambda_function.yml b/skypy/pipeline/tests/data/lambda_function.yml new file mode 100644 index 000000000..12f646bda --- /dev/null +++ b/skypy/pipeline/tests/data/lambda_function.yml @@ -0,0 +1,3 @@ +a: 0.5 +tau: !lambda 'a : 1. / a' +twice: !lambda 'a : a * 2' diff --git a/skypy/pipeline/tests/test_config.py b/skypy/pipeline/tests/test_config.py index 34bac602c..87b34d8a2 100644 --- a/skypy/pipeline/tests/test_config.py +++ b/skypy/pipeline/tests/test_config.py @@ -65,3 +65,11 @@ def test_kwarg_must_be_strings(): with pytest.raises(ValueError) as e: load_skypy_yaml(filename) assert("Invalid key found in config" in e.value.args[0]) + + +def test_lambda_functions(): + filename = get_pkg_data_filename('data/lambda_function.yml') + config = load_skypy_yaml(filename) + a = config['a'] + assert config['tau'](a) == 1. / a + assert config['twice'](a) == 1.0