Skip to content

Commit

Permalink
[Cleanup] Remove redundant function "evaluate_sample" (#124)
Browse files Browse the repository at this point in the history
* [Sampling] Remove redundant function "evaluate_sample"

* [Changelog] Update changelog
  • Loading branch information
kaiserls committed May 28, 2024
1 parent 89ba1f5 commit b828a50
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 37 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ All notable changes to this project will be documented in this file.

### Changed

- Removed the redundant function `evaluate_sample` from `sampling.py`

### Fixed

Expand Down
38 changes: 1 addition & 37 deletions eulerpi/core/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,42 +27,6 @@
from eulerpi.core.result_manager import ResultManager
from eulerpi.core.transformations import eval_log_transformed_density

# TODO: This works on the blob
# Return the samples.
# return sampler.get_chain(discard=0, thin=1, flat=True)
# TODO: This stores the sample as 2d array in the format walker1_step1, walker2_step1, walker3_step1, walker1_step2, walker2_step2, walker3_step2, ...
# sampler_results = samplerBlob.reshape(
# num_walkers * num_steps, sampling_dim + data_dim + 1
# )


def evaluate_sample(
param: np.ndarray,
model: Model,
data: np.ndarray,
data_transformation: DataTransformation,
data_stdevs: np.ndarray,
slice: np.ndarray,
) -> typing.Tuple[float, np.ndarray]:
"""Evaluate the log transformed density at the given parameter values.
Args:
param (np.ndarray): parameter values
model (Model): The model which will be sampled
data (np.ndarray): data
data_transformation (DataTransformation): The data transformation used to normalize the data.
data_stdevs (np.ndarray): kernel width for the data
slice (np.ndarray): slice of the parameter space which will be sampled
Returns:
typing.Tuple[float, np.ndarray]: log transformed density and the sampler result
"""

log_samplerresult = eval_log_transformed_density(
param, model, data, data_transformation, data_stdevs, slice
)
return log_samplerresult


def run_emcee_once(
model: Model,
Expand Down Expand Up @@ -102,7 +66,7 @@ def run_emcee_once(
sampler = emcee.EnsembleSampler(
num_walkers,
sampling_dim,
evaluate_sample,
eval_log_transformed_density,
pool=pool,
args=[model, data, data_transformation, data_stdevs, slice],
)
Expand Down

0 comments on commit b828a50

Please sign in to comment.