Skip to content

Commit

Permalink
moved extensions into single file to avoid issues with file rewrite o…
Browse files Browse the repository at this point in the history
…f PackageExtensionsCompat.jl
  • Loading branch information
pat-alt committed Aug 30, 2023
1 parent 8362b62 commit 139aae0
Show file tree
Hide file tree
Showing 11 changed files with 336 additions and 311 deletions.
36 changes: 35 additions & 1 deletion ext/MPIExt/MPIExt.jl → ext/MPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,41 @@ export MPIParallelizer
using CounterfactualExplanations
using MPI

include("utils.jl")
### BEGIN utils.jl

"""
split_count(N::Integer, n::Integer)
Return a vector of `n` integers which are approximately equally sized and sum to `N`. Lifted from https://juliaparallel.org/MPI.jl/v0.20/examples/06-scatterv/.
"""
function split_count(N::Integer, n::Integer)
q, r = divrem(N, n)
return [i <= r ? q + 1 : q for i in 1:n]
end

"""
split_obs(obs::AbstractVector, n::Integer)
Return a vector of `n` group indices for `obs`.
"""
function split_obs(obs::AbstractVector, n::Integer)
N = length(obs)
N_counts = split_count(N, n)
_start = cumsum([1; N_counts[1:(end - 1)]])
_stop = cumsum(N_counts)
return [obs[_start[i]:_stop[i]] for i in 1:n]
end

vectorize_collection(collection::Vector) = collection

vectorize_collection(collection::Base.Iterators.Zip) = map(x -> x[1], collect(collection))

function vectorize_collection(collection::Matrix)
@warn "It looks like there is only one observation in the collection. Are you sure you want to parallelize?"
return [collection]
end

### END

"The `MPIParallelizer` type is used to parallelize the evaluation of a function using `MPI.jl`."
struct MPIParallelizer <: CounterfactualExplanations.AbstractParallelizer
Expand Down
31 changes: 0 additions & 31 deletions ext/MPIExt/utils.jl

This file was deleted.

221 changes: 221 additions & 0 deletions ext/PythonCallExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
module PythonCallExt

using CounterfactualExplanations
using Flux
using PythonCall

### BEGIN utils.jl

"""
CounterfactualExplanations.pytorch_model_loader(model_path::String, model_file::String, class_name::String, pickle_path::String)
Loads a previously saved PyTorch model.
# Arguments
- `model_path::String`: Path to the directory containing the model file.
- `model_file::String`: Name of the model file.
- `class_name::String`: Name of the model class.
- `pickle_path::String`: Path to the pickle file containing the model.
# Returns
- `model`: The loaded PyTorch model.
# Example
```{julia}
model = pytorch_model_loader(
"src/models/pretrained/pytorch",
"pytorch_model.py",
"PyTorchModel",
"src/models/pretrained/pytorch/pytorch_model.pt",
)
```
"""
function CounterfactualExplanations.pytorch_model_loader(
model_path::String, model_file::String, class_name::String, pickle_path::String
)
sys = PythonCall.pyimport("sys")
torch = PythonCall.pyimport("torch")

# Check whether the path is correct
if !endswith(pickle_path, ".pt")
throw(ArgumentError("pickle_path must end with '.pt'"))
end

# Make sure Python is able to import the module
if !in(model_path, sys.path)
sys.path.append(model_path)
end

PythonCall.pyimport(model_file => class_name)
model = torch.load(pickle_path)
return model
end

"""
CounterfactualExplanations.preprocess_python_data(data::CounterfactualData)
Converts a `CounterfactualData` object to an input tensor and a label tensor.
# Arguments
- `data::CounterfactualData`: The data to be converted.
# Returns
- `(x_python::Py, y_python::Py)`: A tuple of tensors resulting from the conversion, `x_python` holding the features and `y_python` holding the labels.
# Example
x_python, y_python = preprocess_python_data(counterfactual_data) # converts `counterfactual_data` to tensors `x_python` and `y_python
"""
function CounterfactualExplanations.preprocess_python_data(data::CounterfactualData)
x_julia = data.X
y_julia = data.y

# Convert data to tensors
torch = PythonCall.pyimport("torch")
np = PythonCall.pyimport("numpy")

x_python = Float32.(x_julia)
x_python = np.array(x_python)
x_python = torch.tensor(x_python).T

y_python = Float32.(y_julia)
y_python = np.array(y_python)
y_python = torch.tensor(y_python)

