Skip to content

Commit

Permalink
Fix more test results
Browse files Browse the repository at this point in the history
  • Loading branch information
tecosaur committed Dec 2, 2022
1 parent 4f94a27 commit 700afd2
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 38 deletions.
15 changes: 10 additions & 5 deletions src/DecisionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export InfoNode, InfoLeaf, wrap
########## Types ##########

struct Leaf{T, N}
features :: NTuple{N, T}
classes :: NTuple{N, T}
majority :: Int
values :: NTuple{N, Int}
total :: Int
Expand All @@ -54,15 +54,20 @@ struct Ensemble{S, T, N}
featim :: Vector{Float64}
end

Leaf(features::NTuple{T, N}) where {T, N} =
Leaf(features::NTuple{N, T}) where {T, N} =
Leaf(features, 0, Tuple(zeros(T, N)), 0)
Leaf(features::NTuple{N, T}, frequencies::NTuple{N, Int}) where {T, N} =
Leaf(features, argmax(frequencies), frequencies, sum(frequencies))
Leaf(features::Union{<:AbstractVector, <:Tuple},
frequencies::Union{<:AbstractVector{Int}, <:Tuple}) =
Leaf(Tuple(features), Tuple(frequencies))

is_leaf(l::Leaf) = true
is_leaf(n::Node) = false

_zero(::Type{String}) = ""
_zero(x::Any) = zero(x)
convert(::Type{Node{S, T, N}}, lf::Leaf{T, N}) where {S, T, N} = Node(0, _zero(S), lf, Leaf(lf.features))
convert(::Type{Node{S, T, N}}, lf::Leaf{T, N}) where {S, T, N} = Node(0, _zero(S), lf, Leaf(lf.classes))
convert(::Type{Root{S, T, N}}, node::LeafOrNode{S, T, N}) where {S, T, N} = Root{S, T, N}(node, 0, Float64[])
convert(::Type{LeafOrNode{S, T, N}}, tree::Root{S, T, N}) where {S, T, N} = tree.node
promote_rule(::Type{Node{S, T, N}}, ::Type{Leaf{T, N}}) where {S, T, N} = Node{S, T, N}
Expand Down Expand Up @@ -101,7 +106,7 @@ depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
depth(tree::Root) = depth(tree.node)

