MLIR 22.0.0git
WinogradConv2D.cpp
Go to the documentation of this file.
1//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implement Winograd Conv2D algorithm. The implementation is based on the
10// paper: Fast Algorithms for Convolutional Neural Networks
11// (https://arxiv.org/abs/1509.09308)
12//
13//===----------------------------------------------------------------------===//
14
21#include "llvm/Support/MathExtras.h"
22
23namespace mlir {
24namespace linalg {
25
26namespace {
27
28// clang-format off
29/// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
30/// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
31/// m is the output dimension and r is the filter dimension, is
32///
33/// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
34///
35/// g is filter and d is input data. We need to prepare 6 constant
36/// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
37///
38/// The following tables define these constant transformation matrices for
39/// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
40///
41/// To add more transformation matrices, we need to add the following
42/// items:
43/// 1. Add the constant transformation matrix to the corresponding
44/// G, GT, BT, B, AT, or A array.
45/// 2. Add the corresponding TransformMatrix to the GMatrices, GTMatrices,
46/// BTMatrices, BMatrices, ATMatrices, or AMatrices map.
47/// 3. Add a enum value F_m_r to WinogradConv2DFmr enum.
48///
49constexpr float G_2x2_3x3[] = {
50 -1, 0, 0,
51 1./2, -1./2, 1./2,
52 1./2, 1./2, 1./2,
53 0, 0, 1
54};
55
56constexpr float GT_2x2_3x3[] = {
57 -1, 1./2, 1./2, 0,
58 0, -1./2, 1./2, 0,
59 0, 1./2, 1./2, 1
60};
61
62constexpr float BT_2x2_3x3[] = {
63 -1, 0, 1, 0,
64 0, -1, 1, 0,
65 0, 1, 1, 0,
66 0, -1, 0, 1
67};
68
69constexpr float B_2x2_3x3[] = {
70 -1, 0, 0, 0,
71 0, -1, 1, -1,
72 1, 1, 1, 0,
73 0, 0, 0, 1
74};
75
76constexpr float AT_2x2_3x3[] = {
77 1, 1, 1, 0,
78 0, -1, 1, 1
79};
80
81constexpr float A_2x2_3x3[] = {
82 1, 0,
83 1, -1,
84 1, 1,
85 0, 1
86};
87
88constexpr float G_4x4_3x3[] = {
89 1, 0, 0,
90 -1./3, 1./3, -1./3,
91 -1./3, -1./3, -1./3,
92 1./12, -1./6, 1./3,
93 1./12, 1./6, 1./3,
94 0, 0, 1
95};
96
97constexpr float GT_4x4_3x3[] = {
98 1, -1./3, -1./3, 1./12, 1./12, 0,
99 0, 1./3, -1./3, -1./6, 1./6, 0,
100 0, -1./3, -1./3, 1./3, 1./3, 1
101};
102
103constexpr float BT_4x4_3x3[] = {
104 1./4, 0, -5./16, 0, 1./16, 0,
105 0, 1./4, -1./4, -1./16, 1./16, 0,
106 0, -1./4, -1./4, 1./16, 1./16, 0,
107 0, 1./4, -1./8, -1./4, 1./8, 0,
108 0, -1./4, -1./8, 1./4, 1./8, 0,
109 0, 1./4, 0, -5./16, 0, 1./16
110};
111
112constexpr float B_4x4_3x3[] = {
113 1./4, 0, 0, 0, 0, 0,
114 0, 1./4, -1./4, 1./4, -1./4, 1./4,
115 -5./16, -1./4, -1./4, -1./8, -1./8, 0,
116 0, -1./16, 1./16, -1./4, 1./4, -5./16,
117 1./16, 1./16, 1./16, 1./8, 1./8, 0,
118 0, 0, 0, 0, 0, 1./16
119};
120
121constexpr float AT_4x4_3x3[] = {
122 1./8, 1./4, 1./4, 1./8, 1./8, 0,
123 0, -1./4, 1./4, -1./4, 1./4, 0,
124 0, 1./4, 1./4, 1./2, 1./2, 0,
125 0, -1./4, 1./4, -1, 1, 1./2
126};
127
128constexpr float A_4x4_3x3[] = {
129 1./8, 0, 0, 0,
130 1./4, -1./4, 1./4, -1./4,
131 1./4, 1./4, 1./4, 1./4,
132 1./8, -1./4, 1./2, -1,
133 1./8, 1./4, 1./2, 1,
134 0, 0, 0, 1./2
135};
136
137constexpr float G_2x2_5x5[] = {
138 1, 0, 0, 0, 0,
139 1./6, -1./6, 1./6, -1./6, 1./6,
140 -1./6, -1./6, -1./6, -1./6, -1./6,
141-4./15, 2./15, -1./15, 1./30, -1./60,
142 1./60, 1./30, 1./15, 2./15, 4./15,
143 0, 0, 0, 0, 1
144};
145
146constexpr float GT_2x2_5x5[] = {
147 1, 1./6, -1./6, -4./15, 1./60, 0,
148 0, -1./6, -1./6, 2./15, 1./30, 0,
149 0, 1./6, -1./6, -1./15, 1./15, 0,
150 0, -1./6, -1./6, 1./30, 2./15, 0,
151 0, 1./6, -1./6, -1./60, 4./15, 1
152};
153
154constexpr float BT_2x2_5x5[] = {
155 1./8, 3./16, -1./4, -3./16, 1./8, 0,
156 0, 1./8, 1./16, -5./16, 1./8, 0,
157 0, -1./8, -5./16, -1./16, 1./8, 0,
158 0, 1./4, -1./8, -1./4, 1./8, 0,
159 0, -1./8, -1./4, 1./8, 1./4, 0,
160 0, 1./8, 3./16, -1./4, -3./16, 1./8
161};
162
163constexpr float B_2x2_5x5[] = {
164 1./8, 0, 0, 0, 0, 0,
165 3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
166 -1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
167 -3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
168 1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
169 0, 0, 0, 0, 0, 1./8
170};
171
172constexpr float AT_2x2_5x5[] = {
173 1./2, 1, 1, 2, 1, 0,
174 0, -1, 1, -1, 2, 1./2
175};
176
177constexpr float A_2x2_5x5[] = {
178 1./2, 0,
179 1, -1,
180 1, 1,
181 2, -1,
182 1, 2,
183 0, 1./2
184};
185// clang-format on
186
187/// Structure to keep information of constant transform matrices.
188struct TransformMatrix {
189 TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols,
190 int64_t scalarFactor = 1)
191 : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
192
193 ArrayRef<float> table;
194 int64_t rows;
195 int64_t cols;
196 int64_t scalarFactor;
197};
198
199/// Utility function to convert constant array to arith.constant Value.
200Value create2DTransformMatrix(OpBuilder &builder, Location loc,
201 TransformMatrix transform, Type type) {
202 assert(transform.table.size() ==
203 static_cast<size_t>(transform.rows * transform.cols));
204 assert(type.isFloat() && "Only floats are supported by Winograd");
205 ArrayRef<float> constVec(transform.table.data(),
206 transform.rows * transform.cols);
207 auto constAttrVec =
208 llvm::map_to_vector<>(constVec, [&](const float v) -> Attribute {
209 return builder.getFloatAttr(type, v);
210 });
211 SmallVector<int64_t, 2> shape{transform.rows, transform.cols};
212 return arith::ConstantOp::create(
213 builder, loc,
214 DenseFPElementsAttr::get(RankedTensorType::get(shape, type),
215 constAttrVec));
216}
217
218/// Extract height x width data from 4D tensors.
219Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
220 Value loopNorFIndex, Value loopCorFIndex,
221 Value heightOffset, Value widthOffset,
222 int64_t extractHeight, int64_t extractWidth,
223 int64_t loopNorFIdx, int64_t loopCorFIdx,
224 int64_t heightIdx, int64_t widthIdx) {
225 auto sourceType = cast<ShapedType>(source.getType());
226 Type elementType = sourceType.getElementType();
227 int64_t srcSize = sourceType.getRank();
228
229 auto oneIndex = builder.getIndexAttr(1);
230 SmallVector<OpFoldResult> offsets;
231 offsets.resize(srcSize);
232 offsets[loopNorFIdx] = loopNorFIndex;
233 offsets[loopCorFIdx] = loopCorFIndex;
234 offsets[heightIdx] = heightOffset;
235 offsets[widthIdx] = widthOffset;
236 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
237 sizes[heightIdx] = builder.getIndexAttr(extractHeight);
238 sizes[widthIdx] = builder.getIndexAttr(extractWidth);
239 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
240
241 auto extractFilterType =
242 RankedTensorType::get({extractHeight, extractWidth}, elementType);
243 auto extractFilterOp = tensor::ExtractSliceOp::create(
244 builder, loc, extractFilterType, source, offsets, sizes, strides);
245
246 return extractFilterOp;
247}
248
249/// Extract height x width data from 6D tensors.
250Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
251 Value tileHIndex, Value tileWIndex,
252 Value loopNorFIndex, Value loopCorFIndex,
253 int64_t tileHIdx, int64_t tileWIdx,
254 int64_t loopNorFIdx, int64_t loopCorFIdx,
255 int64_t heightIdx, int64_t widthIdx) {
256 auto sourceType = cast<ShapedType>(source.getType());
257 Type elementType = sourceType.getElementType();
258 auto sourceShape = sourceType.getShape();
259 int64_t srcSize = sourceType.getRank();
260 int64_t height = sourceShape[heightIdx];
261 int64_t width = sourceShape[widthIdx];
262
263 auto zeroIndex = builder.getIndexAttr(0);
264 auto oneIndex = builder.getIndexAttr(1);
265 SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
266 offsets.resize(srcSize);
267 offsets[tileHIdx] = tileHIndex;
268 offsets[tileWIdx] = tileWIndex;
269 offsets[loopNorFIdx] = loopNorFIndex;
270 offsets[loopCorFIdx] = loopCorFIndex;
271 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
272 sizes[heightIdx] = builder.getIndexAttr(height);
273 sizes[widthIdx] = builder.getIndexAttr(width);
274 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
275
276 auto extractFilterType = RankedTensorType::get({height, width}, elementType);
277 auto extractFilterOp = tensor::ExtractSliceOp::create(
278 builder, loc, extractFilterType, source, offsets, sizes, strides);
279
280 return extractFilterOp;
281}
282
283/// Insert transformed height x width data to 4D tensors which it is
284/// extracted from.
285Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
286 Value dest, Value loopNorFIndex, Value loopCorFIndex,
287 Value heightOffset, Value widthOffset, int64_t height,
288 int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
289 int64_t heightIdx, int64_t widthIdx) {
290 int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
291 auto oneIndex = builder.getIndexAttr(1);
292 SmallVector<OpFoldResult> retOffsets;
293 retOffsets.resize(destSize);
294 retOffsets[loopNorFIdx] = loopNorFIndex;
295 retOffsets[loopCorFIdx] = loopCorFIndex;
296 retOffsets[heightIdx] = heightOffset;
297 retOffsets[widthIdx] = widthOffset;
298 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
299 retSizes[heightIdx] = builder.getIndexAttr(height);
300 retSizes[widthIdx] = builder.getIndexAttr(width);
301 SmallVector<OpFoldResult> strides(destSize, oneIndex);
302
303 auto insertSliceOp = tensor::InsertSliceOp::create(
304 builder, loc, source, dest, retOffsets, retSizes, strides);
305
306 return insertSliceOp;
307}
308
309/// Insert transformed height x width data to 6D tensors which it is
310/// extracted from.
311Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
312 Value dest, Value tileHIndex, Value tileWIndex,
313 Value loopNorFIndex, Value loopCorFIndex, int64_t height,
314 int64_t width, int64_t tileHIdx, int64_t tileWIdx,
315 int64_t loopNorFIdx, int64_t loopCorFIdx,
316 int64_t heightIdx, int64_t widthIdx) {
317 int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
318 auto zeroIndex = builder.getIndexAttr(0);
319 auto oneIndex = builder.getIndexAttr(1);
320 SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
321 retOffsets.resize(destSize);
322 retOffsets[tileHIdx] = tileHIndex;
323 retOffsets[tileWIdx] = tileWIndex;
324 retOffsets[loopNorFIdx] = loopNorFIndex;
325 retOffsets[loopCorFIdx] = loopCorFIndex;
326 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
327 retSizes[heightIdx] = builder.getIndexAttr(height);
328 retSizes[widthIdx] = builder.getIndexAttr(width);
329 SmallVector<OpFoldResult> strides(destSize, oneIndex);
330
331 auto insertSliceOp = tensor::InsertSliceOp::create(
332 builder, loc, source, dest, retOffsets, retSizes, strides);
333
334 return insertSliceOp;
335}
336
337/// This function transforms the filter. The data layout of the filter is FHWC.
338/// The transformation matrix is 2-dimension. We need to extract H x W from
339/// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
340/// After the transformation, we get
341///
342/// scf.for %f = lo_f to hi_f step 1
343/// scf.for %c = lo_c to hi_c step 1
344/// %extracted = extract filter<h x w> from filter<f x h x w x c>
345/// %ret = linalg.matmul G, %extracted
346/// %ret = linalg.matmul %ret, GT
347/// %inserted = insert %ret into filter<h x w x c x f>
348Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
349 Value retValue, WinogradConv2DFmr fmr,
350 bool leftTransform = true, bool rightTransform = true) {
351 // Map from (m, r) to G transform matrix.
352 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
353 GMatrices = {
354 {WinogradConv2DFmr::F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
355 {WinogradConv2DFmr::F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
356 {WinogradConv2DFmr::F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
357 };
358
359 // Map from (m, r) to GT transform matrix.
360 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
361 GTMatrices = {
362 {WinogradConv2DFmr::F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
363 {WinogradConv2DFmr::F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
364 {WinogradConv2DFmr::F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
365 };
366
367 auto filterType = cast<ShapedType>(filter.getType());
368 Type elementType = filterType.getElementType();
369 auto filterShape = filterType.getShape(); // F, H, W, C
370 int64_t filterF = filterShape[0];
371 int64_t filterH = filterShape[1];
372 int64_t filterW = filterShape[2];
373 int64_t filterC = filterShape[3];
374
375 int64_t m, r;
376 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
377 if (filterH != r && filterH != 1)
378 return Value();
379 if (filterW != r && filterW != 1)
380 return Value();
381
382 Value zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0);
383 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
385 Value FIter = ivs[0];
386 Value CIter = ivs[1];
387
388 // Extract (H, W) from (F, H, W, C).
389 auto extractFilter =
390 extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
391 zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
392 /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
393
394 int64_t retRows = 1;
395 Value matmulRetValue = extractFilter;
396 Value zero = arith::ConstantOp::create(builder, loc,
397 rewriter.getZeroAttr(elementType));
398 if (leftTransform) {
399 // Get constant transform matrix G.
400 auto it = GMatrices.find(fmr);
401 if (it == GMatrices.end())
402 return {};
403 const TransformMatrix &GMatrix = it->second;
404
405 retRows = GMatrix.rows;
406 auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
407 auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
408 elementType)
409 .getResult();
410 auto init =
411 linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
412
413 Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
414 // Multiply G x g.
415 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
416 ValueRange{G, extractFilter},
417 ValueRange{init});
418 matmulRetValue = matmulOp.getResult(0);
419 }
420
421 if (rightTransform) {
422 // Get constant transform matrix GT.
423 auto it = GTMatrices.find(fmr);
424 if (it == GTMatrices.end())
425 return {};
426 const TransformMatrix &GTMatrix = it->second;
427
428 auto matmulType =
429 RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
430 auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
431 elementType)
432 .getResult();
433 auto init =
434 linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
435
436 Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
437 // Multiply u = (G x g) x GT.
438 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
439 ValueRange{matmulRetValue, GT},
440 ValueRange{init});
441 matmulRetValue = matmulOp.getResult(0);
442 }
443
444 // Insert (H, W) to (H, W, C, F).
445 int64_t retHeight = leftTransform ? m + r - 1 : 1;
446 int64_t retWidth = rightTransform ? m + r - 1 : 1;
447
448 auto insertSliceOp =
449 insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
450 zeroIdx, zeroIdx, retHeight, retWidth,
451 /*loopNorFIdx=*/3, /*loopCorFIdx=*/2,
452 /*heightIdx=*/0, /*widthIdx=*/1);
453
454 return {insertSliceOp};
455 };
456
457 auto fUpperBound = arith::ConstantIndexOp::create(rewriter, loc, filterF);
458 auto cUpperBound = arith::ConstantIndexOp::create(rewriter, loc, filterC);
459 auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1);
460 scf::LoopNest loops = scf::buildLoopNest(
461 rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
462 {oneStep, oneStep}, {retValue}, buildBody);
463 return loops.results[0];
464}
465
466/// This function transforms the input. The data layout of the input is NHWC.
467/// The transformation matrix is 2-dimension. We need to extract H x W from
468/// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
469/// After the transformation, we get
470///
471/// scf.for %h = 0 to tileH step 1
472/// scf.for %w = 0 to tileW step 1
473/// scf.for %n = 0 to N step 1
474/// scf.for %c = 0 to C step 1
475/// %extracted = extract %extracted<alphaH x alphaW> from
476/// %input<N x H x W x C>
477/// at [%n, (%h x m), (%w x m), %c]
478/// %ret = linalg.matmul BT, %extracted
479/// %ret = linalg.matmul %ret, B
480/// %inserted = insert %ret<alphaH x alphaW> into
481/// %output<alphaH x alphaW x tileH x tileW x N x C>
482/// at [0, 0, %h, %w, %n, %c]
483Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
484 Value retValue, WinogradConv2DFmr fmr,
485 bool leftTransform = true, bool rightTransform = true) {
486 // Map from (m, r) to BT transform matrix.
487 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
488 BTMatrices = {
489 {WinogradConv2DFmr::F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
490 {WinogradConv2DFmr::F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
491 {WinogradConv2DFmr::F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
492 };
493
494 // Map from (m, r) to B transform matrix.
495 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
496 BMatrices = {
497 {WinogradConv2DFmr::F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
498 {WinogradConv2DFmr::F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
499 {WinogradConv2DFmr::F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
500 };
501
502 int64_t m, r;
503 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
504 auto inputType = cast<ShapedType>(input.getType());
505 Type elementType = inputType.getElementType();
506 auto inputShape = inputType.getShape(); // N, H, W, C
507 int64_t inputN = inputShape[0];
508 int64_t inputC = inputShape[3];
509 auto valueType = cast<ShapedType>(retValue.getType());
510 auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C
511 int64_t tileH = valueShape[2];
512 int64_t tileW = valueShape[3];
513 int64_t alphaH = leftTransform ? m + r - 1 : 1;
514 int64_t alphaW = rightTransform ? m + r - 1 : 1;
515
516 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
518 Value tileHIter = ivs[0];
519 Value tileWIter = ivs[1];
520 Value NIter = ivs[2];
521 Value CIter = ivs[3];
522
523 auto *context = builder.getContext();
524
525 auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
526 auto affineMap =
527 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
528 Value heightOffset = affine::AffineApplyOp::create(
529 builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
530 Value widthOffset = affine::AffineApplyOp::create(
531 builder, loc, rightTransform ? affineMap : identityAffineMap,
532 tileWIter);
533
534 // Extract (H, W) from (N, H, W, C).
535 auto extractInput =
536 extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
537 widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
538 /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
539
540 int64_t retRows = 1;
541 int64_t retCols = 1;
542 Value matmulRetValue = extractInput;
543 Value zero = arith::ConstantOp::create(builder, loc,
544 rewriter.getZeroAttr(elementType));
545 if (leftTransform) {
546 // Get constant transform matrix BT.
547 auto it = BTMatrices.find(fmr);
548 if (it == BTMatrices.end())
549 return {};
550 const TransformMatrix &BTMatrix = it->second;
551
552 retRows = BTMatrix.rows;
553 auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
554 auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
555 elementType)
556 .getResult();
557 auto init =
558 linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
559
560 Value BT = create2DTransformMatrix(builder, loc, BTMatrix, elementType);
561 // Multiply BT x d.
562 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
563 ValueRange{BT, matmulRetValue},
564 ValueRange{init});
565 matmulRetValue = matmulOp.getResult(0);
566 }
567
568 if (rightTransform) {
569 // Get constant transform matrix B.
570 auto it = BMatrices.find(fmr);
571 if (it == BMatrices.end())
572 return {};
573 const TransformMatrix &BMatrix = it->second;
574
575 retCols = BMatrix.cols;
576 auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
577 auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
578 elementType)
579 .getResult();
580 auto init =
581 linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
582 Value B = create2DTransformMatrix(builder, loc, BMatrix, elementType);
583 // Multiply v = (BT x d) x B.
584 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
585 ValueRange{matmulRetValue, B},
586 ValueRange{init});
587 matmulRetValue = matmulOp.getResult(0);
588 }
589
590 // Insert (H, W) to (H, W, tileH, tileW, N, C).
591 auto combinedVal = insert2DDataTo6D(
592 builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
593 CIter, retRows, retCols, 2, 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5,
594 /*heightIdx=*/0, /*widthIdx=*/1);
595
596 return {combinedVal};
597 };
598
599 auto zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0);
600 auto tileHBound = arith::ConstantIndexOp::create(rewriter, loc, tileH);
601 auto tileWBound = arith::ConstantIndexOp::create(rewriter, loc, tileW);
602 auto nUpperBound = arith::ConstantIndexOp::create(rewriter, loc, inputN);
603 auto cUpperBound = arith::ConstantIndexOp::create(rewriter, loc, inputC);
604 auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1);
605 scf::LoopNest loops = scf::buildLoopNest(
606 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
607 {tileHBound, tileWBound, nUpperBound, cUpperBound},
608 {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
609 return loops.results[0];
610}
611
612/// This function generates linalg.batch_matmul to multiply input with filter.
613/// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
614/// tileH x tileW x H x W data as the 1-dimensional data array. That is to
615/// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
616/// way, we can convert 6-dimensional inputs to 3-dimensional representation
617/// that is suitable for linalg.batch_matmul.
618///
619/// Batched matmul will do the matrix multiply with the reduction on channel.
620///
621/// We get
622///
623/// %collapsed_input = tensor.collapse_shape %input
624/// %collapsed_filter = tensor.collapse_shape %filter
625/// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
626/// %expanded_ret = tensor.expand_shape %ret
627///
628/// After this function, we get return value with data layout
629/// (tileH, tileW, H, W, N, F).
630static Value matrixMultiply(RewriterBase &rewriter, Location loc,
631 Value transformedFilter, Value transformedInput,
632 Type outputElementType) {
633 // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter.
634 auto filterType = cast<ShapedType>(transformedFilter.getType());
635 assert(filterType.hasStaticShape() && "only support static shapes.");
636 ArrayRef<int64_t> filterShape = filterType.getShape();
637 Type filterElementType = filterType.getElementType();
638 auto filterReassocType = RankedTensorType::get(
639 {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
640 filterElementType);
641 SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
642 Value collapseFilter = tensor::CollapseShapeOp::create(
643 rewriter, loc, filterReassocType, transformedFilter, filterReassoc);
644
645 // Convert (alphaH, alphaW, tileH, tileW, N, C) to
646 // (alphaH x alphaW, tileH x tileW x N, C) for input.
647 auto inputType = cast<ShapedType>(transformedInput.getType());
648 assert(inputType.hasStaticShape() && "only support static shapes.");
649 ArrayRef<int64_t> inputShape = inputType.getShape();
650 Type inputElementType = inputType.getElementType();
651 auto inputReassocType = RankedTensorType::get(
652 {inputShape[0] * inputShape[1],
653 inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
654 inputElementType);
655 SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
656 Value collapseInput = tensor::CollapseShapeOp::create(
657 rewriter, loc, inputReassocType, transformedInput, inputReassoc);
658
659 // Batched matrix multiply.
660 auto matmulType = RankedTensorType::get(
661 {inputShape[0] * inputShape[1],
662 inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
663 outputElementType);
664 Value empty = tensor::EmptyOp::create(rewriter, loc, matmulType.getShape(),
665 outputElementType)
666 .getResult();
667 Value zero = arith::ConstantOp::create(
668 rewriter, loc, rewriter.getZeroAttr(outputElementType));
669 Value init = linalg::FillOp::create(rewriter, loc, zero, empty).getResult(0);
670
671 auto matmulOp = linalg::BatchMatmulOp::create(
672 rewriter, loc, matmulType, ValueRange({collapseInput, collapseFilter}),
673 ValueRange{init});
674
675 // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
676 // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F).
677 SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
678 auto outputReassocType =
679 RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
680 inputShape[3], inputShape[4], filterShape[3]},
681 outputElementType);
682 auto expandOutput = tensor::ExpandShapeOp::create(
683 rewriter, loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
684 return expandOutput;
685}
686
687/// This function transforms the output. The data layout of the output is HWNF.
688/// The transformation matrix is 2-dimension. We need to extract H x W from
689/// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
690/// After the transformation, we get
691///
692/// scf.for %h = 0 to tileH step 1
693/// scf.for %w = 0 to tileW step 1
694/// scf.for %n = 0 to N step 1
695/// scf.for %f = 0 to F step 1
696/// %extracted = extract %extracted<alphaH x alphaW> from
697/// %input<alphaH x alphaW x tileH x tileW x N x F>
698/// at [0, 0, %h, %w, %n, %f]
699/// %ret = linalg.matmul AT, %extracted
700/// %ret = linalg.matmul %ret, A
701/// %inserted = insert %ret<alphaH x alphaW> into
702/// output<N x H x W x F>
703/// at [%n, (%h x m), (%w x m), %f]
704Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
705 Value output, WinogradConv2DFmr fmr,
706 bool leftTransform = true, bool rightTransform = true) {
707 // Map from (m, r) to AT transform matrix.
708 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
709 ATMatrices = {
710 {WinogradConv2DFmr::F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
711 {WinogradConv2DFmr::F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
712 {WinogradConv2DFmr::F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
713 };
714
715 // Map from (m, r) to A transform matrix.
716 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
717 AMatrices = {
718 {WinogradConv2DFmr::F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
719 {WinogradConv2DFmr::F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
720 {WinogradConv2DFmr::F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
721 };
722
723 int64_t m, r;
724 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
725 auto valueType = cast<ShapedType>(value.getType());
726 Type elementType = valueType.getElementType();
727 auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
728 int64_t valueH = valueShape[0];
729 int64_t valueW = valueShape[1];
730 int64_t valueN = valueShape[4];
731 int64_t valueF = valueShape[5];
732 int64_t alphaH = leftTransform ? m + r - 1 : 1;
733 int64_t alphaW = rightTransform ? m + r - 1 : 1;
734
735 if (valueH != alphaH && valueH != 1)
736 return Value();
737 if (valueW != alphaW && valueW != 1)
738 return Value();
739
740 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
742 auto *context = builder.getContext();
743 Value tileHIter = ivs[0];
744 Value tileWIter = ivs[1];
745 Value NIter = ivs[2];
746 Value FIter = ivs[3];
747
748 // Extract (H, W) from (H, W, tileH, tileW, N, F).
749 auto extractValue =
750 extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
751 FIter, 2, 3, /*loopNorFIdx=*/4,
752 /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
753
754 const TransformMatrix &AMatrix = AMatrices.at(fmr);
755 const TransformMatrix &ATMatrix = ATMatrices.at(fmr);
756 int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
757 (leftTransform ? ATMatrix.scalarFactor : 1);
758 int64_t retCols = rightTransform ? AMatrix.cols : 1;
759 int64_t retRows = leftTransform ? ATMatrix.rows : 1;
760
761 Value matmulRetValue = extractValue;
762 Value zero = arith::ConstantOp::create(builder, loc,
763 rewriter.getZeroAttr(elementType));
764
765 auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
766 auto affineMap =
767 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
768 Value heightOffset = affine::AffineApplyOp::create(
769 builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
770 Value widthOffset = affine::AffineApplyOp::create(
771 builder, loc, rightTransform ? affineMap : identityAffineMap,
772 tileWIter);
773
774 Value outInitVal =
775 extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
776 widthOffset, retRows, retCols,
777 /*loopNorFIdx=*/0,
778 /*loopCorFIdx=*/3, /*heightIdx=*/1,
779 /*widthIdx=*/2);
780 if (leftTransform) {
781 auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
782 Value init = outInitVal;
783 if (rightTransform || scalarFactor != 1) {
784 auto empty = tensor::EmptyOp::create(builder, loc,
785 matmulType.getShape(), elementType)
786 .getResult();
787 init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
788 }
789
790 Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
791 // Multiply AT x m.
792 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
793 ValueRange{AT, matmulRetValue},
794 ValueRange{init});
795 matmulRetValue = matmulOp.getResult(0);
796 }
797
798 if (rightTransform) {
799 auto matmulType =
800 RankedTensorType::get({retRows, AMatrix.cols}, elementType);
801 Value init = outInitVal;
802 if (scalarFactor != 1) {
803 auto empty = tensor::EmptyOp::create(builder, loc,
804 matmulType.getShape(), elementType)
805 .getResult();
806 init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
807 }
808
809 Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
810 // Multiply y = (AT x m) x A.
811 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
812 ValueRange{matmulRetValue, A},
813 ValueRange{init});
814 matmulRetValue = matmulOp.getResult(0);
815 }
816
817 if (scalarFactor != 1) {
818 // Multiply by scalar factor and add outInitVal.
819 Value scalarFactorValue = arith::ConstantOp::create(
820 builder, loc, FloatAttr::get(elementType, scalarFactor));
821 auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
822 auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
823 SmallVector<AffineMap> affineMaps = {
824 AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
825
826 matmulRetValue =
827 linalg::GenericOp::create(
828 rewriter, loc, matmulType,
829 ValueRange{scalarFactorValue, matmulRetValue},
830 ValueRange{outInitVal}, affineMaps,
831 llvm::ArrayRef<utils::IteratorType>{
832 utils::IteratorType::parallel, utils::IteratorType::parallel},
833 [&](OpBuilder &nestedBuilder, Location nestedLoc,
834 ValueRange args) {
835 auto mulf = arith::MulFOp::create(nestedBuilder, nestedLoc,
836 args[0], args[1]);
837 auto addf = arith::AddFOp::create(nestedBuilder, nestedLoc,
838 mulf.getResult(), args[2]);
839 linalg::YieldOp::create(nestedBuilder, nestedLoc,
840 addf.getResult());
841 })
842 .getResult(0);
843 }
844
845 // Insert (H, W) to (N, H, W, F).
846 Value combinedVal =
847 insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
848 heightOffset, widthOffset, retRows, retCols,
849 /*loopNorFIdx=*/0,
850 /*loopCorFIdx=*/3, /*heightIdx=*/1,
851 /*widthIdx=*/2);
852
853 return {combinedVal};
854 };
855
856 int64_t tilwH = valueShape[2];
857 int64_t tileW = valueShape[3];
858 auto zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0);
859 auto tileHBound = arith::ConstantIndexOp::create(rewriter, loc, tilwH);
860 auto tileWBound = arith::ConstantIndexOp::create(rewriter, loc, tileW);
861 auto nUpperBound = arith::ConstantIndexOp::create(rewriter, loc, valueN);
862 auto fUpperBound = arith::ConstantIndexOp::create(rewriter, loc, valueF);
863 auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1);
864 scf::LoopNest loops = scf::buildLoopNest(
865 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
866 {tileHBound, tileWBound, nUpperBound, fUpperBound},
867 {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
868 return loops.results[0];
869}
870
871/// Create an empty tensor with alignedType and insert the value into the
872/// created empty tensor with aligned size.
873static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
874 Value value, ArrayRef<int64_t> alignedShape) {
875 auto valueType = cast<ShapedType>(value.getType());
876 Type elementType = valueType.getElementType();
877 auto alignedType = RankedTensorType::get(alignedShape, elementType);
878 Value padValue = arith::ConstantOp::create(rewriter, loc, elementType,
879 rewriter.getZeroAttr(elementType));
880
881 return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value,
882 padValue, false);
883}
884
885/// Extract sub-tensor with extractedType from value.
886static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
887 Value value,
888 RankedTensorType extractedType) {
889 OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
890 OpFoldResult oneIndex = rewriter.getIndexAttr(1);
891 SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
892 SmallVector<OpFoldResult, 4> strides(4, oneIndex);
893
894 ArrayRef<int64_t> extractedShape = extractedType.getShape();
895 SmallVector<OpFoldResult> sizes =
896 getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape));
897
898 return tensor::ExtractSliceOp::create(rewriter, loc, extractedType, value,
899 offsets, sizes, strides);
900}
901
902/// Utility function to check all values in the attribute are 1.
903static bool hasAllOneValues(DenseIntElementsAttr attr) {
904 return llvm::all_of(
905 attr, [](const APInt &element) { return element.getSExtValue() == 1; });
906}
907
908/// A helper function to convert linalg.conv_2d_nhwc_fhwc to
909/// linalg.winograd_*_transform ops.
910static FailureOr<Operation *>
911winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
912 WinogradConv2DFmr fmr) {
913 if (!convOp.hasPureTensorSemantics())
914 return rewriter.notifyMatchFailure(
915 convOp, "expected pure tensor semantics for linalg.conv_2d_nhwc_fhwc");
916
917 Value input = convOp.getInputs()[0];
918 Value filter = convOp.getInputs()[1];
919 Value output = convOp.getOutputs()[0];
920 auto inputType = cast<ShapedType>(input.getType());
921 auto filterType = cast<ShapedType>(filter.getType());
922 auto outputType = cast<ShapedType>(output.getType());
923
924 if (!inputType.hasStaticShape())
925 return rewriter.notifyMatchFailure(convOp,
926 "expected a static shape for the input");
927
928 if (!filterType.hasStaticShape())
929 return rewriter.notifyMatchFailure(
930 convOp, "expected a static shape for the filter");
931
932 if (!hasAllOneValues(convOp.getDilations()))
933 return rewriter.notifyMatchFailure(convOp,
934 "expected all ones for dilations");
935
936 if (!hasAllOneValues(convOp.getStrides()))
937 return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
938
939 ArrayRef<int64_t> filterShape = filterType.getShape();
940 int64_t filterF = filterShape[0];
941 int64_t filterH = filterShape[1];
942 int64_t filterW = filterShape[2];
943 int64_t filterC = filterShape[3];
944 ArrayRef<int64_t> inputShape = inputType.getShape();
945 int64_t inputN = inputShape[0];
946 int64_t inputH = inputShape[1];
947 int64_t inputW = inputShape[2];
948 int64_t inputC = inputShape[3];
949 ArrayRef<int64_t> outputShape = outputType.getShape();
950 int64_t outputN = outputShape[0];
951 int64_t outputH = outputShape[1];
952 int64_t outputW = outputShape[2];
953 int64_t outputF = outputShape[3];
954
955 int64_t m, r;
956 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
957 // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
958 bool isSupportedFilter = false;
959 if (filterH == filterW && filterH == r)
960 isSupportedFilter = true;
961 if (filterH == r && filterW == 1)
962 isSupportedFilter = true;
963 if (filterH == 1 && filterW == r)
964 isSupportedFilter = true;
965
966 if (!isSupportedFilter)
967 return rewriter.notifyMatchFailure(
968 convOp, "only support filter (r x r), (r x 1) or (1 x r)");
969
970 // All the criterias are satisfied. We can do Winograd Conv2D.
971 Location loc = convOp.getLoc();
972
973 // For F(m x 1, r x 1), we only need to do left side transform.
974 bool leftTransform = filterH != 1;
975 // For F(1 x m, 1 x r), we only need to do right side transform.
976 bool rightTransform = filterW != 1;
977 int64_t heightM = leftTransform ? m : 1;
978 int64_t widthM = rightTransform ? m : 1;
979 int64_t heightR = leftTransform ? r : 1;
980 int64_t widthR = rightTransform ? r : 1;
981
982 // --- Create operation for filter transform ---
983 Type filterElementType = filterType.getElementType();
984 int64_t alphaH = heightM + heightR - 1;
985 int64_t alphaW = widthM + widthR - 1;
986 int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
987 int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
988 auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF},
989 filterElementType);
990 Value retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(),
991 filterElementType);
992 auto transformedFilter = linalg::WinogradFilterTransformOp::create(
993 rewriter, loc, retType, filter, retValue, fmr);
994
995 // --- Create operation for input transform ---
996
997 // When input size - (r - 1) is not aligned with output tile size, we need to
998 // pad the input data to create the full tiles as tiling.
999 Type inputElementType = inputType.getElementType();
1000 int64_t alignedInputH = tileH * heightM + (heightR - 1);
1001 int64_t alignedInputW = tileW * widthM + (widthR - 1);
1002 if (alignedInputH != inputH || alignedInputW != inputW) {
1003 input = padToAlignedTensor(rewriter, loc, input,
1004 {inputN, alignedInputH, alignedInputW, inputC});
1005 }
1006
1007 retType = RankedTensorType::get(
1008 {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
1009 retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(),
1010 inputElementType);
1011 auto transformedInput = linalg::WinogradInputTransformOp::create(
1012 rewriter, loc, retType, input, retValue, fmr);
1013
1014 Type outputElementType = outputType.getElementType();
1015 Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
1016 transformedInput, outputElementType);
1017
1018 // --- Create operation for output transform ---
1019
1020 // When output size is not aligned with output tile size, we need to pad the
1021 // output buffer to insert the full tiles after tiling.
1022 int64_t alignedOutputH = tileH * heightM;
1023 int64_t alignedOutputW = tileW * widthM;
1024 bool isOutputUnaligned =
1025 ((alignedOutputH != outputH) || (alignedOutputW != outputW));
1026 if (isOutputUnaligned) {
1027 auto alignedOutputType = RankedTensorType::get(
1028 {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
1029 output =
1030 padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
1031 outputType = alignedOutputType;
1032 }
1033
1034 Value transformedOutput = linalg::WinogradOutputTransformOp::create(
1035 rewriter, loc, outputType, matmulRet, output, fmr);
1036
1037 // When output size is not aligned with output tile size, extract the
1038 // value from the padded buffer.
1039 if (isOutputUnaligned) {
1040 transformedOutput = extractFromAlignedTensor(
1041 rewriter, loc, transformedOutput,
1042 RankedTensorType::get({outputN, outputH, outputW, outputF},
1043 outputElementType));
1044 }
1045
1046 rewriter.replaceOp(convOp, transformedOutput);
1047
1048 return transformedOutput.getDefiningOp();
1049}
1050
1051/// A helper function to decompose linalg.winograd_filter_transform.
1052FailureOr<Operation *>
1053decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
1054 linalg::WinogradFilterTransformOp op) {
1055 Location loc = op.getLoc();
1056 Value filter = op.getFilter();
1057 auto filterType = cast<ShapedType>(filter.getType());
1058 auto filterShape = filterType.getShape();
1059 int64_t filterH = filterShape[1];
1060 int64_t filterW = filterShape[2];
1061
1062 // For F(m x 1, r x 1), we only need to do left side transform.
1063 bool leftTransform = filterH != 1;
1064 // For F(1 x m, 1 x r), we only need to do right side transform.
1065 bool rightTransform = filterW != 1;
1066 Value transformedFilter =
1067 filterTransform(rewriter, loc, filter, op.getOutput(), op.getFmr(),
1068 leftTransform, rightTransform);
1069 if (!transformedFilter)
1070 return failure();
1071
1072 rewriter.replaceOp(op, transformedFilter);
1073
1074 return transformedFilter.getDefiningOp();
1075}
1076
1077/// A helper function to decompose linalg.winograd_input_transform.
1078FailureOr<Operation *>
1079decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
1080 linalg::WinogradInputTransformOp op) {
1081 Location loc = op.getLoc();
1082 Value output = op.getOutput();
1083 auto outputType = cast<ShapedType>(output.getType());
1084 auto outputShape = outputType.getShape();
1085
1086 int64_t outputH = outputShape[0];
1087 int64_t outputW = outputShape[1];
1088
1089 // For F(m x 1, r x 1), we only need to do left side transform.
1090 bool leftTransform = outputH != 1;
1091 // For F(1 x m, 1 x r), we only need to do right side transform.
1092 bool rightTransform = outputW != 1;
1093 Value transformedInput =
1094 inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getFmr(),
1095 leftTransform, rightTransform);
1096 if (!transformedInput)
1097 return failure();
1098
1099 rewriter.replaceOp(op, transformedInput);
1100
1101 return transformedInput.getDefiningOp();
1102}
1103
1104/// A helper function to decompose linalg.winograd_output_transform.
1105FailureOr<Operation *>
1106decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
1107 linalg::WinogradOutputTransformOp op) {
1108 Location loc = op.getLoc();
1109 Value value = op.getValue();
1110 auto valueType = cast<ShapedType>(value.getType());
1111 auto valueShape = valueType.getShape();
1112 int64_t valueH = valueShape[0];
1113 int64_t valueW = valueShape[1];
1114
1115 // For F(m x 1, r x 1), we only need to do left side transform.
1116 bool leftTransform = valueH != 1;
1117 // For F(1 x m, 1 x r), we only need to do right side transform.
1118 bool rightTransform = valueW != 1;
1119 Value transformedOutput =
1120 outputTransform(rewriter, loc, value, op.getOutput(), op.getFmr(),
1121 leftTransform, rightTransform);
1122 if (!transformedOutput)
1123 return failure();
1124
1125 rewriter.replaceOp(op, transformedOutput);
1126
1127 return transformedOutput.getDefiningOp();
1128}
1129
1130/// A rewrite pattern to decompose linalg.winograd_filter_transform operations.
1131class DecomposeWinogradFilterTransform final
1132 : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
1133public:
1135
1136 LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
1137 PatternRewriter &rewriter) const override {
1138 return decomposeWinogradFilterTransformHelper(rewriter, op);
1139 }
1140};
1141
1142/// A rewrite pattern to decompose linalg.winograd_input_transform operations.
1143class DecomposeWinogradInputTransform final
1144 : public OpRewritePattern<linalg::WinogradInputTransformOp> {
1145public:
1147
1148 LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
1149 PatternRewriter &rewriter) const override {
1150 return decomposeWinogradInputTransformHelper(rewriter, op);
1151 }
1152};
1153
1154/// A rewrite pattern to decompose linalg.winograd_output_transform operations.
1155class DecomposeWinogradOutputTransform final
1156 : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
1157public:
1159
1160 LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
1161 PatternRewriter &rewriter) const override {
1162 return decomposeWinogradOutputTransformHelper(rewriter, op);
1163 }
1164};
1165
1166/// A rewrite pattern for Winograd Conv2D algorithm.
1167class WinogradConv2DNhwcFhwc final
1168 : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
1169public:
1171 WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, WinogradConv2DFmr fmr)
1172 : OpRewritePattern(context), fmr(fmr) {}
1173
1174 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
1175 PatternRewriter &rewriter) const override {
1176 if (failed(winogradConv2DHelper(rewriter, convOp, fmr)))
1177 return failure();
1178
1179 return success();
1180 }
1181
1182private:
1183 WinogradConv2DFmr fmr;
1184};
1185
1186} // end anonymous namespace
1187
1188//===----------------------------------------------------------------------===//
1189FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1190 linalg::Conv2DNhwcFhwcOp op,
1191 linalg::WinogradConv2DFmr fmr) {
1192 return winogradConv2DHelper(rewriter, op, fmr);
1193}
1194
1195FailureOr<Operation *>
1197 linalg::WinogradFilterTransformOp op) {
1198 return decomposeWinogradFilterTransformHelper(rewriter, op);
1199}
1200
1201FailureOr<Operation *>
1203 linalg::WinogradInputTransformOp op) {
1204 return decomposeWinogradInputTransformHelper(rewriter, op);
1205}
1206
1207FailureOr<Operation *>
1209 linalg::WinogradOutputTransformOp op) {
1210 return decomposeWinogradOutputTransformHelper(rewriter, op);
1211}
1212
1214 WinogradConv2DFmr fmr) {
1215 MLIRContext *context = patterns.getContext();
1216 // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw
1217 patterns.insert<WinogradConv2DNhwcFhwc>(context, fmr);
1218}
1219
1221 MLIRContext *context = patterns.getContext();
1222 patterns
1223 .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
1224 DecomposeWinogradOutputTransform>(context);
1225}
1226
1227} // end namespace linalg
1228} // end namespace mlir
return success()
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, WinogradConv2DFmr fmr)
Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns)
Patterns to decompose Winograd operators.
static bool hasAllOneValues(DenseIntElementsAttr attr)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value padding, bool nofold, ValueRange typeDynDims={})
Create a tensor::PadOp that pads source to the shape of type whose sizes are assumed to be greater th...
Definition Utils.cpp:1115
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:837
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
ValueVector results
Definition SCF.h:68