Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu-native] Add Tile Op #22133

Open
wants to merge 3 commits into
base: fs-eire/webgpu-ep
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/tile.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#include "core/common/inlined_containers.h"
#include "core/providers/webgpu/tensor/tile.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

namespace onnxruntime {
namespace webgpu {

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Tile,
kOnnxDomain,
6, 12,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Tile);

ONNX_OPERATOR_KERNEL_EX(
Tile,
kOnnxDomain,
13
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Tile);

const std::string AppendTileFunction(ShaderVariableHelper& input, ShaderVariableHelper& output) {

Check warning on line 31 in onnxruntime/core/providers/webgpu/tensor/tile.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/tensor/tile.cc:31: Add #include <string> for string [build/include_what_you_use] [4]
std::ostringstream ss;
ss.imbue(std::locale::classic());
const TensorShape& input_shape = input.GetShape();
int32_t input_rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions());
ss << "fn tile(i: output_indices_t)->input_indices_t {\n"
" var input_indices;\n";
for (auto i = 0; i < input_rank; i++) {
ss << " input_dim_i = input.GetDimensionByIndex(" << i << ");\n";
ss << " input_dim_value = output.GetDimensionByIndex(" << i << ") % input_dim_i;\n";
ss << " input.indicesSet('input_indices', '" << i << "', 'input_dim_value');\n";
}
ss << " return input_indices;\n"
"}\n";
return ss.str();
}

Status TileProgram::GenerateShaderCode(ShaderHelper& shader) const {
const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform);
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform);
shader.AppendImplementation(AppendTileFunction(input, output));
shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"),
" let output_indices = ", output.OffsetToIndices("global_idx"),
";\n"
" let input_indices = tile(input, output); \n"
" ",
output.SetByOffset("global_idx", input.GetByIndices("input_indices")));

return Status::OK();
}

Status Tile::ComputeInternal(ComputeContext& context) const {
const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
int32_t input_rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions());

const auto* repeats_tensor = context.Input(1);
const auto* repeats = repeats_tensor->Data<int32_t>();

auto output_dims = input_shape.AsShapeVector();
for (size_t axis = 0; axis < input_rank; axis++) {
output_dims[axis] *= repeats[axis];
}

TensorShape output_shape(output_dims);
auto* output_tensor = context.Output(0, output_shape);
uint32_t output_size = gsl::narrow_cast<int32_t>(output_tensor->Shape().Size());

TileProgram program{};
program
.AddInputs({{input_tensor}, {repeats_tensor}})
.AddOutputs({output_tensor})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
return context.RunProgram(program);
}

} // namespace webgpu
} // namespace onnxruntime

Check warning on line 88 in onnxruntime/core/providers/webgpu/tensor/tile.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Could not find a newline character at the end of the file. [whitespace/ending_newline] [5] Raw Output: onnxruntime/core/providers/webgpu/tensor/tile.cc:88: Could not find a newline character at the end of the file. [whitespace/ending_newline] [5]
29 changes: 29 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/tile.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/program.h"

namespace onnxruntime {
namespace webgpu {

class TileProgram final : public Program<TileProgram> {
public:
TileProgram() : Program{"Tile"} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

Check warning on line 16 in onnxruntime/core/providers/webgpu/tensor/tile.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/webgpu/tensor/tile.h:16: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32});
};

class Tile final : public WebGpuKernel {
public:
Tile(const OpKernelInfo& info) : WebGpuKernel(info) {}

Status ComputeInternal(ComputeContext& context) const override;
};

} // namespace webgpu
} // namespace onnxruntime
Loading