MLIR  21.0.0git
WinogradConv2D.cpp
Go to the documentation of this file.
1 //===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implement Winograd Conv2D algorithm. The implementation is based on the
10 // paper: Fast Algorithms for Convolutional Neural Networks
11 // (https://arxiv.org/abs/1509.09308)
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "llvm/Support/MathExtras.h"
23 
24 namespace mlir {
25 namespace linalg {
26 
27 namespace {
28 
29 // clang-format off
30 /// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
31 /// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
32 /// m is the output dimension and r is the filter dimension, is
33 ///
34 /// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
35 ///
36 /// g is filter and d is input data. We need to prepare 6 constant
37 /// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
38 ///
39 /// The following tables define these constant transformation matrices for
40 /// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
41 constexpr float G_2x2_3x3[] = {
42  -1, 0, 0,
43  1./2, -1./2, 1./2,
44  1./2, 1./2, 1./2,
45  0, 0, 1
46 };
47 
48 constexpr float GT_2x2_3x3[] = {
49  -1, 1./2, 1./2, 0,
50  0, -1./2, 1./2, 0,
51  0, 1./2, 1./2, 1
52 };
53 
54 constexpr float BT_2x2_3x3[] = {
55  -1, 0, 1, 0,
56  0, -1, 1, 0,
57  0, 1, 1, 0,
58  0, -1, 0, 1
59 };
60 
61 constexpr float B_2x2_3x3[] = {
62  -1, 0, 0, 0,
63  0, -1, 1, -1,
64  1, 1, 1, 0,
65  0, 0, 0, 1
66 };
67 
68 constexpr float AT_2x2_3x3[] = {
69  1, 1, 1, 0,
70  0, -1, 1, 1
71 };
72 
73 constexpr float A_2x2_3x3[] = {
74  1, 0,
75  1, -1,
76  1, 1,
77  0, 1
78 };
79 
80 constexpr float G_4x4_3x3[] = {
81  1, 0, 0,
82  -1./3, 1./3, -1./3,
83  -1./3, -1./3, -1./3,
84  1./12, -1./6, 1./3,
85  1./12, 1./6, 1./3,
86  0, 0, 1
87 };
88 
89 constexpr float GT_4x4_3x3[] = {
90  1, -1./3, -1./3, 1./12, 1./12, 0,
91  0, 1./3, -1./3, -1./6, 1./6, 0,
92  0, -1./3, -1./3, 1./3, 1./3, 1
93 };
94 
95 constexpr float BT_4x4_3x3[] = {
96  1./4, 0, -5./16, 0, 1./16, 0,
97  0, 1./4, -1./4, -1./16, 1./16, 0,
98  0, -1./4, -1./4, 1./16, 1./16, 0,
99  0, 1./4, -1./8, -1./4, 1./8, 0,
100  0, -1./4, -1./8, 1./4, 1./8, 0,
101  0, 1./4, 0, -5./16, 0, 1./16
102 };
103 
104 constexpr float B_4x4_3x3[] = {
105  1./4, 0, 0, 0, 0, 0,
106  0, 1./4, -1./4, 1./4, -1./4, 1./4,
107  -5./16, -1./4, -1./4, -1./8, -1./8, 0,
108  0, -1./16, 1./16, -1./4, 1./4, -5./16,
109  1./16, 1./16, 1./16, 1./8, 1./8, 0,
110  0, 0, 0, 0, 0, 1./16
111 };
112 
113 constexpr float AT_4x4_3x3[] = {
114  1./8, 1./4, 1./4, 1./8, 1./8, 0,
115  0, -1./4, 1./4, -1./4, 1./4, 0,
116  0, 1./4, 1./4, 1./2, 1./2, 0,
117  0, -1./4, 1./4, -1, 1, 1./2
118 };
119 
120 constexpr float A_4x4_3x3[] = {
121  1./8, 0, 0, 0,
122  1./4, -1./4, 1./4, -1./4,
123  1./4, 1./4, 1./4, 1./4,
124  1./8, -1./4, 1./2, -1,
125  1./8, 1./4, 1./2, 1,
126  0, 0, 0, 1./2
127 };
128 
129 constexpr float G_2x2_5x5[] = {
130  1, 0, 0, 0, 0,
131  1./6, -1./6, 1./6, -1./6, 1./6,
132  -1./6, -1./6, -1./6, -1./6, -1./6,
133 -4./15, 2./15, -1./15, 1./30, -1./60,
134  1./60, 1./30, 1./15, 2./15, 4./15,
135  0, 0, 0, 0, 1
136 };
137 
138 constexpr float GT_2x2_5x5[] = {
139  1, 1./6, -1./6, -4./15, 1./60, 0,
140  0, -1./6, -1./6, 2./15, 1./30, 0,
141  0, 1./6, -1./6, -1./15, 1./15, 0,
142  0, -1./6, -1./6, 1./30, 2./15, 0,
143  0, 1./6, -1./6, -1./60, 4./15, 1
144 };
145 
146 constexpr float BT_2x2_5x5[] = {
147  1./8, 3./16, -1./4, -3./16, 1./8, 0,
148  0, 1./8, 1./16, -5./16, 1./8, 0,
149  0, -1./8, -5./16, -1./16, 1./8, 0,
150  0, 1./4, -1./8, -1./4, 1./8, 0,
151  0, -1./8, -1./4, 1./8, 1./4, 0,
152  0, 1./8, 3./16, -1./4, -3./16, 1./8
153 };
154 
155 constexpr float B_2x2_5x5[] = {
156  1./8, 0, 0, 0, 0, 0,
157  3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
158  -1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
159  -3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
160  1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
161  0, 0, 0, 0, 0, 1./8
162 };
163 
164 constexpr float AT_2x2_5x5[] = {
165  1./2, 1, 1, 2, 1, 0,
166  0, -1, 1, -1, 2, 1./2
167 };
168 
169 constexpr float A_2x2_5x5[] = {
170  1./2, 0,
171  1, -1,
172  1, 1,
173  2, -1,
174  1, 2,
175  0, 1./2
176 };
177 // clang-format on
178 
179 using TransformMapKeyTy = std::pair<int, int>;
180 
181 /// We use F(m, r) to define the size of minimal filtering algorithms.
182 /// m is the output dimension and r is the filter dimension. We can get
183 /// the input dimension, alpha, from the formula, alpha = m + r - 1.
184 ///
185 /// For example, when m = 2 and r = 3, we know its input size is 4.
186 /// The Conv2D will operate on 4x4 input data with 3x3 filter and get
187 /// 2x2 output result.
188 constexpr TransformMapKeyTy F_2_3{2, 3};
189 constexpr TransformMapKeyTy F_4_3{4, 3};
190 constexpr TransformMapKeyTy F_2_5{2, 5};
191 
192 /// Structure to keep information of constant transform matrices.
193 struct TransformMatrix {
194  TransformMatrix(const float *table, int64_t rows, int64_t cols,
195  int64_t scalarFactor = 1)
197 
198  const float *table;
199  int64_t rows;
200  int64_t cols;
201  int64_t scalarFactor;
202 };
203 
204 /// Utility function to convert constant array to arith.constant Value.
205 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
206  TransformMatrix transform, Type type) {
207  ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
208 
209  return builder.create<arith::ConstantOp>(
212  SmallVector<int64_t>{transform.rows, transform.cols}, type),
213  constVec));
214 }
215 
216 /// Extract height x width data from 4D tensors.
217 Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
218  Value loopNorFIndex, Value loopCorFIndex,
219  Value heightOffset, Value widthOffset,
220  int64_t extractHeight, int64_t extractWidth,
221  int64_t loopNorFIdx, int64_t loopCorFIdx,
222  int64_t heightIdx, int64_t widthIdx) {
223  auto sourceType = cast<ShapedType>(source.getType());
224  Type elementType = sourceType.getElementType();
225  int64_t srcSize = sourceType.getRank();
226 
227  auto oneIndex = builder.getIndexAttr(1);
228  SmallVector<OpFoldResult> offsets;
229  offsets.resize(srcSize);
230  offsets[loopNorFIdx] = loopNorFIndex;
231  offsets[loopCorFIdx] = loopCorFIndex;
232  offsets[heightIdx] = heightOffset;
233  offsets[widthIdx] = widthOffset;
234  SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
235  sizes[heightIdx] = builder.getIndexAttr(extractHeight);
236  sizes[widthIdx] = builder.getIndexAttr(extractWidth);
237  SmallVector<OpFoldResult> strides(srcSize, oneIndex);
238 
239  auto extractFilterType =
240  RankedTensorType::get({extractHeight, extractWidth}, elementType);
241  auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
242  loc, extractFilterType, source, offsets, sizes, strides);
243 
244  return extractFilterOp;
245 }
246 
247 /// Extract height x width data from 6D tensors.
248 Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
249  Value tileHIndex, Value tileWIndex,
250  Value loopNorFIndex, Value loopCorFIndex,
251  int64_t tileHIdx, int64_t tileWIdx,
252  int64_t loopNorFIdx, int64_t loopCorFIdx,
253  int64_t heightIdx, int64_t widthIdx) {
254  auto sourceType = cast<ShapedType>(source.getType());
255  Type elementType = sourceType.getElementType();
256  auto sourceShape = sourceType.getShape();
257  int64_t srcSize = sourceType.getRank();
258  int64_t height = sourceShape[heightIdx];
259  int64_t width = sourceShape[widthIdx];
260 
261  auto zeroIndex = builder.getIndexAttr(0);
262  auto oneIndex = builder.getIndexAttr(1);
263  SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
264  offsets.resize(srcSize);
265  offsets[tileHIdx] = tileHIndex;
266  offsets[tileWIdx] = tileWIndex;
267  offsets[loopNorFIdx] = loopNorFIndex;
268  offsets[loopCorFIdx] = loopCorFIndex;
269  SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
270  sizes[heightIdx] = builder.getIndexAttr(height);
271  sizes[widthIdx] = builder.getIndexAttr(width);
272  SmallVector<OpFoldResult> strides(srcSize, oneIndex);
273 
274  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
275  auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
276  loc, extractFilterType, source, offsets, sizes, strides);
277 
278  return extractFilterOp;
279 }
280 
281 /// Insert transformed height x width data to 4D tensors which it is
282 /// extracted from.
283 Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
284  Value dest, Value loopNorFIndex, Value loopCorFIndex,
285  Value heightOffset, Value widthOffset, int64_t height,
286  int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
287  int64_t heightIdx, int64_t widthIdx) {
288  int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
289  auto oneIndex = builder.getIndexAttr(1);
290  SmallVector<OpFoldResult> retOffsets;
291  retOffsets.resize(destSize);
292  retOffsets[loopNorFIdx] = loopNorFIndex;
293  retOffsets[loopCorFIdx] = loopCorFIndex;
294  retOffsets[heightIdx] = heightOffset;
295  retOffsets[widthIdx] = widthOffset;
296  SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
297  retSizes[heightIdx] = builder.getIndexAttr(height);
298  retSizes[widthIdx] = builder.getIndexAttr(width);
299  SmallVector<OpFoldResult> strides(destSize, oneIndex);
300 
301  auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
302  loc, source, dest, retOffsets, retSizes, strides);
303 
304  return insertSliceOp;
305 }
306 
307 /// Insert transformed height x width data to 6D tensors which it is
308 /// extracted from.
309 Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
310  Value dest, Value tileHIndex, Value tileWIndex,
311  Value loopNorFIndex, Value loopCorFIndex, int64_t height,
312  int64_t width, int64_t tileHIdx, int64_t tileWIdx,
313  int64_t loopNorFIdx, int64_t loopCorFIdx,
314  int64_t heightIdx, int64_t widthIdx) {
315  int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
316  auto zeroIndex = builder.getIndexAttr(0);
317  auto oneIndex = builder.getIndexAttr(1);
318  SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
319  retOffsets.resize(destSize);
320  retOffsets[tileHIdx] = tileHIndex;
321  retOffsets[tileWIdx] = tileWIndex;
322  retOffsets[loopNorFIdx] = loopNorFIndex;
323  retOffsets[loopCorFIdx] = loopCorFIndex;
324  SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
325  retSizes[heightIdx] = builder.getIndexAttr(height);
326  retSizes[widthIdx] = builder.getIndexAttr(width);
327  SmallVector<OpFoldResult> strides(destSize, oneIndex);
328 
329  auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
330  loc, source, dest, retOffsets, retSizes, strides);
331 
332  return insertSliceOp;
333 }
334 
335 /// This function transforms the filter. The data layout of the filter is FHWC.
336 /// The transformation matrix is 2-dimension. We need to extract H x W from
337 /// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
338 /// After the transformation, we get
339 ///
340 /// scf.for %f = lo_f to hi_f step 1
341 /// scf.for %c = lo_c to hi_c step 1
342 /// %extracted = extract filter<h x w> from filter<f x h x w x c>
343 /// %ret = linalg.matmul G, %extracted
344 /// %ret = linalg.matmul %ret, GT
345 /// %inserted = insert %ret into filter<h x w x c x f>
346 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
347  Value retValue, int64_t m, int64_t r,
348  bool leftTransform = true, bool rightTransform = true) {
349  // Map from (m, r) to G transform matrix.
350  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
351  GMatrices = {
352  {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
353  {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
354  {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
355  };
356 
357  // Map from (m, r) to GT transform matrix.
358  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
359  GTMatrices = {
360  {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
361  {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
362  {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
363  };
364 
365  auto filterType = cast<ShapedType>(filter.getType());
366  Type elementType = filterType.getElementType();
367  auto filterShape = filterType.getShape(); // F, H, W, C
368  int64_t filterF = filterShape[0];
369  int64_t filterH = filterShape[1];
370  int64_t filterW = filterShape[2];
371  int64_t filterC = filterShape[3];
372 
373  if (filterH != r && filterH != 1)
374  return Value();
375  if (filterW != r && filterW != 1)
376  return Value();
377 
378  Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
379  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
380  ValueRange args) -> scf::ValueVector {
381  Value FIter = ivs[0];
382  Value CIter = ivs[1];
383 
384  // Extract (H, W) from (F, H, W, C).
385  auto extractFilter =
386  extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
387  zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
388  /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
389 
390  TransformMapKeyTy key = {m, r};
391  int64_t retRows = 1;
392  Value matmulRetValue = extractFilter;
393  Value zero = builder.create<arith::ConstantOp>(
394  loc, rewriter.getZeroAttr(elementType));
395  if (leftTransform) {
396  // Get constant transform matrix G.
397  auto it = GMatrices.find(key);
398  if (it == GMatrices.end())
399  return {};
400  const TransformMatrix &GMatrix = it->second;
401 
402  retRows = GMatrix.rows;
403  auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
404  auto empty =
405  builder
406  .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
407  .getResult();
408  auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
409 
410  Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
411  // Multiply G x g.
412  auto matmulOp = builder.create<linalg::MatmulOp>(
413  loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
414  matmulRetValue = matmulOp.getResult(0);
415  }
416 
417  if (rightTransform) {
418  // Get constant transform matrix GT.
419  auto it = GTMatrices.find(key);
420  if (it == GTMatrices.end())
421  return {};
422  const TransformMatrix &GTMatrix = it->second;
423 
424  auto matmulType =
425  RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
426  auto empty =
427  builder
428  .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
429  .getResult();
430  auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
431 
432  Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
433  // Multiply u = (G x g) x GT.
434  auto matmulOp = builder.create<linalg::MatmulOp>(
435  loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
436  matmulRetValue = matmulOp.getResult(0);
437  }
438 
439  // Insert (H, W) to (H, W, C, F).
440  int64_t retHeight = leftTransform ? m + r - 1 : 1;
441  int64_t retWidth = rightTransform ? m + r - 1 : 1;
442 
443  auto insertSliceOp =
444  insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
445  zeroIdx, zeroIdx, retHeight, retWidth,
446  /*loopNorFIdx=*/3, /*loopCorFIdx=*/2,
447  /*heightIdx=*/0, /*widthIdx=*/1);
448 
449  return {insertSliceOp};
450  };
451 
452  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
453  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
454  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
455  scf::LoopNest loops = scf::buildLoopNest(
456  rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
457  {oneStep, oneStep}, {retValue}, buildBody);
458  return loops.results[0];
459 }
460 
461 /// This function transforms the input. The data layout of the input is NHWC.
462 /// The transformation matrix is 2-dimension. We need to extract H x W from
463 /// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
464 /// After the transformation, we get
465 ///
466 /// scf.for %h = 0 to tileH step 1
467 /// scf.for %w = 0 to tileW step 1
468 /// scf.for %n = 0 to N step 1
469 /// scf.for %c = 0 to C step 1
470 /// %extracted = extract %extracted<alphaH x alphaW> from
471 /// %input<N x H x W x C>
472 /// at [%n, (%h x m), (%w x m), %c]
473 /// %ret = linalg.matmul BT, %extracted
474 /// %ret = linalg.matmul %ret, B
475 /// %inserted = insert %ret<alphaH x alphaW> into
476 /// %output<alphaH x alphaW x tileH x tileW x N x C>
477 /// at [0, 0, %h, %w, %n, %c]
478 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
479  Value retValue, int64_t m, int64_t r,
480  bool leftTransform = true, bool rightTransform = true) {
481  // Map from (m, r) to BT transform matrix.
482  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
483  BTMatrices = {
484  {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
485  {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
486  {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
487  };
488 
489  // Map from (m, r) to B transform matrix.
490  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
491  BMatrices = {
492  {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
493  {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
494  {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
495  };
496 
497  auto inputType = cast<ShapedType>(input.getType());
498  Type elementType = inputType.getElementType();
499  auto inputShape = inputType.getShape(); // N, H, W, C
500  int64_t inputN = inputShape[0];
501  int64_t inputC = inputShape[3];
502  auto valueType = cast<ShapedType>(retValue.getType());
503  auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C
504  int64_t tileH = valueShape[2];
505  int64_t tileW = valueShape[3];
506  int64_t alphaH = leftTransform ? m + r - 1 : 1;
507  int64_t alphaW = rightTransform ? m + r - 1 : 1;
508 
509  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
510  ValueRange args) -> scf::ValueVector {
511  Value tileHIter = ivs[0];
512  Value tileWIter = ivs[1];
513  Value NIter = ivs[2];
514  Value CIter = ivs[3];
515 
516  auto context = builder.getContext();
517 
518  auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
519  auto affineMap =
520  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
521  Value heightOffset = builder.create<affine::AffineApplyOp>(
522  loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
523  Value widthOffset = builder.create<affine::AffineApplyOp>(
524  loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
525 
526  // Extract (H, W) from (N, H, W, C).
527  auto extractInput =
528  extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
529  widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
530  /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
531 
532  TransformMapKeyTy key = {m, r};
533  int64_t retRows = 1;
534  int64_t retCols = 1;
535  Value matmulRetValue = extractInput;
536  Value zero = builder.create<arith::ConstantOp>(
537  loc, rewriter.getZeroAttr(elementType));
538  if (leftTransform) {
539  // Get constant transform matrix BT.
540  auto it = BTMatrices.find(key);
541  if (it == BTMatrices.end())
542  return {};
543  const TransformMatrix &BTMatrix = it->second;
544 
545  retRows = BTMatrix.rows;
546  auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
547  auto empty =
548  builder
549  .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
550  .getResult();
551  auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
552 
553  Value BT =
554  create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
555  // Multiply BT x d.
556  auto matmulOp = builder.create<linalg::MatmulOp>(
557  loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
558  matmulRetValue = matmulOp.getResult(0);
559  }
560 
561  if (rightTransform) {
562  // Get constant transform matrix B.
563  auto it = BMatrices.find(key);
564  if (it == BMatrices.end())
565  return {};
566  const TransformMatrix &BMatrix = it->second;
567 
568  retCols = BMatrix.cols;
569  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
570  auto empty =
571  builder
572  .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
573  .getResult();
574  auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
575  Value B =
576  create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
577  // Multiply v = (BT x d) x B.
578  auto matmulOp = builder.create<linalg::MatmulOp>(
579  loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
580  matmulRetValue = matmulOp.getResult(0);
581  }
582 
583  // Insert (H, W) to (H, W, tileH, tileW, N, C).
584  auto combinedVal = insert2DDataTo6D(
585  builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
586  CIter, retRows, retCols, 2, 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5,
587  /*heightIdx=*/0, /*widthIdx=*/1);
588 
589  return {combinedVal};
590  };
591 
592  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
593  auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tileH);
594  auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
595  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
596  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
597  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
598  scf::LoopNest loops = scf::buildLoopNest(
599  rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
600  {tileHBound, tileWBound, nUpperBound, cUpperBound},
601  {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
602  return loops.results[0];
603 }
604 
605 /// This function generates linalg.batch_matmul to multiply input with filter.
606 /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
607 /// tileH x tileW x H x W data as the 1-dimensional data array. That is to
608 /// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
609 /// way, we can convert 6-dimensional inputs to 3-dimensional representation
610 /// that is suitable for linalg.batch_matmul.
611 ///
612 /// Batched matmul will do the matrix multiply with the reduction on channel.
613 ///
614 /// We get
615 ///
616 /// %collapsed_input = tensor.collapse_shape %input
617 /// %collapsed_filter = tensor.collapse_shape %filter
618 /// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
619 /// %expanded_ret = tensor.expand_shape %ret
620 ///
621 /// After this function, we get return value with data layout
622 /// (tileH, tileW, H, W, N, F).
623 static Value matrixMultiply(RewriterBase &rewriter, Location loc,
624  Value transformedFilter, Value transformedInput,
625  Type outputElementType) {
626  // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter.
627  auto filterType = cast<ShapedType>(transformedFilter.getType());
628  assert(filterType.hasStaticShape() && "only support static shapes.");
629  ArrayRef<int64_t> filterShape = filterType.getShape();
630  Type filterElementType = filterType.getElementType();
631  auto filterReassocType = RankedTensorType::get(
632  {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
633  filterElementType);
634  SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
635  Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
636  loc, filterReassocType, transformedFilter, filterReassoc);
637 
638  // Convert (alphaH, alphaW, tileH, tileW, N, C) to
639  // (alphaH x alphaW, tileH x tileW x N, C) for input.
640  auto inputType = cast<ShapedType>(transformedInput.getType());
641  assert(inputType.hasStaticShape() && "only support static shapes.");
642  ArrayRef<int64_t> inputShape = inputType.getShape();
643  Type inputElementType = inputType.getElementType();
644  auto inputReassocType = RankedTensorType::get(
645  {inputShape[0] * inputShape[1],
646  inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
647  inputElementType);
648  SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
649  Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
650  loc, inputReassocType, transformedInput, inputReassoc);
651 
652  // Batched matrix multiply.
653  auto matmulType = RankedTensorType::get(
654  {inputShape[0] * inputShape[1],
655  inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
656  outputElementType);
657  Value empty = rewriter
658  .create<tensor::EmptyOp>(loc, matmulType.getShape(),
659  outputElementType)
660  .getResult();
661  Value zero = rewriter.create<arith::ConstantOp>(
662  loc, rewriter.getZeroAttr(outputElementType));
663  Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
664 
665  auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
666  loc, matmulType, ValueRange({collapseInput, collapseFilter}),
667  ValueRange{init});
668 
669  // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
670  // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F).
671  SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
672  auto outputReassocType =
673  RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
674  inputShape[3], inputShape[4], filterShape[3]},
675  outputElementType);
676  auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
677  loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
678  return expandOutput;
679 }
680 
681 /// This function transforms the output. The data layout of the output is HWNF.
682 /// The transformation matrix is 2-dimension. We need to extract H x W from
683 /// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
684 /// After the transformation, we get
685 ///
686 /// scf.for %h = 0 to tileH step 1
687 /// scf.for %w = 0 to tileW step 1
688 /// scf.for %n = 0 to N step 1
689 /// scf.for %f = 0 to F step 1
690 /// %extracted = extract %extracted<alphaH x alphaW> from
691 /// %input<alphaH x alphaW x tileH x tileW x N x F>
692 /// at [0, 0, %h, %w, %n, %f]
693 /// %ret = linalg.matmul AT, %extracted
694 /// %ret = linalg.matmul %ret, A
695 /// %inserted = insert %ret<alphaH x alphaW> into
696 /// output<N x H x W x F>
697 /// at [%n, (%h x m), (%w x m), %f]
698 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
699  Value output, int64_t m, int64_t r,
700  bool leftTransform = true, bool rightTransform = true) {
701  // Map from (m, r) to AT transform matrix.
702  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
703  ATMatrices = {
704  {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
705  {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
706  {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
707  };
708 
709  // Map from (m, r) to A transform matrix.
710  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
711  AMatrices = {
712  {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
713  {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
714  {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
715  };
716 
717  auto valueType = cast<ShapedType>(value.getType());
718  Type elementType = valueType.getElementType();
719  auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
720  int64_t valueH = valueShape[0];
721  int64_t valueW = valueShape[1];
722  int64_t valueN = valueShape[4];
723  int64_t valueF = valueShape[5];
724  int64_t alphaH = leftTransform ? m + r - 1 : 1;
725  int64_t alphaW = rightTransform ? m + r - 1 : 1;
726 
727  if (valueH != alphaH && valueH != 1)
728  return Value();
729  if (valueW != alphaW && valueW != 1)
730  return Value();
731 
732  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
733  ValueRange args) -> scf::ValueVector {
734  auto context = builder.getContext();
735  Value tileHIter = ivs[0];
736  Value tileWIter = ivs[1];
737  Value NIter = ivs[2];
738  Value FIter = ivs[3];
739 
740  // Extract (H, W) from (H, W, tileH, tileW, N, F).
741  auto extractValue =
742  extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
743  FIter, 2, 3, /*loopNorFIdx=*/4,
744  /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
745 
746  const TransformMapKeyTy key = {m, r};
747  const TransformMatrix &AMatrix = AMatrices.at(key);
748  const TransformMatrix &ATMatrix = ATMatrices.at(key);
749  int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
750  (leftTransform ? ATMatrix.scalarFactor : 1);
751  int64_t retCols = rightTransform ? AMatrix.cols : 1;
752  int64_t retRows = leftTransform ? ATMatrix.rows : 1;
753 
754  Value matmulRetValue = extractValue;
755  Value zero = builder.create<arith::ConstantOp>(
756  loc, rewriter.getZeroAttr(elementType));
757 
758  auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
759  auto affineMap =
760  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
761  Value heightOffset = builder.create<affine::AffineApplyOp>(
762  loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
763  Value widthOffset = builder.create<affine::AffineApplyOp>(
764  loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
765 
766  Value outInitVal =
767  extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
768  widthOffset, retRows, retCols,
769  /*loopNorFIdx=*/0,
770  /*loopCorFIdx=*/3, /*heightIdx=*/1,
771  /*widthIdx=*/2);
772  if (leftTransform) {
773  auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
774  Value init = outInitVal;
775  if (rightTransform || scalarFactor != 1) {
776  auto empty = builder
777  .create<tensor::EmptyOp>(loc, matmulType.getShape(),
778  elementType)
779  .getResult();
780  init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
781  }
782 
783  Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
784  // Multiply AT x m.
785  auto matmulOp = builder.create<linalg::MatmulOp>(
786  loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
787  matmulRetValue = matmulOp.getResult(0);
788  }
789 
790  if (rightTransform) {
791  auto matmulType =
792  RankedTensorType::get({retRows, AMatrix.cols}, elementType);
793  Value init = outInitVal;
794  if (scalarFactor != 1) {
795  auto empty = builder
796  .create<tensor::EmptyOp>(loc, matmulType.getShape(),
797  elementType)
798  .getResult();
799  init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
800  }
801 
802  Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
803  // Multiply y = (AT x m) x A.
804  auto matmulOp = builder.create<linalg::MatmulOp>(
805  loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
806  matmulRetValue = matmulOp.getResult(0);
807  }
808 
809  if (scalarFactor != 1) {
810  // Multiply by scalar factor and add outInitVal.
811  Value scalarFactorValue = builder.create<arith::ConstantOp>(
812  loc, FloatAttr::get(elementType, scalarFactor));
813  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
814  auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
815  SmallVector<AffineMap> affineMaps = {
816  AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
817 
818  matmulRetValue =
819  rewriter
820  .create<linalg::GenericOp>(
821  loc, matmulType,
822  ValueRange{scalarFactorValue, matmulRetValue},
823  ValueRange{outInitVal}, affineMaps,
825  utils::IteratorType::parallel,
826  utils::IteratorType::parallel},
827  [&](OpBuilder &nestedBuilder, Location nestedLoc,
828  ValueRange args) {
829  auto mulf = nestedBuilder.create<arith::MulFOp>(
830  nestedLoc, args[0], args[1]);
831  auto addf = nestedBuilder.create<arith::AddFOp>(
832  nestedLoc, mulf.getResult(), args[2]);
833  nestedBuilder.create<linalg::YieldOp>(nestedLoc,
834  addf.getResult());
835  })
836  .getResult(0);
837  }
838 
839  // Insert (H, W) to (N, H, W, F).
840  Value combinedVal =
841  insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
842  heightOffset, widthOffset, retRows, retCols,
843  /*loopNorFIdx=*/0,
844  /*loopCorFIdx=*/3, /*heightIdx=*/1,
845  /*widthIdx=*/2);
846 
847  return {combinedVal};
848  };
849 
850  int64_t tilwH = valueShape[2];
851  int64_t tileW = valueShape[3];
852  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
853  auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tilwH);
854  auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
855  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
856  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
857  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
858  scf::LoopNest loops = scf::buildLoopNest(
859  rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
860  {tileHBound, tileWBound, nUpperBound, fUpperBound},
861  {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
862  return loops.results[0];
863 }
864 
865 /// Create an empty tensor with alignedType and insert the value into the
866 /// created empty tensor with aligned size.
867 static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
868  Value value, ArrayRef<int64_t> alignedShape) {
869  auto valueType = cast<ShapedType>(value.getType());
870  Type elementType = valueType.getElementType();
871  auto alignedType = RankedTensorType::get(alignedShape, elementType);
872  Value padValue = rewriter.create<arith::ConstantOp>(
873  loc, elementType, rewriter.getZeroAttr(elementType));
874 
875  return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value,
876  padValue, false);
877 }
878 
879 /// Extract sub-tensor with extractedType from value.
880 static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
881  Value value,
882  RankedTensorType extractedType) {
883  OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
884  OpFoldResult oneIndex = rewriter.getIndexAttr(1);
885  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
886  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
887 
888  ArrayRef<int64_t> extractedShape = extractedType.getShape();
889  SmallVector<OpFoldResult> sizes =
890  getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape));
891 
892  return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
893  offsets, sizes, strides);
894 }
895 
896 /// Utility function to check all values in the attribute are 1.
897 static bool hasAllOneValues(DenseIntElementsAttr attr) {
898  return llvm::all_of(
899  attr, [](const APInt &element) { return element.getSExtValue() == 1; });
900 }
901 
902 /// A helper function to convert linalg.conv_2d_nhwc_fhwc to
903 /// linalg.winograd_*_transform ops.
904 static FailureOr<Operation *>
905 winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
906  int64_t m, int64_t r) {
907  Value input = convOp.getInputs()[0];
908  Value filter = convOp.getInputs()[1];
909  Value output = convOp.getOutputs()[0];
910  auto inputType = cast<ShapedType>(input.getType());
911  auto filterType = cast<ShapedType>(filter.getType());
912  auto outputType = cast<ShapedType>(output.getType());
913 
914  if (!inputType.hasStaticShape())
915  return rewriter.notifyMatchFailure(convOp,
916  "expected a static shape for the input");
917 
918  if (!filterType.hasStaticShape())
919  return rewriter.notifyMatchFailure(
920  convOp, "expected a static shape for the filter");
921 
922  if (!hasAllOneValues(convOp.getDilations()))
923  return rewriter.notifyMatchFailure(convOp,
924  "expected all ones for dilations");
925 
926  if (!hasAllOneValues(convOp.getStrides()))
927  return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
928 
929  ArrayRef<int64_t> filterShape = filterType.getShape();
930  int64_t filterF = filterShape[0];
931  int64_t filterH = filterShape[1];
932  int64_t filterW = filterShape[2];
933  int64_t filterC = filterShape[3];
934  ArrayRef<int64_t> inputShape = inputType.getShape();
935  int64_t inputN = inputShape[0];
936  int64_t inputH = inputShape[1];
937  int64_t inputW = inputShape[2];
938  int64_t inputC = inputShape[3];
939  ArrayRef<int64_t> outputShape = outputType.getShape();
940  int64_t outputN = outputShape[0];
941  int64_t outputH = outputShape[1];
942  int64_t outputW = outputShape[2];
943  int64_t outputF = outputShape[3];
944 
945  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
946  bool isSupportedFilter = false;
947  if (filterH == filterW && filterH == r)
948  isSupportedFilter = true;
949  if (filterH == r && filterW == 1)
950  isSupportedFilter = true;
951  if (filterH == 1 && filterW == r)
952  isSupportedFilter = true;
953 
954  if (!isSupportedFilter)
955  return rewriter.notifyMatchFailure(
956  convOp, "only support filter (r x r), (r x 1) or (1 x r)");
957 
958  // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5).
959  static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
960  F_2_3, F_4_3, F_2_5};
961 
962  TransformMapKeyTy key = {m, r};
963  auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
964  // If we cannot find the constant transformation matrix, it means we do
965  // not support this configuration yet.
966  if (it == validConfigs.end())
967  return failure();
968 
969  // All the criterias are satisfied. We can do Winograd Conv2D.
970  Location loc = convOp.getLoc();
971 
972  // For F(m x 1, r x 1), we only need to do left side transform.
973  bool leftTransform = filterH != 1;
974  // For F(1 x m, 1 x r), we only need to do right side transform.
975  bool rightTransform = filterW != 1;
976  int64_t heightM = leftTransform ? m : 1;
977  int64_t widthM = rightTransform ? m : 1;
978  int64_t heightR = leftTransform ? r : 1;
979  int64_t widthR = rightTransform ? r : 1;
980 
981  // --- Create operation for filter transform ---
982  Type filterElementType = filterType.getElementType();
983  int64_t alphaH = heightM + heightR - 1;
984  int64_t alphaW = widthM + widthR - 1;
985  int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
986  int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
987  auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF},
988  filterElementType);
989  Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
990  filterElementType);
991  auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
992  loc, retType, filter, retValue, m, r);
993 
994  // --- Create operation for input transform ---
995 
996  // When input size - (r - 1) is not aligned with output tile size, we need to
997  // pad the input data to create the full tiles as tiling.
998  Type inputElementType = inputType.getElementType();
999  int64_t alignedInputH = tileH * heightM + (heightR - 1);
1000  int64_t alignedInputW = tileW * widthM + (widthR - 1);
1001  if (alignedInputH != inputH || alignedInputW != inputW) {
1002  input = padToAlignedTensor(rewriter, loc, input,
1003  {inputN, alignedInputH, alignedInputW, inputC});
1004  }
1005 
1006  retType = RankedTensorType::get(
1007  {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
1008  retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
1009  inputElementType);
1010  auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
1011  loc, retType, input, retValue, m, r);
1012 
1013  Type outputElementType = outputType.getElementType();
1014  Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
1015  transformedInput, outputElementType);
1016 
1017  // --- Create operation for output transform ---
1018 
1019  // When output size is not aligned with output tile size, we need to pad the
1020  // output buffer to insert the full tiles after tiling.
1021  int64_t alignedOutputH = tileH * heightM;
1022  int64_t alignedOutputW = tileW * widthM;
1023  bool isOutputUnaligned =
1024  ((alignedOutputH != outputH) || (alignedOutputW != outputW));
1025  if (isOutputUnaligned) {
1026  auto alignedOutputType = RankedTensorType::get(
1027  {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
1028  output =
1029  padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
1030  outputType = alignedOutputType;
1031  }
1032 
1033  Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
1034  loc, outputType, matmulRet, output, m, r);
1035 
1036  // When output size is not aligned with output tile size, extract the
1037  // value from the padded buffer.
1038  if (isOutputUnaligned) {
1039  transformedOutput = extractFromAlignedTensor(
1040  rewriter, loc, transformedOutput,
1041  RankedTensorType::get({outputN, outputH, outputW, outputF},
1042  outputElementType));
1043  }
1044 
1045  rewriter.replaceOp(convOp, transformedOutput);
1046 
1047  return transformedOutput.getDefiningOp();
1048 }
1049 
1050 /// A helper function to decompose linalg.winograd_filter_transform.
1051 FailureOr<Operation *>
1052 decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
1053  linalg::WinogradFilterTransformOp op) {
1054  Location loc = op.getLoc();
1055  Value filter = op.getFilter();
1056  auto filterType = cast<ShapedType>(filter.getType());
1057  auto filterShape = filterType.getShape();
1058  int64_t filterH = filterShape[1];
1059  int64_t filterW = filterShape[2];
1060 
1061  // For F(m x 1, r x 1), we only need to do left side transform.
1062  bool leftTransform = filterH != 1;
1063  // For F(1 x m, 1 x r), we only need to do right side transform.
1064  bool rightTransform = filterW != 1;
1065  Value transformedFilter =
1066  filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
1067  op.getR(), leftTransform, rightTransform);
1068  if (!transformedFilter)
1069  return failure();
1070 
1071  rewriter.replaceOp(op, transformedFilter);
1072 
1073  return transformedFilter.getDefiningOp();
1074 }
1075 
1076 /// A helper function to decompose linalg.winograd_input_transform.
1077 FailureOr<Operation *>
1078 decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
1079  linalg::WinogradInputTransformOp op) {
1080  Location loc = op.getLoc();
1081  Value output = op.getOutput();
1082  auto outputType = cast<ShapedType>(output.getType());
1083  auto outputShape = outputType.getShape();
1084 
1085  int64_t outputH = outputShape[0];
1086  int64_t outputW = outputShape[1];
1087 
1088  // For F(m x 1, r x 1), we only need to do left side transform.
1089  bool leftTransform = outputH != 1;
1090  // For F(1 x m, 1 x r), we only need to do right side transform.
1091  bool rightTransform = outputW != 1;
1092  Value transformedInput =
1093  inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
1094  op.getR(), leftTransform, rightTransform);
1095  if (!transformedInput)
1096  return failure();
1097 
1098  rewriter.replaceOp(op, transformedInput);
1099 
1100  return transformedInput.getDefiningOp();
1101 }
1102 
1103 /// A helper function to decompose linalg.winograd_output_transform.
1104 FailureOr<Operation *>
1105 decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
1106  linalg::WinogradOutputTransformOp op) {
1107  Location loc = op.getLoc();
1108  Value value = op.getValue();
1109  auto valueType = cast<ShapedType>(value.getType());
1110  auto valueShape = valueType.getShape();
1111  int64_t valueH = valueShape[0];
1112  int64_t valueW = valueShape[1];
1113 
1114  // For F(m x 1, r x 1), we only need to do left side transform.
1115  bool leftTransform = valueH != 1;
1116  // For F(1 x m, 1 x r), we only need to do right side transform.
1117  bool rightTransform = valueW != 1;
1118  Value transformedOutput =
1119  outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
1120  op.getR(), leftTransform, rightTransform);
1121  if (!transformedOutput)
1122  return failure();
1123 
1124  rewriter.replaceOp(op, transformedOutput);
1125 
1126  return transformedOutput.getDefiningOp();
1127 }
1128 
1129 /// A rewrite pattern to decompose linalg.winograd_filter_transform operations.
1130 class DecomposeWinogradFilterTransform final
1131  : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
1132 public:
1134 
1135  LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
1136  PatternRewriter &rewriter) const override {
1137  return decomposeWinogradFilterTransformHelper(rewriter, op);
1138  }
1139 };
1140 
1141 /// A rewrite pattern to decompose linalg.winograd_input_transform operations.
1142 class DecomposeWinogradInputTransform final
1143  : public OpRewritePattern<linalg::WinogradInputTransformOp> {
1144 public:
1146 
1147  LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
1148  PatternRewriter &rewriter) const override {
1149  return decomposeWinogradInputTransformHelper(rewriter, op);
1150  }
1151 };
1152 
1153 /// A rewrite pattern to decompose linalg.winograd_output_transform operations.
1154 class DecomposeWinogradOutputTransform final
1155  : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
1156 public:
1158 
1159  LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
1160  PatternRewriter &rewriter) const override {
1161  return decomposeWinogradOutputTransformHelper(rewriter, op);
1162  }
1163 };
1164 
1165 /// A rewrite pattern for Winograd Conv2D algorithm.
1166 class WinogradConv2DNhwcFhwc final
1167  : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
1168 public:
1170  WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
1171  : OpRewritePattern(context), m(m), r(r) {}
1172 
1173  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
1174  PatternRewriter &rewriter) const override {
1175  if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
1176  return failure();
1177 
1178  return success();
1179  }
1180 
1181 private:
1182  int64_t m;
1183  int64_t r;
1184 };
1185 } // end anonymous namespace
1186 
1187 //===----------------------------------------------------------------------===//
1188 FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1189  linalg::Conv2DNhwcFhwcOp op, int64_t m,
1190  int64_t r) {
1191  return winogradConv2DHelper(rewriter, op, m, r);
1192 }
1193 
1194 FailureOr<Operation *>
1196  linalg::WinogradFilterTransformOp op) {
1197  return decomposeWinogradFilterTransformHelper(rewriter, op);
1198 }
1199 
1200 FailureOr<Operation *>
1202  linalg::WinogradInputTransformOp op) {
1203  return decomposeWinogradInputTransformHelper(rewriter, op);
1204 }
1205 
1206 FailureOr<Operation *>
1208  linalg::WinogradOutputTransformOp op) {
1209  return decomposeWinogradOutputTransformHelper(rewriter, op);
1210 }
1211 
1213  int64_t r) {
1214  MLIRContext *context = patterns.getContext();
1215  // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw
1216  patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
1217 }
1218 
1220  MLIRContext *context = patterns.getContext();
1221  patterns
1222  .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
1223  DecomposeWinogradOutputTransform>(context);
1224 }
1225 
1226 } // end namespace linalg
1227 } // end namespace mlir
int64_t scalarFactor
const float * table
int64_t cols
int64_t rows
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns)
Patterns to decompose Winograd operators.
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold)
Create a tensor::PadOp that pads source to the size of the statically sized type whose static sizes a...
Definition: Utils.cpp:246
static bool hasAllOneValues(DenseIntElementsAttr attr)
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, int64_t r)
Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
@ Type
An inlay hint that for a type annotation.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
ValueVector results
Definition: SCF.h:68