Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make prediction with probability free #180

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions src/DecisionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,47 +26,56 @@ export InfoNode, InfoLeaf, wrap
###########################
########## Types ##########

struct Leaf{T}
majority :: T
values :: Vector{T}
struct Leaf{T, N}
classes :: NTuple{N, T}
majority :: Int
values :: NTuple{N, Int}
total :: Int
end

struct Node{S, T}
struct Node{S, T, N}
featid :: Int
featval :: S
left :: Union{Leaf{T}, Node{S, T}}
right :: Union{Leaf{T}, Node{S, T}}
left :: Union{Leaf{T, N}, Node{S, T, N}}
right :: Union{Leaf{T, N}, Node{S, T, N}}
end

const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
const LeafOrNode{S, T, N} = Union{Leaf{T, N}, Node{S, T, N}}

struct Root{S, T}
node :: LeafOrNode{S, T}
struct Root{S, T, N}
node :: LeafOrNode{S, T, N}
n_feat :: Int
featim :: Vector{Float64} # impurity importance
end

struct Ensemble{S, T}
trees :: Vector{LeafOrNode{S, T}}
struct Ensemble{S, T, N}
trees :: Vector{LeafOrNode{S, T, N}}
n_feat :: Int
featim :: Vector{Float64}
end

Leaf(features::NTuple{N, T}) where {T, N} =
Leaf(features, 0, Tuple(zeros(T, N)), 0)
Leaf(features::NTuple{N, T}, frequencies::NTuple{N, Int}) where {T, N} =
Leaf(features, argmax(frequencies), frequencies, sum(frequencies))
Leaf(features::Union{<:AbstractVector, <:Tuple},
frequencies::Union{<:AbstractVector{Int}, <:Tuple}) =
Leaf(Tuple(features), Tuple(frequencies))

is_leaf(l::Leaf) = true
is_leaf(n::Node) = false

_zero(::Type{String}) = ""
_zero(x::Any) = zero(x)
convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, _zero(S), lf, Leaf(_zero(T), [_zero(T)]))
convert(::Type{Root{S, T}}, node::LeafOrNode{S, T}) where {S, T} = Root{S, T}(node, 0, Float64[])
convert(::Type{LeafOrNode{S, T}}, tree::Root{S, T}) where {S, T} = tree.node
promote_rule(::Type{Node{S, T}}, ::Type{Leaf{T}}) where {S, T} = Node{S, T}
promote_rule(::Type{Leaf{T}}, ::Type{Node{S, T}}) where {S, T} = Node{S, T}
promote_rule(::Type{Root{S, T}}, ::Type{Leaf{T}}) where {S, T} = Root{S, T}
promote_rule(::Type{Leaf{T}}, ::Type{Root{S, T}}) where {S, T} = Root{S, T}
promote_rule(::Type{Root{S, T}}, ::Type{Node{S, T}}) where {S, T} = Root{S, T}
promote_rule(::Type{Node{S, T}}, ::Type{Root{S, T}}) where {S, T} = Root{S, T}
convert(::Type{Node{S, T, N}}, lf::Leaf{T, N}) where {S, T, N} = Node(0, _zero(S), lf, Leaf(lf.classes))
convert(::Type{Root{S, T, N}}, node::LeafOrNode{S, T, N}) where {S, T, N} = Root{S, T, N}(node, 0, Float64[])
convert(::Type{LeafOrNode{S, T, N}}, tree::Root{S, T, N}) where {S, T, N} = tree.node
promote_rule(::Type{Node{S, T, N}}, ::Type{Leaf{T, N}}) where {S, T, N} = Node{S, T, N}
promote_rule(::Type{Leaf{T, N}}, ::Type{Node{S, T, N}}) where {S, T, N} = Node{S, T, N}
promote_rule(::Type{Root{S, T, N}}, ::Type{Leaf{T}}) where {S, T, N} = Root{S, T, N}
promote_rule(::Type{Leaf{T, N}}, ::Type{Root{S, T, N}}) where {S, T, N} = Root{S, T, N}
promote_rule(::Type{Root{S, T, N}}, ::Type{Node{S, T, N}}) where {S, T, N} = Root{S, T, N}
promote_rule(::Type{Node{S, T, N}}, ::Type{Root{S, T, N}}) where {S, T, N} = Root{S, T}

# make a Random Number Generator object
mk_rng(rng::Random.AbstractRNG) = rng
Expand Down Expand Up @@ -97,9 +106,8 @@ depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
depth(tree::Root) = depth(tree.node)