return x_python, y_python
end

### END

### BEGIN models.jl

using CounterfactualExplanations.Models

"""
PyTorchModel <: AbstractDifferentiableModel
Constructor for models trained in `PyTorch`.
"""
struct PyTorchModel <: AbstractDifferentiableModel
model::Any
likelihood::Symbol
function PyTorchModel(model, likelihood)
if likelihood [:classification_binary, :classification_multi]
new(model, likelihood)
else
throw(
ArgumentError(
"`type` should be in `[:classification_binary,:classification_multi]`"
),
)
end
end
end

"Outer constructor that extends method from parent package."
CounterfactualExplanations.PyTorchModel(args...) = PyTorchModel(args...)

"""
function Models.logits(M::PyTorchModel, x::AbstractArray)
Calculates the logit scores output by the model `M` for the input data `X`.
# Arguments
- `M::PyTorchModel`: The model selected by the user. Must be a model defined using PyTorch.
- `X::AbstractArray`: The feature vector for which the logit scores are calculated.
# Returns
- `logits::AbstractArray`: The logit scores for each output class for the data points in `X`.
# Example
logits = Models.logits(M, x) # calculates the logit scores for each output class for the data points in `X`
"""
function Models.logits(M::PyTorchModel, x::AbstractArray)
torch = PythonCall.pyimport("torch")
np = PythonCall.pyimport("numpy")

if !isa(x, Matrix)
x = reshape(x, length(x), 1)
end

ŷ_python = M.model(torch.tensor(np.array(x)).T).detach().numpy()
= PythonCall.pyconvert(Matrix, ŷ_python)

return transpose(ŷ)
end

"""
function Models.probs(M::PyTorchModel, x::AbstractArray)
Calculates the output probabilities of the model `M` for the input data `X`.
# Arguments
- `M::PyTorchModel`: The model selected by the user. Must be a model defined using PyTorch.
- `X::AbstractArray`: The feature vector for which the logit scores are calculated.
# Returns
- `logits::AbstractArray`: The probabilities for each output class for the data points in `X`.
# Example
logits = Models.logits(M, x) # calculates the probabilities for each output class for the data points in `X`
"""
function Models.probs(M::PyTorchModel, x::AbstractArray)
if M.likelihood == :classification_binary
return Flux.σ.(logits(M, x))
elseif M.likelihood == :classification_multi
return Flux.softmax(logits(M, x))
end
end

### END

### BEGIN generators.jl

using CounterfactualExplanations.Generators

"""
Generators.∂ℓ(generator::AbstractGradientBasedGenerator, M::PyTorchModel, ce::AbstractCounterfactualExplanation)
Method for computing the gradient of the loss function at the current counterfactual state for gradient-based generators operating on PyTorch models.
The gradients are calculated through PyTorch using PythonCall.jl.
# Arguments
- `generator::AbstractGradientBasedGenerator`: The generator object that is used to generate the counterfactual explanation.
- `M::Models.PyTorchModel`: The PyTorch model for which the counterfactual is generated.
- `ce::AbstractCounterfactualExplanation`: The counterfactual explanation object for which the gradient is calculated.
# Returns
- `grad::AbstractArray`: The gradient of the loss function at the current counterfactual state.
# Example
grad = ∂ℓ(generator, M, ce) # calculates the gradient of the loss function at the current counterfactual state.
"""
function Generators.∂ℓ(
generator::AbstractGradientBasedGenerator,
M::PyTorchModel,
ce::AbstractCounterfactualExplanation,
)
torch = PythonCall.pyimport("torch")
np = PythonCall.pyimport("numpy")

x = ce.x
target = Float32.(ce.target_encoded)

x = torch.tensor(np.array(reshape(x, 1, length(x))))
x.requires_grad = true

target = torch.tensor(np.array(reshape(target, 1, length(target))))
target = target.squeeze()

output = M.model(x).squeeze()

obj_loss = torch.nn.BCEWithLogitsLoss()(output, target)
obj_loss.backward()

grad = PythonCall.pyconvert(Matrix, x.grad.t().detach().numpy())

return grad
end

### END

end
11 changes: 0 additions & 11 deletions ext/PythonCallExt/PythonCallExt.jl

This file was deleted.

45 changes: 0 additions & 45 deletions ext/PythonCallExt/generators.jl

This file was deleted.

Loading

0 comments on commit 139aae0

Please sign in to comment.