Skip to content

Commit

Permalink
[js/webgpu] Optimize ConvTranspose (Continue) (#23429)
Browse files Browse the repository at this point in the history
BUG #23273

This PR does below optimizations:
1. When output channels is one, 1) calculate the offset before the
inchannel loop to reduce indices to offsets calculation, 2) split the
`inputChannelsPerGroup` into `inputChannelsPerGroupInt` and
`inputChannelsRemainder` parts so that we can always access 4 data for
`inputChannelsPerGroupInt`.
2. Use precise initial value to reduce useless loop iterations. Thanks
@jiangzhaoming 's suggestion's on this.

With this PR, ConvTranspose becomes 3.7s from 8.4s on Intel Meteor Lake.
On NV RTX 2000 Ada, it becomes 1.6s from 2.7s.
  • Loading branch information
qjia7 authored and ashrit-ms committed Jan 23, 2025
1 parent 04a4a69 commit 009cae0
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 24 deletions.
114 changes: 90 additions & 24 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ export const createConvTranspose2DProgramInfo = (
const inputChannelsPerGroup = wShape[2] / group;
const outputChannelsPerGroup = wShape[3];
const aComponents = isChannelsLast ? getMaxComponents(inputChannelsPerGroup) : 1;
const packInputAs4 = isChannelsLast && outputChannelsPerGroup === 1;
const inputChannelsPerGroupInt = packInputAs4
? Math.floor(inputChannelsPerGroup / 4) * 4
: Math.floor(inputChannelsPerGroup / aComponents) * aComponents;
const inputChannelsRemainder = inputChannelsPerGroup - inputChannelsPerGroupInt;
const components = isChannelsLast ? getMaxComponents(outputChannelsPerGroup) : 1;
const bComponents = isChannelsLast ? (outputChannelsPerGroup === 1 ? aComponents : components) : 1;
const outputSize = ShapeUtil.size(outputShape) / components;
Expand Down Expand Up @@ -78,7 +83,7 @@ export const createConvTranspose2DProgramInfo = (
{ type: DataType.uint32, data: dilations },
{ type: DataType.uint32, data: effectiveFilterDims },
{ type: DataType.int32, data: pads },
{ type: DataType.uint32, data: inputChannelsPerGroup },
{ type: DataType.uint32, data: inputChannelsPerGroupInt },
{ type: DataType.uint32, data: outputChannelsPerGroup },
...createTensorShapeVariables(inputs[0].dims, inputs[1].dims),
];
Expand Down Expand Up @@ -114,16 +119,40 @@ export const createConvTranspose2DProgramInfo = (

const calculateResult = (): string => {
let calcStr = '';
if (aComponents === 1) {
calcStr += `
let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)};
let wValue = ${w.getByOffset(`w_offset / ${bComponents}`)};
dotProd = dotProd + xValue * wValue;`;
if (packInputAs4) {
if (aComponents === 4) {
calcStr += `
let xValue = ${dy.getByOffset('x_offset')};
let wValue = ${w.getByOffset('w_offset')};
dotProd = dotProd + dot(xValue, wValue);
x_offset += 1u;
w_offset += 1u;`;
} else if (aComponents === 2) {
calcStr += `
dotProd = dotProd + dot(vec4<${dataType}>(${dy.getByOffset('x_offset')}, ${dy.getByOffset('x_offset + 1u')}), vec4<${dataType}>(${w.getByOffset('w_offset')}, ${w.getByOffset('w_offset + 1u')}));
x_offset += 2u;
w_offset += 2u;`;
} else if (aComponents === 1) {
calcStr += `
dotProd = dotProd + dot(vec4<${dataType}>(${dy.getByOffset('x_offset')}, ${dy.getByOffset('x_offset + 1u')}, ${dy.getByOffset('x_offset + 2u')}, ${dy.getByOffset('x_offset + 3u')}), vec4<${dataType}>(${w.getByOffset('w_offset')}, ${w.getByOffset('w_offset + 1u')}, ${w.getByOffset('w_offset + 2u')}, ${w.getByOffset('w_offset + 3u')}));
x_offset += 4u;
w_offset += 4u;`;
}
} else {
if (outputChannelsPerGroup === 1) {
calcStr += `
let xValue = ${
isChannelsLast
? dy.getByOffset(
`${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}`,
)
: dy.get('batch', 'inputChannel', 'idyR', 'idyC')
};
`;
if (aComponents === 1) {
calcStr += `
let wValue = ${w.getByOffset(`${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)} / ${bComponents}`)};
dotProd = dotProd + dot(xValue, wValue);`;
let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)};
let wValue = ${w.getByOffset(`w_offset / ${bComponents}`)};
dotProd = dotProd + xValue * wValue;`;
} else {
for (let c = 0; c < aComponents; c++) {
calcStr += `
Expand All @@ -134,6 +163,32 @@ export const createConvTranspose2DProgramInfo = (
}
return calcStr;
};
const calculateRemainder = (): string => {
if (inputChannelsRemainder === 0) {
return '';
}
if (!packInputAs4) {
throw new Error(`packInputAs4 ${packInputAs4} is not true.`);
}
let calcStr = '';
if (aComponents === 1) {
calcStr += 'dotProd = dotProd';
for (let i = 0; i < inputChannelsRemainder; i++) {
calcStr += `
+ ${dy.getByOffset(`x_offset + ${i}`)} * ${w.getByOffset(`w_offset + ${i}`)}`;
}
calcStr += ';';
} else if (aComponents === 2) {
if (inputChannelsRemainder !== 2) {
throw new Error(`Invalid inputChannelsRemainder ${inputChannelsRemainder}.`);
}
calcStr += `
let xValue = ${dy.getByOffset('x_offset')};
let wValue = ${w.getByOffset('w_offset')};
dotProd = dotProd + dot(xValue, wValue);`;
}
return calcStr;
};
const codeSnippet = `
let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)};
let batch = ${output.indicesGet('outputIndices', 0)};
Expand All @@ -148,7 +203,12 @@ export const createConvTranspose2DProgramInfo = (
// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
var dotProd = ${output.type.value}(0.0);
for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {
var wR: u32 = 0;
if (uniforms.dilations.x == 1) {
// Minimum wR >= 0 that satisfies (dyRCorner + wR) % (uniforms.strides.x) == 0
wR = u32(((dyRCorner + i32(uniforms.strides.x) - 1) / i32(uniforms.strides.x)) * i32(uniforms.strides.x) - dyRCorner);
}
for (; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {
if (wR % uniforms.dilations.x != 0) {
continue;
}
Expand All @@ -158,10 +218,13 @@ export const createConvTranspose2DProgramInfo = (
wRPerm < 0) {
continue;
}
wR = wR + uniforms.strides[0] - 1;
let idyR: u32 = u32(dyR);
for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
var wC: u32 = 0;
if (uniforms.dilations.y == 1) {
// Minimum wC >= 0 that satisfies (dyCCorner + wC) % (uniforms.strides.y) == 0
wC = u32(((dyCCorner + i32(uniforms.strides.y) - 1) / i32(uniforms.strides.y)) * i32(uniforms.strides.y) - dyCCorner);
}
for (; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
if (wC % uniforms.dilations.y != 0) {
continue;
}
Expand All @@ -171,21 +234,24 @@ export const createConvTranspose2DProgramInfo = (
fract(dyC) > 0.0 || wCPerm < 0) {
continue;
}
wC = wC + uniforms.strides.y - 1;
let idyC: u32 = u32(dyC);
var inputChannel = groupId * uniforms.input_channels_per_group;
for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + ${aComponents}) {
let xValue = ${
isChannelsLast
? dy.getByOffset(
`${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}`,
)
: dy.get('batch', 'inputChannel', 'idyR', 'idyC')
};
${
packInputAs4
? `
var x_offset = ${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents};
var w_offset = ${w.indicesToOffset(`${w.type.indices}(wRPerm, wCPerm, inputChannel, wOutChannel)`)} / ${bComponents};
`
: ''
}
for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + ${packInputAs4 ? 4 : aComponents}) {
${calculateResult()}
inputChannel = inputChannel + ${aComponents};
inputChannel = inputChannel + ${packInputAs4 ? 4 : aComponents};
}
${calculateRemainder()}
wC = wC + uniforms.strides.y - 1;
}
wR = wR + uniforms.strides[0] - 1;
}
let value = dotProd${hasBias ? ` + bias[d1 / ${components}]` : ''};
${output.setByOffset('global_idx', 'value')};
Expand All @@ -201,7 +267,7 @@ export const createConvTranspose2DProgramInfo = (
return {
name: 'ConvTranspose2D',
shaderCache: {
hint: `${attributes.cacheKey};${aComponents}${bComponents}${components}${outputChannelsPerGroup === 1}`,
hint: `${attributes.cacheKey};${aComponents}${bComponents}${components}${outputChannelsPerGroup === 1}${inputChannelsRemainder}`,
inputDependencies,
},
getRunData: () => ({
Expand Down
146 changes: 146 additions & 0 deletions js/web/test/data/ops/conv-transpose.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,152 @@
}
]
},
{
"name": "ConvTranspose with output channels = 1",
"operator": "ConvTranspose",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
{ "name": "strides", "data": [2, 2], "type": "ints" }
],
"cases": [
{
"name": "inChannels = 5",
"inputs": [
{
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45
],
"dims": [1, 5, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8],
"dims": [5, 1, 2, 2],
"type": "float32"
},
{
"data": [2],
"dims": [1],
"type": "float32"
}
],
"outputs": [
{
"data": [
437, 532, 458, 558, 479, 584, 627, 722, 658, 758, 689, 794, 500, 610, 521, 636, 542, 662, 720, 830, 751,
866, 782, 902, 563, 688, 584, 714, 605, 740, 813, 938, 844, 974, 875, 1010
],
"dims": [1, 1, 6, 6],
"type": "float32"
}
]
},
{
"name": "inChannels = 6",
"inputs": [
{
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 1, 2, 3, 4, 5, 6, 7, 8, 9
],
"dims": [1, 6, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4],
"dims": [6, 1, 2, 2],
"type": "float32"
},
{
"data": [2],
"dims": [1],
"type": "float32"
}
],
"outputs": [
{
"data": [
438, 534, 460, 562, 482, 590, 630, 726, 664, 766, 698, 806, 504, 618, 526, 646, 548, 674, 732, 846, 766,
886, 800, 926, 570, 702, 592, 730, 614, 758, 834, 966, 868, 1006, 902, 1046
],
"dims": [1, 1, 6, 6],
"type": "float32"
}
]
},
{
"name": "inChannels = 7",
"inputs": [
{
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18
],
"dims": [1, 7, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8],
"dims": [7, 1, 2, 2],
"type": "float32"
},
{
"data": [2],
"dims": [1],
"type": "float32"
}
],
"outputs": [
{
"data": [
488, 594, 515, 628, 542, 662, 700, 806, 741, 854, 782, 902, 569, 696, 596, 730, 623, 764, 823, 950, 864,
998, 905, 1046, 650, 798, 677, 832, 704, 866, 946, 1094, 987, 1142, 1028, 1190
],
"dims": [1, 1, 6, 6],
"type": "float32"
}
]
},
{
"name": "inChannels = 8",
"inputs": [
{
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 1, 2, 3, 4, 5, 6, 7, 8, 9
],
"dims": [1, 8, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4],
"dims": [8, 1, 2, 2],
"type": "float32"
},
{
"data": [2],
"dims": [1],
"type": "float32"
}
],
"outputs": [
{
"data": [
489, 596, 517, 632, 545, 668, 703, 810, 747, 862, 791, 914, 573, 704, 601, 740, 629, 776, 835, 966, 879,
1018, 923, 1070, 657, 812, 685, 848, 713, 884, 967, 1122, 1011, 1174, 1055, 1226
],
"dims": [1, 1, 6, 6],
"type": "float32"
}
]
}
]
},
{
"name": "ConvTranspose without bias addition C",
"operator": "ConvTranspose",
Expand Down

0 comments on commit 009cae0

Please sign in to comment.