function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
n_matches = count(leaf.values .== leaf.majority)
ratio = string(n_matches, "/", length(leaf.values))
println(io, "$(leaf.majority) : $(ratio)")
println(io, leaf.classes[leaf.majority], " : ",
leaf.values[leaf.majority], '/', leaf.total)
end
function print_tree(leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
return print_tree(stdout, leaf, depth, indent; sigdigits, feature_names)
Expand Down Expand Up @@ -162,8 +170,8 @@ end

function show(io::IO, leaf::Leaf)
println(io, "Decision Leaf")
println(io, "Majority: $(leaf.majority)")
print(io, "Samples: $(length(leaf.values))")
println(io, "Majority: ", leaf.classes[leaf.majority])
print(io, "Samples: ", leaf.total)
end

function show(io::IO, tree::Node)
Expand Down
93 changes: 54 additions & 39 deletions src/classification/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ function _convert(
) where {S, T}

if node.is_leaf
return Leaf{T}(list[node.label], labels[node.region])
classfreq = Tuple(sum(labels[node.region] .== l) for l in list)
return Leaf{T, length(list)}(
Tuple(list), argmax(classfreq), classfreq, length(node.region))
else
left = _convert(node.l, list, labels)
right = _convert(node.r, list, labels)
return Node{S, T}(node.feature, node.threshold, left, right)
return Node{S, T, length(list)}(
node.feature, node.threshold, left, right)
end
end

Expand Down Expand Up @@ -114,6 +117,7 @@ function build_stump(
labels :: AbstractVector{T},
features :: AbstractMatrix{S},
weights = nothing;
n_classes :: Int = length(unique(labels)),
rng = Random.GLOBAL_RNG,
impurity_importance :: Bool = true) where {S, T}

Expand All @@ -130,7 +134,7 @@ function build_stump(
min_purity_increase = 0.0;
rng = rng)

return _build_tree(t, labels, size(features, 2), size(features, 1), impurity_importance)
return _build_tree(t, labels, n_classes, size(features, 2), size(features, 1), impurity_importance)
end

function build_tree(
Expand All @@ -141,6 +145,7 @@ function build_tree(
min_samples_leaf = 1,
min_samples_split = 2,
min_purity_increase = 0.0;
n_classes :: Int = length(unique(labels)),
loss = util.entropy :: Function,
rng = Random.GLOBAL_RNG,
impurity_importance :: Bool = true) where {S, T}
Expand All @@ -165,23 +170,24 @@ function build_tree(
min_purity_increase = Float64(min_purity_increase),
rng = rng)

return _build_tree(t, labels, size(features, 2), size(features, 1), impurity_importance)
return _build_tree(t, labels, n_classes, size(features, 2), size(features, 1), impurity_importance)
end

function _build_tree(
tree::treeclassifier.Tree{S, T},
labels::AbstractVector{T},
n_classes::Int,
n_features,
n_samples,
impurity_importance::Bool
) where {S, T}
node = _convert(tree.root, tree.list, labels[tree.labels])
if !impurity_importance
return Root{S, T}(node, n_features, Float64[])
return Root{S, T, n_classes}(node, n_features, Float64[])
else
fi = zeros(Float64, n_features)
update_using_impurity!(fi, tree.root)
return Root{S, T}(node, n_features, fi ./ n_samples)
return Root{S, T, n_classes}(node, n_features, fi ./ n_samples)
end
end

Expand Down Expand Up @@ -221,42 +227,43 @@ function prune_tree(
end
ntt = nsample(tree)
function _prune_run_stump(
tree::LeafOrNode{S, T},
tree::LeafOrNode{S, T, N},
purity_thresh::Real,
fi::Vector{Float64} = Float64[]
) where {S, T}
all_labels = [tree.left.values; tree.right.values]
majority = majority_vote(all_labels)
matches = findall(all_labels .== majority)
purity = length(matches) / length(all_labels)
) where {S, T, N}
combined = tree.left.values .+ tree.right.values
total = tree.left.total + tree.right.total
majority = argmax(combined)
purity = combined[majority] / total
if purity >= purity_thresh
if !isempty(fi)
update_pruned_impurity!(tree, fi, ntt, loss)
end
return Leaf{T}(majority, all_labels)
return Leaf{T, N}(tree.left.classes, majority, combined, total)
else
return tree
end
end
function _prune_run(tree::Root{S, T}, purity_thresh::Real) where {S, T}
function _prune_run(tree::Root{S, T, N}, purity_thresh::Real) where {S, T, N}
fi = deepcopy(tree.featim) ## recalculate feature importances
node = _prune_run(tree.node, purity_thresh, fi)
return Root{S, T}(node, tree.n_feat, fi)
return Root{S, T, N}(node, tree.n_feat, fi)
end
function _prune_run(
tree::LeafOrNode{S, T},
tree::LeafOrNode{S, T, N},
purity_thresh::Real,
fi::Vector{Float64} = Float64[]
) where {S, T}
N = length(tree)
if N == 1 ## a Leaf
) where {S, T, N}
L = length(tree)
if L == 1 ## a Leaf
return tree
elseif N == 2 ## a stump
elseif L == 2 ## a stump
return _prune_run_stump(tree, purity_thresh, fi)
else
left = _prune_run(tree.left, purity_thresh, fi)
right = _prune_run(tree.right, purity_thresh, fi)
return Node{S, T}(tree.featid, tree.featval, left, right)
return Node{S, T, N}(
tree.featid, tree.featval,
_prune_run(tree.left, purity_thresh),
_prune_run(tree.right, purity_thresh))
end
end
pruned = _prune_run(tree, purity_thresh)
Expand All @@ -268,7 +275,7 @@ function prune_tree(
end


apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.majority
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.classes[leaf.majority]
apply_tree(
tree::Root{S, T},
features::AbstractVector{S}
Expand Down Expand Up @@ -314,7 +321,7 @@ of the output matrix.
apply_tree_proba(tree::Root{S, T}, features::AbstractVector{S}, labels) where {S, T} =
apply_tree_proba(tree.node, features, labels)
apply_tree_proba(leaf::Leaf{T}, features::AbstractVector{S}, labels) where {S, T} =
compute_probabilities(labels, leaf.values)
leaf.values ./ leaf.total

function apply_tree_proba(
tree::Node{S, T},
Expand All @@ -329,10 +336,13 @@ function apply_tree_proba(
return apply_tree_proba(tree.right, features, labels)
end
end
apply_tree_proba(tree::Root{S, T}, features::AbstractMatrix{S}, labels) where {S, T} =
apply_tree_proba(tree.node, features, labels)
apply_tree_proba(tree::LeafOrNode{S, T}, features::AbstractMatrix{S}, labels) where {S, T} =
stack_function_results(row->apply_tree_proba(tree, row, labels), features)
function apply_tree_proba(tree::Root{S, T}, features::AbstractMatrix{S}, labels) where {S, T}
predictions = Vector{NTuple{length(labels), Float64}}(undef, size(features, 1))
for i in 1:size(features, 1)
predictions[i] = apply_tree_proba(tree, view(features, i, :), labels)
end
reinterpret(reshape, Float64, predictions) |> transpose |> Matrix
end

function build_forest(
labels :: AbstractVector{T},
Expand Down Expand Up @@ -361,10 +371,11 @@ function build_forest(

t_samples = length(labels)
n_samples = floor(Int, partial_sampling * t_samples)
n_classes = length(unique(labels))

forest = impurity_importance ?
Vector{Root{S, T}}(undef, n_trees) :
Vector{LeafOrNode{S, T}}(undef, n_trees)
Vector{Root{S, T, n_classes}}(undef, n_trees) :
Vector{LeafOrNode{S, T, n_classes}}(undef, n_trees)

entropy_terms = util.compute_entropy_terms(n_samples)
loss = (ns, n) -> util.entropy(ns, n, entropy_terms)
Expand All @@ -382,7 +393,8 @@ function build_forest(
max_depth,
min_samples_leaf,
min_samples_split,
min_purity_increase,
min_purity_increase;
n_classes,
loss = loss,
rng = _rng,
impurity_importance = impurity_importance)
Expand All @@ -398,7 +410,8 @@ function build_forest(
max_depth,
min_samples_leaf,
min_samples_split,
min_purity_increase,
min_purity_increase;
n_classes,
loss = loss,
impurity_importance = impurity_importance)
end
Expand All @@ -408,13 +421,13 @@ function build_forest(
end

function _build_forest(
forest :: Vector{<: Union{Root{S, T}, LeafOrNode{S, T}}},
forest :: Vector{<: Union{Root{S, T, N}, LeafOrNode{S, T, N}}},
n_features ,
n_trees ,
impurity_importance :: Bool) where {S, T}
impurity_importance :: Bool) where {S, T, N}

if !impurity_importance
return Ensemble{S, T}(forest, n_features, Float64[])
return Ensemble{S, T, N}(forest, n_features, Float64[])
else
fi = zeros(Float64, n_features)
for tree in forest
Expand All @@ -424,12 +437,12 @@ function _build_forest(
end
end

forest_new = Vector{LeafOrNode{S, T}}(undef, n_trees)
forest_new = Vector{LeafOrNode{S, T, N}}(undef, n_trees)
Threads.@threads for i in 1:n_trees
forest_new[i] = forest[i].node
end

return Ensemble{S, T}(forest_new, n_features, fi ./ n_trees)
return Ensemble{S, T, N}(forest_new, n_features, fi ./ n_trees)
end
end

Expand Down Expand Up @@ -506,11 +519,13 @@ function build_adaboost_stumps(
stumps = Node{S, T}[]
coeffs = Float64[]
n_features = size(features, 2)
n_classes = length(unique(labels))
for i in 1:n_iterations
new_stump = build_stump(
labels,
features,
weights;
n_classes,
rng=mk_rng(rng),
impurity_importance=false
)
Expand All @@ -530,7 +545,7 @@ function build_adaboost_stumps(
break
end
end
return (Ensemble{S, T}(stumps, n_features, Float64[]), coeffs)
return (Ensemble{S, T, n_classes}(stumps, n_features, Float64[]), coeffs)
end

apply_adaboost_stumps(
Expand Down
Loading