Skip to content

Commit

Permalink
[js/webgpu] Fix issue to run model demucs (#22074)
Browse files Browse the repository at this point in the history
This is to fix issue #22031 to run model demucs.
For conv-transpose, outputPadding.length could be 1, while spatialRank
is 2. The fix is to append enough 0s to outputPadding. For conv, the
issue is similar. kernelShape.length sometimes could be 1, while
inputs[1].dims.length is 4. The fix is also to append enough 0s to
kernelShape.
  • Loading branch information
gyagp committed Sep 17, 2024
1 parent 291a535 commit 2db6b73
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
6 changes: 2 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,8 @@ const calculateOutputShapeAndPads = (
) => {
const spatialRank = inputShape.length - 2;
const updateOutputShape = outputShape.length === 0;
if (outputPadding.length === 0) {
for (let i = 0; i < spatialRank; ++i) {
outputPadding.push(0);
}
if (outputPadding.length < spatialRank) {
outputPadding.push(...Array(spatialRank - outputPadding.length).fill(0));
}
const batchSize = inputShape[0];
const outChannels = kernelShape[isChannelLast ? 3 : 1] * group;
Expand Down
5 changes: 4 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvAttribute

const getAdjustedConvAttributes = <T extends ConvAttributes>(attributes: T, inputs: readonly TensorView[]): T => {
const kernelShape = attributes.kernelShape.slice();
// if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
// if kernelShape is not well specified in the attributes, infer it from the weight tensor dims
if (kernelShape.length < inputs[1].dims.length - 2) {
kernelShape.push(...Array(inputs[1].dims.length - 2 - kernelShape.length).fill(0));
}
for (let i = 2; i < inputs[1].dims.length; ++i) {
if (kernelShape[i - 2] === 0) {
kernelShape[i - 2] = inputs[1].dims[i];
Expand Down

0 comments on commit 2db6b73

Please sign in to comment.