Skip to content

Commit

Permalink
how to guide for EnbPI
Browse files Browse the repository at this point in the history
  • Loading branch information
MojiFarmanbar committed Jul 22, 2023
1 parent f68afcc commit 7b735f8
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 1 deletion.
14 changes: 14 additions & 0 deletions bib.bib
Original file line number Diff line number Diff line change
Expand Up @@ -3032,4 +3032,18 @@ @TechReport{martens2020optimizing
keywords = {Computer Science - Machine Learning, Computer Science - Neural and Evolutionary Computing, Statistics - Machine Learning},
}

@InProceedings{chen2021conformal,
title = {Conformal prediction interval for dynamic time-series},
author = {Xu, Chen and Xie, Yao},
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
pages = {11559--11569},
year = {2021},
editor = {Meila, Marina and Zhang, Tong},
volume = {139},
series = {Proceedings of Machine Learning Research},
month = {18--24 Jul},
publisher = {PMLR}
file = {Full Text PDF:https\://proceedings.mlr.press/v139/xu21h/xu21h.pdf:application/pdf},
}

@Comment{jabref-meta: databaseType:biblatex;}
157 changes: 157 additions & 0 deletions docs/src/how_to_guides/timeseries.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
```@meta
CurrentModule = ConformalPrediction
```

# How to Conformalize a time series forecast

```{julia}
#| echo: false
using Pkg; Pkg.activate("docs")
using Plots
theme(:wong)
using Random
Random.seed!(2022)
```

Time series data is prevalent across various domains, such as finance, weather forecasting, energy, and supply chains. However, accurately quantifying uncertainty in time series predictions is often a complex task due to inherent temporal dependencies, non-stationarity, and noise in the data. In this context, Conformal Prediction offers a valuable solution by providing prediction intervals which offer a sound way to quantify uncertainty. This how-to guide demonstrate how you can conformalize a time series data using Ensemble Batch Prediction Intervals (EnbPI) [@chen2021conformal]. This method enables the updating of prediction intervals whenever new observations are available. This dynamic update process allows the method to adapt to changing conditions, accounting for the potential degradation of predictions or the increase in noise levels in the data.


## The Task at Hand

Inspired by [MAPIE](https://mapie.readthedocs.io/en/latest/examples_regression/4-tutorials/plot_ts-tutorial.html), we employ the Victoria electricity demand dataset. This dataset contains hourly electriciry demand (in GW) for the Victoria state in Australia, along with corresponding temperature data (in Celsius degrees).

```{julia}
using CSV, DataFrames
df = CSV.read("./dev/artifacts/electricity_demand.csv", DataFrame)
```

## Feature engineering

In this how-to guide, we only focus on data, time and lag features.

### Date and Time related features
We create temporal features out of the date and hour:

```{julia}
using Dates
df.Datetime = Dates.DateTime.(df.Datetime, "yyyy-mm-dd HH:MM:SS")
df.Weekofyear = Dates.week.(df.Datetime)
df.Weekday = Dates.dayofweek.(df.Datetime)
df.hour = Dates.hour.(df.Datetime)
```

Additionally, to simulate sudden changes caused by unforeseen events, such as blackouts or lockdowns, we deliberately reduce the electricity demand by 2GW on February 22nd onward.

```{julia}
condition = df.Datetime .>= Date("2014-02-22")
df[condition, :Demand] .= df[condition, :Demand] .- 2
```

### Lag features

```{julia}
using ShiftedArrays
n_lags = 5
for i = 1:n_lags
DataFrames.transform!(df, "Demand" => (x -> ShiftedArrays.lag(x, i)) => "lag_hour_$i")
end
df_dropped_missing = dropmissing(df)
df_dropped_missing
```

## Train-test split


```{julia}
features_cols = select(df_dropped_missing, Not([:Datetime, :Demand]))
X = Matrix(features_cols)
y = Matrix(df_dropped_missing[:, [:Demand]])
split_index = floor(Int, 0.9 * size(y , 1))
println(split_index)
X_train = X[1:split_index, :]
y_train = y[1:split_index, :]
X_test = X[split_index+1 : size(y,1), :]
y_test = y[split_index+1 : size(y,1), :]
```

## Loading model using MLJ interface

```{julia}
using MLJ
EvoTreeRegressor = @load EvoTreeRegressor pkg=EvoTrees verbosity=0
model = EvoTreeRegressor(nrounds =100, max_depth=10, rng=123)
```

## Conformal time series

We start off with using EnbPI without updating training set residuals to build prediction intervals

```{julia}
using ConformalPrediction
conf_model = conformal_model(model; method=:time_series_ensemble_batch, coverage=0.95)
mach = machine(conf_model, X_train, y_train)
train = [1:split_index;]
fit!(mach, rows=train)
y_pred_interval = predict(conf_model, mach.fitresult, X_test)
lb = [ minimum(tuple_data) for tuple_data in y_pred_interval]
ub = [ maximum(tuple_data) for tuple_data in y_pred_interval]
y_pred = [mean(tuple_data) for tuple_data in y_pred_interval]
```

```{julia}
cutoff_point = findfirst(df_dropped_missing.Datetime .== Date("2014-02-15"))
p1 = plot(df_dropped_missing[cutoff_point:split_index, [:Datetime]].Datetime, y_train[cutoff_point:split_index] , label="train", color=:blue, legend=:bottomleft)
plot!(df_dropped_missing[split_index+1 : size(y,1), [:Datetime]].Datetime, y_test, label="test", color=:orange)
plot!(df_dropped_missing[split_index+1 : size(y,1), [:Datetime]].Datetime ,y_pred, label ="prediction", color=:green)
plot!(df_dropped_missing[split_index+1 : size(y,1), [:Datetime]].Datetime,
lb, fillrange = ub, fillalpha = 0.2, label = "PI without update of residuals", color=:green, linewidth=0)
```


We can use partial_fit method in EnbPI implementation in ConformalPrediction in order to adjust prediction intervals to sudden change points on test sets that have not been seen by the model during training. In the below experiment, sample_size indicates the batch of new observations. You can decide if you want to upate residuals by sample_size or update and remove first n residuals (shift_size = n). The latter will allow to remove early residuals that won't have positive impact on the current observations.

```{julia}
sample_size = 10
shift_size = 10
last_index = size(X_test , 1)
lb_updated , ub_updated = ([], [])
for step in 1:sample_size:last_index
if last_index - step < sample_size
y_interval = predict(conf_model, mach.fitresult, X_test[step:last_index , :])
partial_fit(mach.model , mach.fitresult, X_test[step:last_index , :], y_test[step:last_index , :], shift_size)
else
y_interval = predict(conf_model, mach.fitresult, X_test[step:step+sample_size-1 , :])
partial_fit(mach.model , mach.fitresult, X_test[step:step+sample_size-1 , :], y_test[step:step+sample_size-1 , :], shift_size)
end
lb_updatedᵢ= [ minimum(tuple_data) for tuple_data in y_interval]
push!(lb_updated,lb_updatedᵢ)
ub_updatedᵢ = [ maximum(tuple_data) for tuple_data in y_interval]
push!(ub_updated, ub_updatedᵢ)
end
lb_updated = reduce(vcat, lb_updated)
ub_updated = reduce(vcat, ub_updated)
```

```{julia}
p2 = plot(df_dropped_missing[cutoff_point:split_index, [:Datetime]].Datetime, y_train[cutoff_point:split_index] , label="train", color=:blue, legend=:bottomleft)
plot!(df_dropped_missing[split_index+1 : size(y,1), [:Datetime]].Datetime, y_test, label="test", color=:orange)
plot!(df_dropped_missing[split_index+1 : size(y,1), [:Datetime]].Datetime ,y_pred, label ="prediction", color=:green)
plot!(df_dropped_missing[split_index+1 : size(y,1), [:Datetime]].Datetime,
lb_updated, fillrange = ub_updated, fillalpha = 0.2, label = "PI with adjusted reiduals", color=:green, linewidth=0)
plot(p1,p2, layout= (2,1))
```

## Results

In time series problems, unexpected incidents can lead to sudden changes, and such scenarios are highly probable. As illustrated earlier, the model's training data lacks information about these change points, making it unable to anticipate them. The top figure demonstrates that when residuals are not updated, the prediction intervals solely rely on the distribution of residuals from the training set. Consequently, these intervals fail to encompass the true observations after the change point, resulting in a sudden drop in coverage.

However, by partially updating the residuals, the method becomes adept at capturing the increasing uncertainties in model predictions. It is important to note that the changes in uncertainty occurs approximately one day after the change point. This delay is attributed to the requirement of having a sufficient number of new residuals to alter the quantiles obtained from the residuals distribution.
2 changes: 1 addition & 1 deletion src/conformal_models/transductive_regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ For the [`TimeSeriesRegressorEnsembleBatch`](@ref) Non-conformity scores are upd
determines how many points in Non-conformity scores will be discarded.
"""
function partial_fit(conf_model::TimeSeriesRegressorEnsembleBatch, fitresult, X, y, shift_size)
function partial_fit(conf_model::TimeSeriesRegressorEnsembleBatch, fitresult, X, y, shift_size= 0)
= [reformat_mlj_prediction(MMI.predict(conf_model.model, μ̂₋ₜ, MMI.reformat(conf_model.model, X)...)) for μ̂₋ₜ in fitresult]
aggregate = conf_model.aggregate
ŷₜ = _aggregate(ŷ, aggregate)
Expand Down

0 comments on commit 7b735f8

Please sign in to comment.