Skip to content

Commit

Permalink
[js/webgpu] Fix issue to run model demucs
Browse files Browse the repository at this point in the history
This is to fix issue microsoft#22031 to run model demucs.
For conv-transpose, outputPadding.length could be 1, while spatialRank
is 2. The fix is to append enough 0s to outputPadding.
For conv, the issue is similar. kernelShape.length sometimes could be 1,
while inputs[1].dims.length is 4. The fix is also to append enough 0s to
kernelShape.
  • Loading branch information
gyagp committed Sep 12, 2024
1 parent d495e6c commit c2eb157
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ const calculateOutputShapeAndPads = (
) => {
const spatialRank = inputShape.length - 2;
const updateOutputShape = outputShape.length === 0;
if (outputPadding.length === 0) {
for (let i = 0; i < spatialRank; ++i) {
if (outputPadding.length < spatialRank) {
for (let i = 0; i < spatialRank - outputPadding.length; ++i) {
outputPadding.push(0);
}
}
Expand Down
7 changes: 6 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,12 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvAttribute

const getAdjustedConvAttributes = <T extends ConvAttributes>(attributes: T, inputs: readonly TensorView[]): T => {
const kernelShape = attributes.kernelShape.slice();
// if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
// if kernelShape is not well specified in the attributes, infer it from the weight tensor dims
if (kernelShape.length < inputs[1].dims.length - 2) {
for (let i = 0; i < inputs[1].dims.length - 2 - kernelShape.length; ++i) {
kernelShape.push(0);
}
}
for (let i = 2; i < inputs[1].dims.length; ++i) {
if (kernelShape[i - 2] === 0) {
kernelShape[i - 2] = inputs[1].dims[i];
Expand Down

0 comments on commit c2eb157

Please sign in to comment.