From 446749bd3f62503d1974b77bf7ca52870919f280 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 8 Apr 2024 13:52:52 +0800 Subject: [PATCH] [VL] RAS: Group reduction support --- .github/workflows/velox_docker.yml | 1 + .../org/apache/gluten/ras/memo/Memo.scala | 16 +++++++++--- .../org/apache/gluten/ras/RasSuite.scala | 26 +++++++++++++++++++ 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/.github/workflows/velox_docker.yml b/.github/workflows/velox_docker.yml index 32c4ccc95c12..afadfa5d203d 100644 --- a/.github/workflows/velox_docker.yml +++ b/.github/workflows/velox_docker.yml @@ -24,6 +24,7 @@ on: - 'gluten-celeborn/common' - 'gluten-celeborn/package' - 'gluten-celeborn/velox' + - 'gluten-ras/**' - 'gluten-core/**' - 'gluten-data/**' - 'gluten-delta/**' diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala index 66626b756c30..49281a82d581 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala @@ -126,8 +126,20 @@ object Memo { extends MemoLike[T] { private val ras = parent.ras + // TODO: Traverse up the tree to do more merges. private def prepareInsert(node: T): Prepare[T] = { - assert(!ras.isGroupLeaf(node)) + if (ras.isGroupLeaf(node)) { + val group = parent.memoTable.allGroups()(ras.planModel.getGroupId(node)) + val residentCluster = group.clusterKey() + + if (residentCluster == targetCluster) { + return Prepare.cluster(parent, targetCluster) + } + // The resident cluster of group leaf is not the same with target cluster. + // Merge. + parent.memoTable.mergeClusters(residentCluster, targetCluster) + return Prepare.cluster(parent, targetCluster) + } val childrenPrepares = ras.planModel.childrenOf(node).map(child => parent.prepareInsert(child)) @@ -155,8 +167,6 @@ object Memo { } // The new node already memorized to memo, but in the different cluster. // Merge the two clusters. - // - // TODO: Traverse up the tree to do more merges. parent.memoTable.mergeClusters(cachedCluster, targetCluster) Prepare.tree(parent, targetCluster, childrenPrepares) } diff --git a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala index 0ad82518128f..f8a3d0799f15 100644 --- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala +++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala @@ -172,6 +172,32 @@ abstract class RasSuite extends AnyFunSuite { assert(optimized == Unary(23, Unary(23, Leaf(70)))) } + test(s"Group reduction") { + object RemoveUnary extends RasRule[TestNode] { + override def shift(node: TestNode): Iterable[TestNode] = node match { + case Unary(cost, child) => List(child) + case other => List.empty + } + + override def shape(): Shape[TestNode] = Shapes.fixedHeight(1) + } + + val ras = + Ras[TestNode]( + PlanModelImpl, + CostModelImpl, + MetadataModelImpl, + PropertyModelImpl, + ExplainImpl, + RasRule.Factory.reuse(List(RemoveUnary))) + .withNewConfig(_ => conf) + val plan = Unary(60, Unary(90, Leaf(70))) + val planner = ras.newPlanner(plan) + val optimized = planner.plan() + + assert(optimized == Leaf(70)) + } + test(s"Unary node insertion") { object InsertUnary2 extends RasRule[TestNode] { override def shift(node: TestNode): Iterable[TestNode] = node match {