From 4857294f2ed6c97e62a467ba96a7045805876510 Mon Sep 17 00:00:00 2001 From: TEC Date: Mon, 27 Jun 2022 16:21:32 +0800 Subject: [PATCH] Fix tree pruning with NTuples --- src/DecisionTree.jl | 6 +++--- src/classification/main.jl | 28 +++++++++++++--------------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index b71841b0..7eb059e3 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -42,10 +42,10 @@ struct Node{S, T, N} right :: Union{Leaf{T, N}, Node{S, T, N}} end -const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}} +const LeafOrNode{S, T, N} = Union{Leaf{T, N}, Node{S, T, N}} -struct Ensemble{S, T} - trees :: Vector{LeafOrNode{S, T}} +struct Ensemble{S, T, N} + trees :: Vector{LeafOrNode{S, T, N}} end Leaf(features::NTuple{T, N}) where {T, N} = diff --git a/src/classification/main.jl b/src/classification/main.jl index 18e283bc..72f12ec2 100644 --- a/src/classification/main.jl +++ b/src/classification/main.jl @@ -113,27 +113,25 @@ function prune_tree(tree::LeafOrNode{S, T}, purity_thresh=1.0) where {S, T} if purity_thresh >= 1.0 return tree end - function _prune_run(tree::LeafOrNode{S, T}, purity_thresh::Real) where {S, T} - N = length(tree) - if N == 1 ## a Leaf + function _prune_run(tree::LeafOrNode{S, T, N}, purity_thresh::Real) where {S, T, N} + L = length(tree) + if L == 1 ## a Leaf return tree - elseif N == 2 ## a stump - all_labels = [tree.left.values; tree.right.values] - majority = majority_vote(all_labels) - matches = findall(all_labels .== majority) - purity = length(matches) / length(all_labels) + elseif L == 2 ## a stump + combined = tree.left.values .+ tree.right.values + total = tree.left.total + tree.right.total + majority = argmax(combined) + purity = combined[majority] / total if purity >= purity_thresh - features = Tuple(unique(all_labels)) - featfreq = Tuple(sum(all_labels .== f) for f in features) - return Leaf{T}(features, argmax(featfreq), - featfreq, length(all_labels)) + return Leaf{T, N}(tree.left.features, majority, combined, total) else return tree end else - return Node{S, T}(tree.featid, tree.featval, - _prune_run(tree.left, purity_thresh), - _prune_run(tree.right, purity_thresh)) + return Node{S, T, N}( + tree.featid, tree.featval, + _prune_run(tree.left, purity_thresh), + _prune_run(tree.right, purity_thresh)) end end pruned = _prune_run(tree, purity_thresh)