From bc99c4910e26648195fec789c090b06c2b4379e2 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Thu, 11 Apr 2024 15:53:26 +0800 Subject: [PATCH] [CORE][VL] Avoid re-exploring explored nodes in DpPlanner (#5363) --- .../scala/org/apache/gluten/ras/Ras.scala | 14 +++-- .../org/apache/gluten/ras/RasCluster.scala | 8 +-- .../scala/org/apache/gluten/ras/RasNode.scala | 21 ++++--- .../org/apache/gluten/ras/RasPlanner.scala | 10 +-- .../apache/gluten/ras/best/BestFinder.scala | 6 +- .../apache/gluten/ras/dp/DpClusterAlgo.scala | 2 +- .../apache/gluten/ras/dp/DpGroupAlgo.scala | 2 +- .../org/apache/gluten/ras/dp/DpPlanner.scala | 62 ++++++++++++++----- .../org/apache/gluten/ras/memo/Memo.scala | 6 +- .../apache/gluten/ras/path/OutputFilter.scala | 15 +++++ .../apache/gluten/ras/rule/EnforcerRule.scala | 39 +++++++----- .../apache/gluten/ras/rule/RuleApplier.scala | 10 +-- .../gluten/ras/vis/GraphvizVisualizer.scala | 4 +- .../org/apache/gluten/ras/RasSuite.scala | 53 ++++++++++++++++ 14 files changed, 184 insertions(+), 68 deletions(-) diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala index f3d46847e6a0..804d04d814e5 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala @@ -171,7 +171,7 @@ class Ras[T <: AnyRef] private ( private[ras] def isInfCost(cost: Cost) = costModel.costComparator().equiv(cost, infCost) - private[ras] def toUnsafeKey(node: T): UnsafeKey[T] = UnsafeKey(this, node) + private[ras] def toHashKey(node: T): UnsafeHashKey[T] = UnsafeHashKey(this, node) } object Ras { @@ -251,15 +251,17 @@ object Ras { } } - trait UnsafeKey[T] + trait UnsafeHashKey[T] - private object UnsafeKey { - def apply[T <: AnyRef](ras: Ras[T], self: T): UnsafeKey[T] = new UnsafeKeyImpl(ras, self) - private class UnsafeKeyImpl[T <: AnyRef](ras: Ras[T], val self: T) extends UnsafeKey[T] { + private object UnsafeHashKey { + def apply[T <: AnyRef](ras: Ras[T], self: T): UnsafeHashKey[T] = + new UnsafeHashKeyImpl(ras, self) + private class UnsafeHashKeyImpl[T <: AnyRef](ras: Ras[T], val self: T) + extends UnsafeHashKey[T] { override def hashCode(): Int = ras.planModel.hashCode(self) override def equals(other: Any): Boolean = { other match { - case that: UnsafeKeyImpl[T] => ras.planModel.equals(self, that.self) + case that: UnsafeHashKeyImpl[T] => ras.planModel.equals(self, that.self) case _ => false } } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala index 1b30e1242c82..eb2b41a91fa0 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.ras -import org.apache.gluten.ras.Ras.UnsafeKey +import org.apache.gluten.ras.Ras.UnsafeHashKey import org.apache.gluten.ras.memo.MemoTable import org.apache.gluten.ras.property.PropertySet @@ -55,16 +55,16 @@ object RasCluster { override val ras: Ras[T], metadata: Metadata) extends MutableRasCluster[T] { - private val deDup: mutable.Set[UnsafeKey[T]] = mutable.Set() + private val deDup: mutable.Set[UnsafeHashKey[T]] = mutable.Set() private val buffer: mutable.ListBuffer[CanonicalNode[T]] = mutable.ListBuffer() override def contains(t: CanonicalNode[T]): Boolean = { - deDup.contains(t.toUnsafeKey()) + deDup.contains(t.toHashKey()) } override def add(t: CanonicalNode[T]): Unit = { - val key = t.toUnsafeKey() + val key = t.toHashKey() assert(!deDup.contains(key)) ras.metadataModel.verify(metadata, ras.metadataModel.metadataOf(t.self())) deDup += key diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala index 65ff8b735e18..710a4e682293 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.ras -import org.apache.gluten.ras.Ras.UnsafeKey +import org.apache.gluten.ras.Ras.UnsafeHashKey import org.apache.gluten.ras.property.PropertySet trait RasNode[T <: AnyRef] { @@ -43,7 +43,7 @@ object RasNode { node.asInstanceOf[GroupNode[T]] } - def toUnsafeKey(): UnsafeKey[T] = node.ras().toUnsafeKey(node.self()) + def toHashKey(): UnsafeHashKey[T] = node.ras().toHashKey(node.self()) } } @@ -131,16 +131,16 @@ object InGroupNode { private case class InGroupNodeImpl[T <: AnyRef](groupId: Int, can: CanonicalNode[T]) extends InGroupNode[T] - trait HashKey extends Any + trait UniqueKey extends Any implicit class InGroupNodeImplicits[T <: AnyRef](n: InGroupNode[T]) { import InGroupNodeImplicits._ - def toHashKey: HashKey = - InGroupNodeHashKeyImpl(n.groupId, System.identityHashCode(n.can)) + def toUniqueKey: UniqueKey = + InGroupNodeUniqueKeyImpl(n.groupId, System.identityHashCode(n.can)) } private object InGroupNodeImplicits { - private case class InGroupNodeHashKeyImpl(gid: Int, cid: Int) extends HashKey + private case class InGroupNodeUniqueKeyImpl(gid: Int, cid: Int) extends UniqueKey } } @@ -159,15 +159,16 @@ object InClusterNode { can: CanonicalNode[T]) extends InClusterNode[T] - trait HashKey extends Any + trait UniqueKey extends Any implicit class InClusterNodeImplicits[T <: AnyRef](n: InClusterNode[T]) { import InClusterNodeImplicits._ - def toHashKey: HashKey = - InClusterNodeHashKeyImpl(n.clusterKey, System.identityHashCode(n.can)) + def toUniqueKey: UniqueKey = + InClusterNodeUniqueKeyImpl(n.clusterKey, System.identityHashCode(n.can)) } private object InClusterNodeImplicits { - private case class InClusterNodeHashKeyImpl(clusterKey: RasClusterKey, cid: Int) extends HashKey + private case class InClusterNodeUniqueKeyImpl(clusterKey: RasClusterKey, cid: Int) + extends UniqueKey } } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala index 74793a3d0fbc..327b980f38ec 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala @@ -62,11 +62,11 @@ object Best { bestPath: KnownCostPath[T], winnerNodes: Seq[InGroupNode[T]], costs: InGroupNode[T] => Option[Cost]): Best[T] = { - val bestNodes = mutable.Set[InGroupNode.HashKey]() + val bestNodes = mutable.Set[InGroupNode.UniqueKey]() def dfs(groupId: Int, cursor: RasPath.PathNode[T]): Unit = { val can = cursor.self().asCanonical() - bestNodes += InGroupNode(groupId, can).toHashKey + bestNodes += InGroupNode(groupId, can).toUniqueKey cursor.zipChildrenWithGroupIds().foreach { case (childPathNode, childGroupId) => dfs(childGroupId, childPathNode) @@ -76,14 +76,14 @@ object Best { dfs(rootGroupId, bestPath.rasPath.node()) val bestNodeSet = bestNodes.toSet - val winnerNodeSet = winnerNodes.map(_.toHashKey).toSet + val winnerNodeSet = winnerNodes.map(_.toUniqueKey).toSet BestImpl( ras, rootGroupId, bestPath, - n => bestNodeSet.contains(n.toHashKey), - n => winnerNodeSet.contains(n.toHashKey), + n => bestNodeSet.contains(n.toUniqueKey), + n => winnerNodeSet.contains(n.toUniqueKey), costs) } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala index 90a0adfb2144..601cd72e5d6b 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala @@ -57,17 +57,17 @@ object BestFinder { val bestPath = groupToCosts(group.id()).best() val winnerNodes = groupToCosts.map { case (id, g) => InGroupNode(id, g.bestNode) }.toSeq - val costsMap = mutable.Map[InGroupNode.HashKey, Cost]() + val costsMap = mutable.Map[InGroupNode.UniqueKey, Cost]() groupToCosts.foreach { case (gid, g) => g.nodes.foreach { n => val c = g.nodeToCost(n) if (c.nonEmpty) { - costsMap += (InGroupNode(gid, n).toHashKey -> c.get.cost) + costsMap += (InGroupNode(gid, n).toUniqueKey -> c.get.cost) } } } - Best(ras, group.id(), bestPath, winnerNodes, ign => costsMap.get(ign.toHashKey)) + Best(ras, group.id(), bestPath, winnerNodes, ign => costsMap.get(ign.toUniqueKey)) } } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala index 6fd95772b243..046760ceb2ec 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala @@ -73,7 +73,7 @@ object DpClusterAlgo { clusterAlgoDef: DpClusterAlgoDef[T, NodeOutput, ClusterOutput]) extends DpZipperAlgoDef[InClusterNode[T], RasClusterKey, NodeOutput, ClusterOutput] { override def idOfX(x: InClusterNode[T]): Any = { - x.toHashKey + x.toUniqueKey } override def idOfY(y: RasClusterKey): Any = { diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala index c824fda8e367..f88f7b6e4116 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala @@ -66,7 +66,7 @@ object DpGroupAlgo { groupAlgoDef: DpGroupAlgoDef[T, NodeOutput, GroupOutput]) extends DpZipperAlgoDef[InGroupNode[T], RasGroup[T], NodeOutput, GroupOutput] { override def idOfX(x: InGroupNode[T]): Any = { - x.toHashKey + x.toUniqueKey } override def idOfY(y: RasGroup[T]): Any = { diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala index 4a9e3f0f0f14..3f2590dff8a4 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala @@ -21,7 +21,7 @@ import org.apache.gluten.ras.Best.KnownCostPath import org.apache.gluten.ras.best.BestFinder import org.apache.gluten.ras.dp.DpZipperAlgo.Adjustment.Panel import org.apache.gluten.ras.memo.{Memo, MemoTable} -import org.apache.gluten.ras.path.{InClusterPath, PathFinder, RasPath} +import org.apache.gluten.ras.path._ import org.apache.gluten.ras.property.PropertySet import org.apache.gluten.ras.rule.{EnforcerRuleSet, RuleApplier, Shape} @@ -99,10 +99,16 @@ object DpPlanner { rules: Seq[RuleApplier[T]], enforcerRuleSet: EnforcerRuleSet[T]) extends DpClusterAlgo.Adjustment[T] { + import ExploreAdjustment._ + + private val ruleShapes: Seq[Shape[T]] = rules.map(_.shape()) override def exploreChildX( panel: Panel[InClusterNode[T], RasClusterKey], - x: InClusterNode[T]): Unit = {} + x: InClusterNode[T]): Unit = { + applyRulesOnNode(panel, x.clusterKey, x.can) + } + override def exploreChildY( panel: Panel[InClusterNode[T], RasClusterKey], y: RasClusterKey): Unit = {} @@ -115,20 +121,24 @@ object DpPlanner { cKey: RasClusterKey): Unit = { memoTable.doExhaustively { applyEnforcerRules(panel, cKey) - applyRules(panel, cKey) } } - private def applyRules( + private def applyRulesOnNode( panel: Panel[InClusterNode[T], RasClusterKey], - cKey: RasClusterKey): Unit = { + cKey: RasClusterKey, + can: CanonicalNode[T]): Unit = { if (rules.isEmpty) { return } val dummyGroup = memoTable.getDummyGroup(cKey) - val shapes = rules.map(_.shape()) - findPaths(GroupNode(ras, dummyGroup), shapes) { - path => rules.foreach(rule => applyRule(panel, cKey, rule, path)) + findPaths(GroupNode(ras, dummyGroup), ruleShapes, List(new FromSingleNode[T](can))) { + path => + val rootNode = path.node().self() + if (rootNode.isCanonical) { + assert(rootNode.asCanonical() eq can) + } + rules.foreach(rule => applyRule(panel, cKey, rule, path)) } } @@ -137,27 +147,34 @@ object DpPlanner { cKey: RasClusterKey): Unit = { val dummyGroup = memoTable.getDummyGroup(cKey) cKey.propSets(memoTable).foreach { - constraintSet => + constraintSet: PropertySet[T] => val enforcerRules = enforcerRuleSet.rulesOf(constraintSet) if (enforcerRules.nonEmpty) { - val shapes = enforcerRules.map(_.shape()) - findPaths(GroupNode(ras, dummyGroup), shapes) { + val shapes = enforcerRuleSet.ruleShapesOf(constraintSet) + findPaths(GroupNode(ras, dummyGroup), shapes, List.empty) { path => enforcerRules.foreach(rule => applyRule(panel, cKey, rule, path)) } } } } - private def findPaths(gn: GroupNode[T], shapes: Seq[Shape[T]])( + private def findPaths(gn: GroupNode[T], shapes: Seq[Shape[T]], filters: Seq[FilterWizard[T]])( onFound: RasPath[T] => Unit): Unit = { - val finder = shapes + val finderBuilder = shapes .foldLeft( PathFinder .builder(ras, memoTable)) { case (builder, shape) => builder.output(shape.wizard()) } + + val finder = filters + .foldLeft(finderBuilder) { + case (builder, filter) => + builder.filter(filter) + } .build() + finder.find(gn).foreach(path => onFound(path)) } @@ -191,5 +208,22 @@ object DpPlanner { } } - private object ExploreAdjustment {} + private object ExploreAdjustment { + private class FromSingleNode[T <: AnyRef](from: CanonicalNode[T]) extends FilterWizard[T] { + override def omit(can: CanonicalNode[T]): FilterWizard.FilterAction[T] = { + if (can eq from) { + return FilterWizard.FilterAction.Continue(this) + } + FilterWizard.FilterAction.omit + } + + override def omit(group: GroupNode[T]): FilterWizard.FilterAction[T] = + FilterWizard.FilterAction.Continue(this) + + override def advance(offset: Int, count: Int): FilterWizard.FilterAdvanceAction[T] = { + // We only filter on nodes from the root group. So continue with a noop filter. + FilterWizard.FilterAdvanceAction.Continue(FilterWizards.none()) + } + } + } } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala index 6406b8fb12eb..c67120357311 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala @@ -17,7 +17,7 @@ package org.apache.gluten.ras.memo import org.apache.gluten.ras._ -import org.apache.gluten.ras.Ras.UnsafeKey +import org.apache.gluten.ras.Ras.UnsafeHashKey import org.apache.gluten.ras.property.PropertySet import org.apache.gluten.ras.vis.GraphvizVisualizer @@ -236,11 +236,11 @@ object Memo { private object MemoCacheKey { def apply[T <: AnyRef](ras: Ras[T], self: T): MemoCacheKey[T] = { assert(ras.isCanonical(self)) - MemoCacheKey[T](ras.toUnsafeKey(self)) + MemoCacheKey[T](ras.toHashKey(self)) } } - private case class MemoCacheKey[T <: AnyRef] private (delegate: UnsafeKey[T]) + private case class MemoCacheKey[T <: AnyRef] private (delegate: UnsafeHashKey[T]) } trait MemoStore[T <: AnyRef] { diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala index 126ae7766164..253e9ec84db1 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala @@ -54,6 +54,21 @@ object FilterWizards { OmitCycles[T](CycleDetector[GroupNode[T]]((one, other) => one.groupId() == other.groupId())) } + def none[T <: AnyRef](): FilterWizard[T] = { + None[T]() + } + + private class None[T <: AnyRef] private () extends FilterWizard[T] { + override def omit(can: CanonicalNode[T]): FilterAction[T] = FilterAction.Continue(this) + override def omit(group: GroupNode[T]): FilterAction[T] = FilterAction.Continue(this) + override def advance(offset: Int, count: Int): FilterAdvanceAction[T] = + FilterAdvanceAction.Continue(this) + } + + private object None { + def apply[T <: AnyRef](): None[T] = new None[T]() + } + // Cycle detection starts from the first visited group in the input path. private class OmitCycles[T <: AnyRef] private (detector: CycleDetector[GroupNode[T]]) extends FilterWizard[T] { diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala index c189369737b6..439b88a2cb57 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala @@ -54,6 +54,7 @@ object EnforcerRule { trait EnforcerRuleSet[T <: AnyRef] { def rulesOf(constraintSet: PropertySet[T]): Seq[RuleApplier[T]] + def ruleShapesOf(constraintSet: PropertySet[T]): Seq[Shape[T]] } object EnforcerRuleSet { @@ -73,21 +74,31 @@ object EnforcerRuleSet { mutable.Map[PropertyDef[T, _ <: Property[T]], EnforcerRuleFactory[T]]() private val buffer = mutable.Map[Property[T], Seq[RuleApplier[T]]]() + private val rulesBuffer = mutable.Map[PropertySet[T], Seq[RuleApplier[T]]]() + private val shapesBuffer = mutable.Map[PropertySet[T], Seq[Shape[T]]]() + override def rulesOf(constraintSet: PropertySet[T]): Seq[RuleApplier[T]] = { - constraintSet.getMap.flatMap { - case (constraintDef, constraint) => - buffer.getOrElseUpdate( - constraint, { - val factory = - factoryBuffer.getOrElseUpdate( - constraintDef, - newEnforcerRuleFactory(ras, constraintDef)) - RuleApplier(ras, closure, EnforcerRule.builtin(constraint)) +: factory - .newEnforcerRules(constraint) - .map(rule => RuleApplier(ras, closure, EnforcerRule(rule, constraint))) - } - ) - }.toSeq + rulesBuffer.getOrElseUpdate( + constraintSet, + constraintSet.getMap.flatMap { + case (constraintDef, constraint) => + buffer.getOrElseUpdate( + constraint, { + val factory = + factoryBuffer.getOrElseUpdate( + constraintDef, + newEnforcerRuleFactory(ras, constraintDef)) + RuleApplier(ras, closure, EnforcerRule.builtin(constraint)) +: factory + .newEnforcerRules(constraint) + .map(rule => RuleApplier(ras, closure, EnforcerRule(rule, constraint))) + } + ) + }.toSeq + ) + } + + override def ruleShapesOf(constraintSet: PropertySet[T]): Seq[Shape[T]] = { + shapesBuffer.getOrElseUpdate(constraintSet, rulesBuffer(constraintSet).map(_.shape())) } } } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala index 01e826f065f1..6b4082c7eb15 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala @@ -17,7 +17,7 @@ package org.apache.gluten.ras.rule import org.apache.gluten.ras._ -import org.apache.gluten.ras.Ras.UnsafeKey +import org.apache.gluten.ras.Ras.UnsafeHashKey import org.apache.gluten.ras.memo.Closure import org.apache.gluten.ras.path.InClusterPath import org.apache.gluten.ras.property.PropertySet @@ -43,14 +43,14 @@ object RuleApplier { private class RegularRuleApplier[T <: AnyRef](ras: Ras[T], closure: Closure[T], rule: RasRule[T]) extends RuleApplier[T] { - private val deDup = mutable.Map[RasClusterKey, mutable.Set[UnsafeKey[T]]]() + private val deDup = mutable.Map[RasClusterKey, mutable.Set[UnsafeHashKey[T]]]() override def apply(icp: InClusterPath[T]): Unit = { val cKey = icp.cluster() val path = icp.path() val plan = path.plan() val appliedPlans = deDup.getOrElseUpdate(cKey, mutable.Set()) - val pKey = ras.toUnsafeKey(plan) + val pKey = ras.toHashKey(plan) if (appliedPlans.contains(pKey)) { return } @@ -76,7 +76,7 @@ object RuleApplier { closure: Closure[T], rule: EnforcerRule[T]) extends RuleApplier[T] { - private val deDup = mutable.Map[RasClusterKey, mutable.Set[UnsafeKey[T]]]() + private val deDup = mutable.Map[RasClusterKey, mutable.Set[UnsafeHashKey[T]]]() private val constraint = rule.constraint() private val constraintDef = constraint.definition() @@ -88,7 +88,7 @@ object RuleApplier { return } val plan = path.plan() - val pKey = ras.toUnsafeKey(plan) + val pKey = ras.toHashKey(plan) val appliedPlans = deDup.getOrElseUpdate(cKey, mutable.Set()) if (appliedPlans.contains(pKey)) { return diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala index 018c8087ed3e..b420d8c2978a 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala @@ -29,7 +29,7 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState: MemoState[T], best private val allGroups = memoState.allGroups() private val allClusters = memoState.clusterLookup() - private val nodeToId = mutable.Map[InGroupNode.HashKey, Int]() + private val nodeToId = mutable.Map[InGroupNode.UniqueKey, Int]() def format(): String = { val rootGroupId = best.rootGroupId() @@ -156,7 +156,7 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState: MemoState[T], best group: RasGroup[T], node: CanonicalNode[T]): String = { val ign = InGroupNode(group.id(), node) - val nodeId = nodeToId.getOrElseUpdate(ign.toHashKey, nodeToId.size) + val nodeId = nodeToId.getOrElseUpdate(ign.toUniqueKey, nodeToId.size) s"[$nodeId][Cost ${costs(ign) .map { case c if ras.isInfCost(c) => "" diff --git a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala index b29e0c267ffe..2f3ef348cb8c 100644 --- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala +++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala @@ -20,6 +20,7 @@ import org.apache.gluten.ras.RasConfig.PlannerType import org.apache.gluten.ras.RasSuiteBase._ import org.apache.gluten.ras.memo.Memo import org.apache.gluten.ras.path.Pattern +import org.apache.gluten.ras.path.Pattern.Matchers import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes} import org.scalatest.funsuite.AnyFunSuite @@ -265,6 +266,58 @@ abstract class RasSuite extends AnyFunSuite { assert(allPaths.size == 15) } + test(s"Rule dependency") { + // Op3 relies on Op2 relies on Op1 + + object Op1 extends RasRule[TestNode] { + override def shift(node: TestNode): Iterable[TestNode] = node match { + case Leaf(70) => + List(Leaf(69)) + case other => List.empty + } + + override def shape(): Shape[TestNode] = Shapes.fixedHeight(1) + } + + object Op2 extends RasRule[TestNode] { + override def shift(node: TestNode): Iterable[TestNode] = node match { + case Leaf(69) => + List(Leaf(68)) + case other => List.empty + } + + override def shape(): Shape[TestNode] = + Shapes.pattern(Pattern.leaf[TestNode](Matchers.clazz(classOf[Leaf])).build()) + } + + object Op3 extends RasRule[TestNode] { + override def shift(node: TestNode): Iterable[TestNode] = node match { + case Leaf(68) => + List(Leaf(67)) + case other => List.empty + } + + override def shape(): Shape[TestNode] = + Shapes.pattern(Pattern.any[TestNode].build()) + } + + val ras = + Ras[TestNode]( + PlanModelImpl, + CostModelImpl, + MetadataModelImpl, + PropertyModelImpl, + ExplainImpl, + RasRule.Factory.reuse(List(Op3, Op1, Op2))) + .withNewConfig(_ => conf) + + val plan = Unary(90, Unary(90, Leaf(70))) + val planner = ras.newPlanner(plan) + val optimized = planner.plan() + + assert(optimized == Unary(90, Unary(90, Leaf(67)))) + } + test(s"Unary node insertion") { object InsertUnary2 extends RasRule[TestNode] { override def shift(node: TestNode): Iterable[TestNode] = node match {