Skip to content

Commit

Permalink
[VL] Fix function input_file_name() outputs empty string in certain…
Browse files Browse the repository at this point in the history
… query plan patterns (apache#7124)
  • Loading branch information
zml1206 authored and hengzhen.sq committed Sep 11, 2024
1 parent b3a9512 commit 24d650f
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,6 @@ object VeloxBackendSettings extends BackendSettingsApi {

override def supportNativeRowIndexColumn(): Boolean = true

override def supportNativeInputFileRelatedExpr(): Boolean = true

override def supportExpandExec(): Boolean = true

override def supportSortExec(): Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ private object VeloxRuleApi {
def injectLegacy(injector: LegacyInjector): Unit = {
// Gluten columnar: Transform rules.
injector.injectTransform(_ => RemoveTransitions)
injector.injectTransform(_ => PushDownInputFileExpression.PreOffload)
injector.injectTransform(c => FallbackOnANSIMode.apply(c.session))
injector.injectTransform(c => FallbackMultiCodegens.apply(c.session))
injector.injectTransform(c => PlanOneRowRelation.apply(c.session))
Expand All @@ -64,6 +65,7 @@ private object VeloxRuleApi {
injector.injectTransform(_ => TransformPreOverrides())
injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject())
injector.injectTransform(c => RewriteTransformer.apply(c.session))
injector.injectTransform(_ => PushDownInputFileExpression.PostOffload)
injector.injectTransform(_ => EnsureLocalSortRequirements)
injector.injectTransform(_ => EliminateLocalSort)
injector.injectTransform(_ => CollapseProjectExecTransformer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ScalarFunctionsValidateSuiteRasOff extends ScalarFunctionsValidateSuite {
super.sparkConf
.set("spark.gluten.ras.enabled", "false")
}

import testImplicits._
// Since https://github.com/apache/incubator-gluten/pull/6200.
test("Test input_file_name function") {
runQueryAndCompare("""SELECT input_file_name(), l_orderkey
Expand All @@ -44,6 +44,21 @@ class ScalarFunctionsValidateSuiteRasOff extends ScalarFunctionsValidateSuite {
| limit 100""".stripMargin) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
withTempPath {
path =>
Seq(1, 2, 3).toDF("a").write.json(path.getCanonicalPath)
spark.read.json(path.getCanonicalPath).createOrReplaceTempView("json_table")
val sql =
"""
|SELECT input_file_name(), a
|FROM
|(SELECT a FROM json_table
|UNION ALL
|SELECT l_orderkey as a FROM lineitem)
|LIMIT 100
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ trait BackendSettingsApi {
def supportNativeWrite(fields: Array[StructField]): Boolean = true
def supportNativeMetadataColumns(): Boolean = false
def supportNativeRowIndexColumn(): Boolean = false
def supportNativeInputFileRelatedExpr(): Boolean = false

def supportExpandExec(): Boolean = false
def supportSortExec(): Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object MiscColumnarRules {
object TransformPreOverrides {
def apply(): TransformPreOverrides = {
TransformPreOverrides(
List(OffloadProject(), OffloadFilter()),
List(OffloadFilter()),
List(
OffloadOthers(),
OffloadAggregate(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,17 @@ import org.apache.gluten.utils.{LogLevelUtil, PlanUtil}

import org.apache.spark.api.python.EvalPythonExecTransformer
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, NamedExpression}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.datasources.WriteFilesExec
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanExecBase}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, BatchEvalPythonExec}
import org.apache.spark.sql.execution.window.{WindowExec, WindowGroupLimitExecShim}
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
import org.apache.spark.sql.types.{LongType, StringType}

import scala.collection.mutable.Map

/**
* Converts a vanilla Spark plan node into Gluten plan node. Gluten plan is supposed to be executed
Expand Down Expand Up @@ -224,148 +220,6 @@ object OffloadJoin {
}
}

case class OffloadProject() extends OffloadSingleNode with LogLevelUtil {
private def containsInputFileRelatedExpr(expr: Expression): Boolean = {
expr match {
case _: InputFileName | _: InputFileBlockStart | _: InputFileBlockLength => true
case _ => expr.children.exists(containsInputFileRelatedExpr)
}
}

private def rewriteExpr(
expr: Expression,
replacedExprs: Map[String, AttributeReference]): Expression = {
expr match {
case _: InputFileName =>
replacedExprs.getOrElseUpdate(
expr.prettyName,
AttributeReference(expr.prettyName, StringType, false)())
case _: InputFileBlockStart =>
replacedExprs.getOrElseUpdate(
expr.prettyName,
AttributeReference(expr.prettyName, LongType, false)())
case _: InputFileBlockLength =>
replacedExprs.getOrElseUpdate(
expr.prettyName,
AttributeReference(expr.prettyName, LongType, false)())
case other =>
other.withNewChildren(other.children.map(child => rewriteExpr(child, replacedExprs)))
}
}

private def addMetadataCol(
plan: SparkPlan,
replacedExprs: Map[String, AttributeReference]): SparkPlan = {
def genNewOutput(output: Seq[Attribute]): Seq[Attribute] = {
var newOutput = output
for ((_, newAttr) <- replacedExprs) {
if (!newOutput.exists(attr => attr.exprId == newAttr.exprId)) {
newOutput = newOutput :+ newAttr
}
}
newOutput
}
def genNewProjectList(projectList: Seq[NamedExpression]): Seq[NamedExpression] = {
var newProjectList = projectList
for ((_, newAttr) <- replacedExprs) {
if (!newProjectList.exists(attr => attr.exprId == newAttr.exprId)) {
newProjectList = newProjectList :+ newAttr.toAttribute
}
}
newProjectList
}

plan match {
case f: FileSourceScanExec =>
f.copy(output = genNewOutput(f.output))
case f: FileSourceScanExecTransformer =>
f.copy(output = genNewOutput(f.output))
case b: BatchScanExec =>
b.copy(output = genNewOutput(b.output).asInstanceOf[Seq[AttributeReference]])
case b: BatchScanExecTransformer =>
b.copy(output = genNewOutput(b.output).asInstanceOf[Seq[AttributeReference]])
case p @ ProjectExec(projectList, child) =>
p.copy(genNewProjectList(projectList), addMetadataCol(child, replacedExprs))
case p @ ProjectExecTransformer(projectList, child) =>
p.copy(genNewProjectList(projectList), addMetadataCol(child, replacedExprs))
case u @ UnionExec(children) =>
val newFirstChild = addMetadataCol(children.head, replacedExprs)
val newOtherChildren = children.tail.map {
child =>
// Make sure exprId is unique in each child of Union.
val newReplacedExprs = replacedExprs.map {
expr => (expr._1, AttributeReference(expr._2.name, expr._2.dataType, false)())
}
addMetadataCol(child, newReplacedExprs)
}
u.copy(children = newFirstChild +: newOtherChildren)
case _ => plan.withNewChildren(plan.children.map(addMetadataCol(_, replacedExprs)))
}
}

private def tryOffloadProjectExecWithInputFileRelatedExprs(
projectExec: ProjectExec): SparkPlan = {
def findScanNodes(plan: SparkPlan): Seq[SparkPlan] = {
plan.collect {
case f @ (_: FileSourceScanExec | _: AbstractFileSourceScanExec |
_: DataSourceV2ScanExecBase) =>
f
}
}
val addHint = AddFallbackTagRule()
val newProjectList = projectExec.projectList.filterNot(containsInputFileRelatedExpr)
val newProjectExec = ProjectExec(newProjectList, projectExec.child)
addHint.apply(newProjectExec)
if (FallbackTags.nonEmpty(newProjectExec)) {
// Project is still not transformable after remove `input_file_name` expressions.
projectExec
} else {
// the project with `input_file_name` expression may have multiple data source
// by union all, reference:
// https://github.com/apache/spark/blob/e459674127e7b21e2767cc62d10ea6f1f941936c
// /sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala#L519
val leafScans = findScanNodes(projectExec)
if (leafScans.isEmpty || leafScans.exists(FallbackTags.nonEmpty)) {
// It means
// 1. projectExec has `input_file_name` but no scan child.
// 2. It has scan children node but the scan node fallback.
projectExec
} else {
val replacedExprs = scala.collection.mutable.Map[String, AttributeReference]()
val newProjectList = projectExec.projectList.map {
expr => rewriteExpr(expr, replacedExprs).asInstanceOf[NamedExpression]
}
val newChild = addMetadataCol(projectExec.child, replacedExprs)
logDebug(
s"Columnar Processing for ${projectExec.getClass} with " +
s"ProjectList ${projectExec.projectList} is currently supported.")
ProjectExecTransformer(newProjectList, newChild)
}
}
}

private def genProjectExec(projectExec: ProjectExec): SparkPlan = {
if (
FallbackTags.nonEmpty(projectExec) &&
BackendsApiManager.getSettings.supportNativeInputFileRelatedExpr() &&
projectExec.projectList.exists(containsInputFileRelatedExpr)
) {
tryOffloadProjectExecWithInputFileRelatedExprs(projectExec)
} else if (FallbackTags.nonEmpty(projectExec)) {
projectExec
} else {
logDebug(s"Columnar Processing for ${projectExec.getClass} is currently supported.")
ProjectExecTransformer(projectExec.projectList, projectExec.child)
}
}

override def offload(plan: SparkPlan): SparkPlan = plan match {
case p: ProjectExec =>
genProjectExec(p)
case other => other
}
}

// Filter transformation.
case class OffloadFilter() extends OffloadSingleNode with LogLevelUtil {
import OffloadOthers._
Expand Down Expand Up @@ -443,6 +297,10 @@ object OffloadOthers {
case plan: CoalesceExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
ColumnarCoalesceExec(plan.numPartitions, plan.child)
case plan: ProjectExec =>
val columnarChild = plan.child
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
ProjectExecTransformer(plan.projectList, columnarChild)
case plan: SortAggregateExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension.columnar

import org.apache.gluten.execution.{BatchScanExecTransformer, FileSourceScanExecTransformer}

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, NamedExpression}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{DeserializeToObjectExec, LeafExecNode, ProjectExec, SerializeFromObjectExec, SparkPlan, UnionExec}

import scala.collection.mutable

/**
* The Spark implementations of input_file_name/input_file_block_start/input_file_block_length uses
* a thread local to stash the file name and retrieve it from the function. If there is a
* transformer node between project input_file_function and scan, the result of input_file_name is
* an empty string. So we should push down input_file_function to transformer scan or add fallback
* project of input_file_function before fallback scan.
*
* Two rules are involved:
* - Before offload, add new project before leaf node and push down input file expression to the
* new project
* - After offload, if scan be offloaded, push down input file expression into scan and remove
* project
*/
object PushDownInputFileExpression {
def containsInputFileRelatedExpr(expr: Expression): Boolean = {
expr match {
case _: InputFileName | _: InputFileBlockStart | _: InputFileBlockLength => true
case _ => expr.children.exists(containsInputFileRelatedExpr)
}
}

object PreOffload extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case ProjectExec(projectList, child) if projectList.exists(containsInputFileRelatedExpr) =>
val replacedExprs = mutable.Map[String, Alias]()
val newProjectList = projectList.map {
expr => rewriteExpr(expr, replacedExprs).asInstanceOf[NamedExpression]
}
val newChild = addMetadataCol(child, replacedExprs)
ProjectExec(newProjectList, newChild)
}

private def rewriteExpr(
expr: Expression,
replacedExprs: mutable.Map[String, Alias]): Expression =
expr match {
case _: InputFileName =>
replacedExprs
.getOrElseUpdate(expr.prettyName, Alias(InputFileName(), expr.prettyName)())
.toAttribute
case _: InputFileBlockStart =>
replacedExprs
.getOrElseUpdate(expr.prettyName, Alias(InputFileBlockStart(), expr.prettyName)())
.toAttribute
case _: InputFileBlockLength =>
replacedExprs
.getOrElseUpdate(expr.prettyName, Alias(InputFileBlockLength(), expr.prettyName)())
.toAttribute
case other =>
other.withNewChildren(other.children.map(child => rewriteExpr(child, replacedExprs)))
}

private def addMetadataCol(
plan: SparkPlan,
replacedExprs: mutable.Map[String, Alias]): SparkPlan =
plan match {
case p: LeafExecNode =>
ProjectExec(p.output ++ replacedExprs.values, p)
// Output of SerializeFromObjectExec's child and output of DeserializeToObjectExec must be
// a single-field row.
case p @ (_: SerializeFromObjectExec | _: DeserializeToObjectExec) =>
ProjectExec(p.output ++ replacedExprs.values, p)
case p: ProjectExec =>
p.copy(
projectList = p.projectList ++ replacedExprs.values.toSeq.map(_.toAttribute),
child = addMetadataCol(p.child, replacedExprs))
case u @ UnionExec(children) =>
val newFirstChild = addMetadataCol(children.head, replacedExprs)
val newOtherChildren = children.tail.map {
child =>
// Make sure exprId is unique in each child of Union.
val newReplacedExprs = replacedExprs.map {
expr => (expr._1, Alias(expr._2.child, expr._2.name)())
}
addMetadataCol(child, newReplacedExprs)
}
u.copy(children = newFirstChild +: newOtherChildren)
case p => p.withNewChildren(p.children.map(child => addMetadataCol(child, replacedExprs)))
}
}

object PostOffload extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case p @ ProjectExec(projectList, child: FileSourceScanExecTransformer)
if projectList.exists(containsInputFileRelatedExpr) =>
child.copy(output = p.output)
case p @ ProjectExec(projectList, child: BatchScanExecTransformer)
if projectList.exists(containsInputFileRelatedExpr) =>
child.copy(output = p.output.asInstanceOf[Seq[AttributeReference]])
}
}
}

0 comments on commit 24d650f

Please sign in to comment.