Skip to content

Commit

Permalink
add n_bins as argument to functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pasq-cat committed Jun 29, 2024
1 parent 6a9ee1b commit 203513d
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@ cumulative distribution function of the predicted distribution targeting y_t.
Source: https://arxiv.org/abs/1807.00263
Inputs:
- Y_cal: a vector of values y_t
- sampled_distributions: an array of sampled distributions F(x_t) stacked column-wise.
- 'Y_cal': a vector of values y_t
- 'sampled_distributions': an array of sampled distributions F(x_t) stacked column-wise.
- 'n_bins': number of equally spaced bins to use.
Outputs:
- counts: an array cointaining the empirical frequencies for each quantile interval.
"""
function empirical_frequency_regression(Y_cal, sampled_distributions)
quantiles = collect(0:0.05:1)
function empirical_frequency_regression(Y_cal, sampled_distributions, n_bins)
quantiles = collect(range(0; stop=1, length=n_bins + 1))
quantiles_matrix = hcat(
[quantile(samples, quantiles) for samples in sampled_distributions]...
)
Expand Down Expand Up @@ -101,16 +102,17 @@ Inputs:
- y_binary: the array of outputs y_t numerically coded: 1 for the target class, 0 for the null class.
- sampled_distributions: an array of sampled distributions stacked column-wise so that in the first row
there is the probability for the target class y_1 and in the second row the probability for the null class y_0.
- 'n_bins': number of equally spaced bins to use.
Outputs:
- num_p_per_interval: array with the number of probabilities falling within interval
- emp_avg: array with the observed empirical average per interval
- bin_centers: array with the centers of the bins
"""
function empirical_frequency_binary_classification(y_binary, sampled_distributions)
function empirical_frequency_binary_classification(y_binary, sampled_distributions, n_bins)

# Number of bins
n_bins = 20
n_bins = n_bins
#intervals boundaries
int_bds = collect(range(0; stop=1, length=n_bins + 1))
#bin centers
Expand Down

0 comments on commit 203513d

Please sign in to comment.