Skip to content

Commit

Permalink
[CORE][VL] Avoid re-exploring explored nodes in DpPlanner (apache#5363)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer committed Apr 11, 2024
1 parent 1601f26 commit bc99c49
Show file tree
Hide file tree
Showing 14 changed files with 184 additions and 68 deletions.
14 changes: 8 additions & 6 deletions gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class Ras[T <: AnyRef] private (

private[ras] def isInfCost(cost: Cost) = costModel.costComparator().equiv(cost, infCost)

private[ras] def toUnsafeKey(node: T): UnsafeKey[T] = UnsafeKey(this, node)
private[ras] def toHashKey(node: T): UnsafeHashKey[T] = UnsafeHashKey(this, node)
}

object Ras {
Expand Down Expand Up @@ -251,15 +251,17 @@ object Ras {
}
}

trait UnsafeKey[T]
trait UnsafeHashKey[T]

private object UnsafeKey {
def apply[T <: AnyRef](ras: Ras[T], self: T): UnsafeKey[T] = new UnsafeKeyImpl(ras, self)
private class UnsafeKeyImpl[T <: AnyRef](ras: Ras[T], val self: T) extends UnsafeKey[T] {
private object UnsafeHashKey {
def apply[T <: AnyRef](ras: Ras[T], self: T): UnsafeHashKey[T] =
new UnsafeHashKeyImpl(ras, self)
private class UnsafeHashKeyImpl[T <: AnyRef](ras: Ras[T], val self: T)
extends UnsafeHashKey[T] {
override def hashCode(): Int = ras.planModel.hashCode(self)
override def equals(other: Any): Boolean = {
other match {
case that: UnsafeKeyImpl[T] => ras.planModel.equals(self, that.self)
case that: UnsafeHashKeyImpl[T] => ras.planModel.equals(self, that.self)
case _ => false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.gluten.ras

import org.apache.gluten.ras.Ras.UnsafeKey
import org.apache.gluten.ras.Ras.UnsafeHashKey
import org.apache.gluten.ras.memo.MemoTable
import org.apache.gluten.ras.property.PropertySet

Expand Down Expand Up @@ -55,16 +55,16 @@ object RasCluster {
override val ras: Ras[T],
metadata: Metadata)
extends MutableRasCluster[T] {
private val deDup: mutable.Set[UnsafeKey[T]] = mutable.Set()
private val deDup: mutable.Set[UnsafeHashKey[T]] = mutable.Set()
private val buffer: mutable.ListBuffer[CanonicalNode[T]] =
mutable.ListBuffer()

override def contains(t: CanonicalNode[T]): Boolean = {
deDup.contains(t.toUnsafeKey())
deDup.contains(t.toHashKey())
}

override def add(t: CanonicalNode[T]): Unit = {
val key = t.toUnsafeKey()
val key = t.toHashKey()
assert(!deDup.contains(key))
ras.metadataModel.verify(metadata, ras.metadataModel.metadataOf(t.self()))
deDup += key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.gluten.ras

import org.apache.gluten.ras.Ras.UnsafeKey
import org.apache.gluten.ras.Ras.UnsafeHashKey
import org.apache.gluten.ras.property.PropertySet

trait RasNode[T <: AnyRef] {
Expand All @@ -43,7 +43,7 @@ object RasNode {
node.asInstanceOf[GroupNode[T]]
}

def toUnsafeKey(): UnsafeKey[T] = node.ras().toUnsafeKey(node.self())
def toHashKey(): UnsafeHashKey[T] = node.ras().toHashKey(node.self())
}
}

Expand Down Expand Up @@ -131,16 +131,16 @@ object InGroupNode {
private case class InGroupNodeImpl[T <: AnyRef](groupId: Int, can: CanonicalNode[T])
extends InGroupNode[T]

trait HashKey extends Any
trait UniqueKey extends Any

implicit class InGroupNodeImplicits[T <: AnyRef](n: InGroupNode[T]) {
import InGroupNodeImplicits._
def toHashKey: HashKey =
InGroupNodeHashKeyImpl(n.groupId, System.identityHashCode(n.can))
def toUniqueKey: UniqueKey =
InGroupNodeUniqueKeyImpl(n.groupId, System.identityHashCode(n.can))
}

private object InGroupNodeImplicits {
private case class InGroupNodeHashKeyImpl(gid: Int, cid: Int) extends HashKey
private case class InGroupNodeUniqueKeyImpl(gid: Int, cid: Int) extends UniqueKey
}
}

Expand All @@ -159,15 +159,16 @@ object InClusterNode {
can: CanonicalNode[T])
extends InClusterNode[T]

trait HashKey extends Any
trait UniqueKey extends Any

implicit class InClusterNodeImplicits[T <: AnyRef](n: InClusterNode[T]) {
import InClusterNodeImplicits._
def toHashKey: HashKey =
InClusterNodeHashKeyImpl(n.clusterKey, System.identityHashCode(n.can))
def toUniqueKey: UniqueKey =
InClusterNodeUniqueKeyImpl(n.clusterKey, System.identityHashCode(n.can))
}

private object InClusterNodeImplicits {
private case class InClusterNodeHashKeyImpl(clusterKey: RasClusterKey, cid: Int) extends HashKey
private case class InClusterNodeUniqueKeyImpl(clusterKey: RasClusterKey, cid: Int)
extends UniqueKey
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ object Best {
bestPath: KnownCostPath[T],
winnerNodes: Seq[InGroupNode[T]],
costs: InGroupNode[T] => Option[Cost]): Best[T] = {
val bestNodes = mutable.Set[InGroupNode.HashKey]()
val bestNodes = mutable.Set[InGroupNode.UniqueKey]()

def dfs(groupId: Int, cursor: RasPath.PathNode[T]): Unit = {
val can = cursor.self().asCanonical()
bestNodes += InGroupNode(groupId, can).toHashKey
bestNodes += InGroupNode(groupId, can).toUniqueKey
cursor.zipChildrenWithGroupIds().foreach {
case (childPathNode, childGroupId) =>
dfs(childGroupId, childPathNode)
Expand All @@ -76,14 +76,14 @@ object Best {
dfs(rootGroupId, bestPath.rasPath.node())

val bestNodeSet = bestNodes.toSet
val winnerNodeSet = winnerNodes.map(_.toHashKey).toSet
val winnerNodeSet = winnerNodes.map(_.toUniqueKey).toSet

BestImpl(
ras,
rootGroupId,
bestPath,
n => bestNodeSet.contains(n.toHashKey),
n => winnerNodeSet.contains(n.toHashKey),
n => bestNodeSet.contains(n.toUniqueKey),
n => winnerNodeSet.contains(n.toUniqueKey),
costs)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,17 @@ object BestFinder {

val bestPath = groupToCosts(group.id()).best()
val winnerNodes = groupToCosts.map { case (id, g) => InGroupNode(id, g.bestNode) }.toSeq
val costsMap = mutable.Map[InGroupNode.HashKey, Cost]()
val costsMap = mutable.Map[InGroupNode.UniqueKey, Cost]()
groupToCosts.foreach {
case (gid, g) =>
g.nodes.foreach {
n =>
val c = g.nodeToCost(n)
if (c.nonEmpty) {
costsMap += (InGroupNode(gid, n).toHashKey -> c.get.cost)
costsMap += (InGroupNode(gid, n).toUniqueKey -> c.get.cost)
}
}
}
Best(ras, group.id(), bestPath, winnerNodes, ign => costsMap.get(ign.toHashKey))
Best(ras, group.id(), bestPath, winnerNodes, ign => costsMap.get(ign.toUniqueKey))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ object DpClusterAlgo {
clusterAlgoDef: DpClusterAlgoDef[T, NodeOutput, ClusterOutput])
extends DpZipperAlgoDef[InClusterNode[T], RasClusterKey, NodeOutput, ClusterOutput] {
override def idOfX(x: InClusterNode[T]): Any = {
x.toHashKey
x.toUniqueKey
}

override def idOfY(y: RasClusterKey): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ object DpGroupAlgo {
groupAlgoDef: DpGroupAlgoDef[T, NodeOutput, GroupOutput])
extends DpZipperAlgoDef[InGroupNode[T], RasGroup[T], NodeOutput, GroupOutput] {
override def idOfX(x: InGroupNode[T]): Any = {
x.toHashKey
x.toUniqueKey
}

override def idOfY(y: RasGroup[T]): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.gluten.ras.Best.KnownCostPath
import org.apache.gluten.ras.best.BestFinder
import org.apache.gluten.ras.dp.DpZipperAlgo.Adjustment.Panel
import org.apache.gluten.ras.memo.{Memo, MemoTable}
import org.apache.gluten.ras.path.{InClusterPath, PathFinder, RasPath}
import org.apache.gluten.ras.path._
import org.apache.gluten.ras.property.PropertySet
import org.apache.gluten.ras.rule.{EnforcerRuleSet, RuleApplier, Shape}

Expand Down Expand Up @@ -99,10 +99,16 @@ object DpPlanner {
rules: Seq[RuleApplier[T]],
enforcerRuleSet: EnforcerRuleSet[T])
extends DpClusterAlgo.Adjustment[T] {
import ExploreAdjustment._

private val ruleShapes: Seq[Shape[T]] = rules.map(_.shape())

override def exploreChildX(
panel: Panel[InClusterNode[T], RasClusterKey],
x: InClusterNode[T]): Unit = {}
x: InClusterNode[T]): Unit = {
applyRulesOnNode(panel, x.clusterKey, x.can)
}

override def exploreChildY(
panel: Panel[InClusterNode[T], RasClusterKey],
y: RasClusterKey): Unit = {}
Expand All @@ -115,20 +121,24 @@ object DpPlanner {
cKey: RasClusterKey): Unit = {
memoTable.doExhaustively {
applyEnforcerRules(panel, cKey)
applyRules(panel, cKey)
}
}

private def applyRules(
private def applyRulesOnNode(
panel: Panel[InClusterNode[T], RasClusterKey],
cKey: RasClusterKey): Unit = {
cKey: RasClusterKey,
can: CanonicalNode[T]): Unit = {
if (rules.isEmpty) {
return
}
val dummyGroup = memoTable.getDummyGroup(cKey)
val shapes = rules.map(_.shape())
findPaths(GroupNode(ras, dummyGroup), shapes) {
path => rules.foreach(rule => applyRule(panel, cKey, rule, path))
findPaths(GroupNode(ras, dummyGroup), ruleShapes, List(new FromSingleNode[T](can))) {
path =>
val rootNode = path.node().self()
if (rootNode.isCanonical) {
assert(rootNode.asCanonical() eq can)
}
rules.foreach(rule => applyRule(panel, cKey, rule, path))
}
}

Expand All @@ -137,27 +147,34 @@ object DpPlanner {
cKey: RasClusterKey): Unit = {
val dummyGroup = memoTable.getDummyGroup(cKey)
cKey.propSets(memoTable).foreach {
constraintSet =>
constraintSet: PropertySet[T] =>
val enforcerRules = enforcerRuleSet.rulesOf(constraintSet)
if (enforcerRules.nonEmpty) {
val shapes = enforcerRules.map(_.shape())
findPaths(GroupNode(ras, dummyGroup), shapes) {
val shapes = enforcerRuleSet.ruleShapesOf(constraintSet)
findPaths(GroupNode(ras, dummyGroup), shapes, List.empty) {
path => enforcerRules.foreach(rule => applyRule(panel, cKey, rule, path))
}
}
}
}

private def findPaths(gn: GroupNode[T], shapes: Seq[Shape[T]])(
private def findPaths(gn: GroupNode[T], shapes: Seq[Shape[T]], filters: Seq[FilterWizard[T]])(
onFound: RasPath[T] => Unit): Unit = {
val finder = shapes
val finderBuilder = shapes
.foldLeft(
PathFinder
.builder(ras, memoTable)) {
case (builder, shape) =>
builder.output(shape.wizard())
}

val finder = filters
.foldLeft(finderBuilder) {
case (builder, filter) =>
builder.filter(filter)
}
.build()

finder.find(gn).foreach(path => onFound(path))
}

Expand Down Expand Up @@ -191,5 +208,22 @@ object DpPlanner {
}
}

private object ExploreAdjustment {}
private object ExploreAdjustment {
private class FromSingleNode[T <: AnyRef](from: CanonicalNode[T]) extends FilterWizard[T] {
override def omit(can: CanonicalNode[T]): FilterWizard.FilterAction[T] = {
if (can eq from) {
return FilterWizard.FilterAction.Continue(this)
}
FilterWizard.FilterAction.omit
}

override def omit(group: GroupNode[T]): FilterWizard.FilterAction[T] =
FilterWizard.FilterAction.Continue(this)

override def advance(offset: Int, count: Int): FilterWizard.FilterAdvanceAction[T] = {
// We only filter on nodes from the root group. So continue with a noop filter.
FilterWizard.FilterAdvanceAction.Continue(FilterWizards.none())
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.gluten.ras.memo

import org.apache.gluten.ras._
import org.apache.gluten.ras.Ras.UnsafeKey
import org.apache.gluten.ras.Ras.UnsafeHashKey
import org.apache.gluten.ras.property.PropertySet
import org.apache.gluten.ras.vis.GraphvizVisualizer

Expand Down Expand Up @@ -236,11 +236,11 @@ object Memo {
private object MemoCacheKey {
def apply[T <: AnyRef](ras: Ras[T], self: T): MemoCacheKey[T] = {
assert(ras.isCanonical(self))
MemoCacheKey[T](ras.toUnsafeKey(self))
MemoCacheKey[T](ras.toHashKey(self))
}
}

private case class MemoCacheKey[T <: AnyRef] private (delegate: UnsafeKey[T])
private case class MemoCacheKey[T <: AnyRef] private (delegate: UnsafeHashKey[T])
}

trait MemoStore[T <: AnyRef] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ object FilterWizards {
OmitCycles[T](CycleDetector[GroupNode[T]]((one, other) => one.groupId() == other.groupId()))
}

def none[T <: AnyRef](): FilterWizard[T] = {
None[T]()
}

private class None[T <: AnyRef] private () extends FilterWizard[T] {
override def omit(can: CanonicalNode[T]): FilterAction[T] = FilterAction.Continue(this)
override def omit(group: GroupNode[T]): FilterAction[T] = FilterAction.Continue(this)
override def advance(offset: Int, count: Int): FilterAdvanceAction[T] =
FilterAdvanceAction.Continue(this)
}

private object None {
def apply[T <: AnyRef](): None[T] = new None[T]()
}

// Cycle detection starts from the first visited group in the input path.
private class OmitCycles[T <: AnyRef] private (detector: CycleDetector[GroupNode[T]])
extends FilterWizard[T] {
Expand Down
Loading

0 comments on commit bc99c49

Please sign in to comment.