Skip to content

Commit

Permalink
Directly operate on leaf tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
tecosaur committed Jun 24, 2022
1 parent 3df9fef commit 908701f
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/classification/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ n_labels` matrix of probabilities, each row summing up to 1.
(eg. ["versicolor", "virginica", "setosa"]). It specifies the column ordering
of the output matrix. """
apply_tree_proba(leaf::Leaf{T}, features::AbstractVector{S}, labels) where {S, T} =
collect(leaf.values ./ leaf.total)
leaf.values ./ leaf.total

function apply_tree_proba(tree::Node{S, T}, features::AbstractVector{S}, labels) where {S, T}
if tree.featval === nothing
Expand All @@ -192,8 +192,13 @@ function apply_tree_proba(tree::Node{S, T}, features::AbstractVector{S}, labels)
end
end

apply_tree_proba(tree::LeafOrNode{S, T}, features::AbstractMatrix{S}, labels) where {S, T} =
stack_function_results(row->apply_tree_proba(tree, row, labels), features)
function apply_tree_proba(tree::LeafOrNode{S, T}, features::AbstractMatrix{S}, labels) where {S, T}
predictions = Vector{NTuple{length(labels), Float64}}(undef, size(features, 1))
for i in 1:size(features, 1)
predictions[i] = apply_tree_proba(tree, view(features, i, :), labels)
end
reinterpret(reshape, Float64, predictions) |> transpose |> Matrix
end

function build_forest(
labels :: AbstractVector{T},
Expand Down

0 comments on commit 908701f

Please sign in to comment.