Skip to content

Commit

Permalink
Fix tree pruning with NTuples
Browse files Browse the repository at this point in the history
  • Loading branch information
tecosaur committed Jun 27, 2022
1 parent 1a452f4 commit 4857294
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/DecisionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ struct Node{S, T, N}
right :: Union{Leaf{T, N}, Node{S, T, N}}
end

const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
const LeafOrNode{S, T, N} = Union{Leaf{T, N}, Node{S, T, N}}

struct Ensemble{S, T}
trees :: Vector{LeafOrNode{S, T}}
struct Ensemble{S, T, N}
trees :: Vector{LeafOrNode{S, T, N}}
end

Leaf(features::NTuple{T, N}) where {T, N} =
Expand Down
28 changes: 13 additions & 15 deletions src/classification/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,25 @@ function prune_tree(tree::LeafOrNode{S, T}, purity_thresh=1.0) where {S, T}
if purity_thresh >= 1.0
return tree
end
function _prune_run(tree::LeafOrNode{S, T}, purity_thresh::Real) where {S, T}
N = length(tree)
if N == 1 ## a Leaf
function _prune_run(tree::LeafOrNode{S, T, N}, purity_thresh::Real) where {S, T, N}
L = length(tree)
if L == 1 ## a Leaf
return tree
elseif N == 2 ## a stump
all_labels = [tree.left.values; tree.right.values]
majority = majority_vote(all_labels)
matches = findall(all_labels .== majority)
purity = length(matches) / length(all_labels)
elseif L == 2 ## a stump
combined = tree.left.values .+ tree.right.values
total = tree.left.total + tree.right.total
majority = argmax(combined)
purity = combined[majority] / total
if purity >= purity_thresh
features = Tuple(unique(all_labels))
featfreq = Tuple(sum(all_labels .== f) for f in features)
return Leaf{T}(features, argmax(featfreq),
featfreq, length(all_labels))
return Leaf{T, N}(tree.left.features, majority, combined, total)
else
return tree
end
else
return Node{S, T}(tree.featid, tree.featval,
_prune_run(tree.left, purity_thresh),
_prune_run(tree.right, purity_thresh))
return Node{S, T, N}(
tree.featid, tree.featval,
_prune_run(tree.left, purity_thresh),
_prune_run(tree.right, purity_thresh))
end
end
pruned = _prune_run(tree, purity_thresh)
Expand Down

0 comments on commit 4857294

Please sign in to comment.