function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
println(io, leaf.features[leaf.majority], " : ",
println(io, leaf.classes[leaf.majority], " : ",
leaf.values[leaf.majority], '/', leaf.total)
end
function print_tree(leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
Expand Down Expand Up @@ -165,7 +170,7 @@ end

function show(io::IO, leaf::Leaf)
println(io, "Decision Leaf")
println(io, "Majority: ", leaf.features[leaf.majority])
println(io, "Majority: ", leaf.classes[leaf.majority])
print(io, "Samples: ", leaf.total)
end

Expand Down
43 changes: 25 additions & 18 deletions src/classification/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ function _convert(
) where {S, T}

if node.is_leaf
featfreq = Tuple(sum(labels[node.region] .== l) for l in list)
classfreq = Tuple(sum(labels[node.region] .== l) for l in list)
return Leaf{T, length(list)}(
Tuple(list), argmax(featfreq), featfreq, length(node.region))
Tuple(list), argmax(classfreq), classfreq, length(node.region))
else
left = _convert(node.l, list, labels)
right = _convert(node.r, list, labels)
Expand Down Expand Up @@ -117,6 +117,7 @@ function build_stump(
labels :: AbstractVector{T},
features :: AbstractMatrix{S},
weights = nothing;
n_classes :: Int = length(unique(labels)),
rng = Random.GLOBAL_RNG,
impurity_importance :: Bool = true) where {S, T}

Expand All @@ -133,7 +134,7 @@ function build_stump(
min_purity_increase = 0.0;
rng = rng)

return _build_tree(t, labels, size(features, 2), size(features, 1), impurity_importance)
return _build_tree(t, labels, n_classes, size(features, 2), size(features, 1), impurity_importance)
end

function build_tree(
Expand All @@ -144,6 +145,7 @@ function build_tree(
min_samples_leaf = 1,
min_samples_split = 2,
min_purity_increase = 0.0;
n_classes :: Int = length(unique(labels)),
loss = util.entropy :: Function,
rng = Random.GLOBAL_RNG,
impurity_importance :: Bool = true) where {S, T}
Expand All @@ -168,18 +170,18 @@ function build_tree(
min_purity_increase = Float64(min_purity_increase),
rng = rng)

return _build_tree(t, labels, size(features, 2), size(features, 1), impurity_importance)
return _build_tree(t, labels, n_classes, size(features, 2), size(features, 1), impurity_importance)
end

function _build_tree(
tree::treeclassifier.Tree{S, T},
labels::AbstractVector{T},
n_classes::Int,
n_features,
n_samples,
impurity_importance::Bool
) where {S, T}
node = _convert(tree.root, tree.list, labels[tree.labels])
n_classes = unique(labels) |> length
if !impurity_importance
return Root{S, T, n_classes}(node, n_features, Float64[])
else
Expand Down Expand Up @@ -237,15 +239,15 @@ function prune_tree(
if !isempty(fi)
update_pruned_impurity!(tree, fi, ntt, loss)
end
return Leaf{T, N}(tree.left.features, majority, combined, total)
return Leaf{T, N}(tree.left.classes, majority, combined, total)
else
return tree
end
end
function _prune_run(tree::Root{S, T, N}, purity_thresh::Real) where {S, T, N}
fi = deepcopy(tree.featim) ## recalculate feature importances
node = _prune_run(tree.node, purity_thresh, fi)
return Root{S, T, N}(node, fi)
return Root{S, T, N}(node, tree.n_feat, fi)
end
function _prune_run(
tree::LeafOrNode{S, T, N},
Expand Down Expand Up @@ -273,7 +275,7 @@ function prune_tree(
end


apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.features[leaf.majority]
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.classes[leaf.majority]
apply_tree(
tree::Root{S, T},
features::AbstractVector{S}
Expand Down Expand Up @@ -369,10 +371,11 @@ function build_forest(

t_samples = length(labels)
n_samples = floor(Int, partial_sampling * t_samples)
n_classes = length(unique(labels))

forest = impurity_importance ?
Vector{Root{S, T}}(undef, n_trees) :
Vector{LeafOrNode{S, T}}(undef, n_trees)
Vector{Root{S, T, n_classes}}(undef, n_trees) :
Vector{LeafOrNode{S, T, n_classes}}(undef, n_trees)

entropy_terms = util.compute_entropy_terms(n_samples)
loss = (ns, n) -> util.entropy(ns, n, entropy_terms)
Expand All @@ -390,7 +393,8 @@ function build_forest(
max_depth,
min_samples_leaf,
min_samples_split,
min_purity_increase,
min_purity_increase;
n_classes,
loss = loss,
rng = _rng,
impurity_importance = impurity_importance)
Expand All @@ -406,7 +410,8 @@ function build_forest(
max_depth,
min_samples_leaf,
min_samples_split,
min_purity_increase,
min_purity_increase;
n_classes,
loss = loss,
impurity_importance = impurity_importance)
end
Expand All @@ -416,13 +421,13 @@ function build_forest(
end

function _build_forest(
forest :: Vector{<: Union{Root{S, T}, LeafOrNode{S, T}}},
forest :: Vector{<: Union{Root{S, T, N}, LeafOrNode{S, T, N}}},
n_features ,
n_trees ,
impurity_importance :: Bool) where {S, T}
impurity_importance :: Bool) where {S, T, N}

if !impurity_importance
return Ensemble{S, T}(forest, n_features, Float64[])
return Ensemble{S, T, N}(forest, n_features, Float64[])
else
fi = zeros(Float64, n_features)
for tree in forest
Expand All @@ -432,12 +437,12 @@ function _build_forest(
end
end

forest_new = Vector{LeafOrNode{S, T}}(undef, n_trees)
forest_new = Vector{LeafOrNode{S, T, N}}(undef, n_trees)
Threads.@threads for i in 1:n_trees
forest_new[i] = forest[i].node
end

return Ensemble{S, T}(forest_new, n_features, fi ./ n_trees)
return Ensemble{S, T, N}(forest_new, n_features, fi ./ n_trees)
end
end

Expand Down Expand Up @@ -514,11 +519,13 @@ function build_adaboost_stumps(
stumps = Node{S, T}[]
coeffs = Float64[]
n_features = size(features, 2)
n_classes = length(unique(labels))
for i in 1:n_iterations
new_stump = build_stump(
labels,
features,
weights;
n_classes,
rng=mk_rng(rng),
impurity_importance=false
)
Expand All @@ -538,7 +545,7 @@ function build_adaboost_stumps(
break
end
end
return (Ensemble{S, T}(stumps, n_features, Float64[]), coeffs)
return (Ensemble{S, T, n_classes}(stumps, n_features, Float64[]), coeffs)
end

apply_adaboost_stumps(
Expand Down
22 changes: 13 additions & 9 deletions src/regression/main.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
include("tree.jl")

function _convert(node::treeregressor.NodeMeta{S}, labels::Array{T}) where {S, T <: Float64}
classes = Tuple(unique(labels))
if node.is_leaf
features = Tuple(unique(labels))
featfreq = Tuple(sum(labels[node.region] .== f) for f in features)
return Leaf{T, length(features)}(
features, argmax(featfreq), featfreq, length(node.region))
classfreq = Tuple(sum(labels[node.region] .== f) for f in classes)
return Leaf{T, length(classes)}(
classes, argmax(classfreq), classfreq, length(node.region))
else
left = _convert(node.l, labels)
right = _convert(node.r, labels)
return Node{S, T}(node.feature, node.threshold, left, right)
return Node{S, T, length(classes)}(node.feature, node.threshold, left, right)
end
end

Expand All @@ -34,6 +34,7 @@ function build_tree(
min_samples_leaf = 5,
min_samples_split = 2,
min_purity_increase = 0.0;
n_classes :: Int = length(unique(labels)),
rng = Random.GLOBAL_RNG,
impurity_importance:: Bool = true) where {S, T <: Float64}

Expand All @@ -59,11 +60,11 @@ function build_tree(
node = _convert(t.root, labels[t.labels])
n_features = size(features, 2)
if !impurity_importance
return Root{S, T}(node, n_features, Float64[])
return Root{S, T, n_classes}(node, n_features, Float64[])
else
fi = zeros(Float64, n_features)
update_using_impurity!(fi, t.root)
return Root{S, T}(node, n_features, fi ./ size(features, 1))
return Root{S, T, n_classes}(node, n_features, fi ./ size(features, 1))
end
end

Expand All @@ -77,6 +78,7 @@ function build_forest(
min_samples_leaf = 5,
min_samples_split = 2,
min_purity_increase = 0.0;
n_classes :: Int = length(unique(labels)),
rng::Union{Integer,AbstractRNG} = Random.GLOBAL_RNG,
impurity_importance :: Bool = true) where {S, T <: Float64}

Expand Down Expand Up @@ -110,7 +112,8 @@ function build_forest(
max_depth,
min_samples_leaf,
min_samples_split,
min_purity_increase,
min_purity_increase;
n_classes,
rng = _rng,
impurity_importance = impurity_importance)
end
Expand All @@ -125,7 +128,8 @@ function build_forest(
max_depth,
min_samples_leaf,
min_samples_split,
min_purity_increase,
min_purity_increase;
n_classes,
impurity_importance = impurity_importance)
end
end
Expand Down
8 changes: 4 additions & 4 deletions test/miscellaneous/abstract_trees_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ clabel_pattern(clabel) = "─ " * clabel * " (" # class labels are embedde
check_occurence(str_tree, pool, pattern) = count(map(elem -> occursin(pattern(elem), str_tree), pool)) == length(pool)

@info("Test base functionality")
l1 = Leaf(1, [1,1,2])
l2 = Leaf(2, [1,2,2])
l3 = Leaf(3, [3,3,1])
l1 = Leaf((1,2,3), 1, (2, 1, 0), 3)
l2 = Leaf((1,2,3), 2, (1, 2, 0), 3)
l3 = Leaf((1,2,3), 3, (1, 0, 2), 3)
n2 = Node(2, 0.5, l2, l3)
n1 = Node(1, 0.7, l1, n2)
feature_names = ["firstFt", "secondFt"]
Expand Down Expand Up @@ -81,4 +81,4 @@ end
traverse_tree(leaf::InfoLeaf) = nothing

traverse_tree(wrapped_tree)
end
end
4 changes: 2 additions & 2 deletions test/miscellaneous/convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

@testset "convert.jl" begin

lf = Leaf(1, [1])
lf = Leaf([1], [1])
nv = Node{Int, Int}[]
rv = Root{Int, Int}[]
push!(nv, lf)
Expand All @@ -22,7 +22,7 @@ push!(rv, nv[1])
@test apply_tree(rv[1], [0]) == 1.0
@test apply_tree(rv[2], [0]) == 1.0

lf = Leaf("A", ["B", "A"])
lf = Leaf(["A", "B"], [2, 1])
nv = Node{Int, String}[]
rv = Root{Int, String}[]
push!(nv, lf)
Expand Down

0 comments on commit 700afd2

Please sign in to comment.