Skip to content

Commit

Permalink
Relax input types (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Jul 25, 2024
1 parent 3f872ed commit 4b53570
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions src/analyze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,30 @@ See also [`Explanation`](@ref).
- `add_batch_dim`: add batch dimension to the input without allocating. Default is `false`.
"""
function analyze(
input::AbstractArray{<:Real},
input,
method::AbstractXAIMethod,
output_selection::Union{Integer,Tuple{<:Integer}};
kwargs...,
)
return _analyze(input, method, IndexSelector(output_selection); kwargs...)
end

function analyze(input::AbstractArray{<:Real}, method::AbstractXAIMethod; kwargs...)
function analyze(input, method::AbstractXAIMethod; kwargs...)
return _analyze(input, method, MaxActivationSelector(); kwargs...)
end

function (method::AbstractXAIMethod)(
input::AbstractArray{<:Real},
output_selection::Union{Integer,Tuple{<:Integer}};
kwargs...,
input, output_selection::Union{Integer,Tuple{<:Integer}}; kwargs...
)
return _analyze(input, method, IndexSelector(output_selection); kwargs...)
end
function (method::AbstractXAIMethod)(input::AbstractArray{<:Real}; kwargs...)
function (method::AbstractXAIMethod)(input; kwargs...)
return _analyze(input, method, MaxActivationSelector(); kwargs...)
end

# lower-level call to method
function _analyze(
input::AbstractArray{T,N},
input,
method::AbstractXAIMethod,
sel::AbstractOutputSelector;
add_batch_dim::Bool=false,
Expand All @@ -52,6 +50,5 @@ function _analyze(
if add_batch_dim
return method(batch_dim_view(input), sel; kwargs...)
end
N < 2 && throw(BATCHDIM_MISSING)
return method(input, sel; kwargs...)
end

0 comments on commit 4b53570

Please sign in to comment.