diff --git a/CHANGELOG.md b/CHANGELOG.md index a876485..fb6687e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ All notable changes to this project will be documented in this file. ### Changed +- Removed the redundant function `evaluate_sample` from `sampling.py` ### Fixed diff --git a/eulerpi/core/sampling.py b/eulerpi/core/sampling.py index be8a982..0a2fda0 100644 --- a/eulerpi/core/sampling.py +++ b/eulerpi/core/sampling.py @@ -27,42 +27,6 @@ from eulerpi.core.result_manager import ResultManager from eulerpi.core.transformations import eval_log_transformed_density -# TODO: This works on the blob -# Return the samples. -# return sampler.get_chain(discard=0, thin=1, flat=True) -# TODO: This stores the sample as 2d array in the format walker1_step1, walker2_step1, walker3_step1, walker1_step2, walker2_step2, walker3_step2, ... -# sampler_results = samplerBlob.reshape( -# num_walkers * num_steps, sampling_dim + data_dim + 1 -# ) - - -def evaluate_sample( - param: np.ndarray, - model: Model, - data: np.ndarray, - data_transformation: DataTransformation, - data_stdevs: np.ndarray, - slice: np.ndarray, -) -> typing.Tuple[float, np.ndarray]: - """Evaluate the log transformed density at the given parameter values. - - Args: - param (np.ndarray): parameter values - model (Model): The model which will be sampled - data (np.ndarray): data - data_transformation (DataTransformation): The data transformation used to normalize the data. - data_stdevs (np.ndarray): kernel width for the data - slice (np.ndarray): slice of the parameter space which will be sampled - - Returns: - typing.Tuple[float, np.ndarray]: log transformed density and the sampler result - """ - - log_samplerresult = eval_log_transformed_density( - param, model, data, data_transformation, data_stdevs, slice - ) - return log_samplerresult - def run_emcee_once( model: Model, @@ -102,7 +66,7 @@ def run_emcee_once( sampler = emcee.EnsembleSampler( num_walkers, sampling_dim, - evaluate_sample, + eval_log_transformed_density, pool=pool, args=[model, data, data_transformation, data_stdevs, slice], )