MLIR  20.0.0git
WinogradConv2D.cpp
Go to the documentation of this file.
1 //===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implement Winograd Conv2D algorithm. The implementation is based on the
10 // paper: Fast Algorithms for Convolutional Neural Networks
11 // (https://arxiv.org/abs/1509.09308)
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "llvm/Support/MathExtras.h"
23 
24 namespace mlir {
25 namespace linalg {
26 
27 namespace {
28 
29 // clang-format off
30 /// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
31 /// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
32 /// m is the output dimension and r is the filter dimension, is
33 ///
34 /// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
35 ///
36 /// g is filter and d is input data. We need to prepare 6 constant
37 /// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
38 ///
39 /// The following tables define these constant transformation matrices for
40 /// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
41 constexpr float G_2x2_3x3[] = {
42  -1, 0, 0,
43  1./2, -1./2, 1./2,
44  1./2, 1./2, 1./2,
45  0, 0, 1
46 };
47 
48 constexpr float GT_2x2_3x3[] = {
49  -1, 1./2, 1./2, 0,
50  0, -1./2, 1./2, 0,
51  0, 1./2, 1./2, 1
52 };
53 
54 constexpr float BT_2x2_3x3[] = {
55  -1, 0, 1, 0,
56  0, -1, 1, 0,
57  0, 1, 1, 0,
58  0, -1, 0, 1
59 };
60 
61 constexpr float B_2x2_3x3[] = {
62  -1, 0, 0, 0,
63  0, -1, 1, -1,
64  1, 1, 1, 0,
65  0, 0, 0, 1
66 };
67 
68 constexpr float AT_2x2_3x3[] = {
69  1, 1, 1, 0,
70  0, -1, 1, 1
71 };
72 
73 constexpr float A_2x2_3x3[] = {
74  1, 0,
75  1, -1,
76  1, 1,
77  0, 1
78 };
79 
80 constexpr float G_4x4_3x3[] = {
81  1, 0, 0,
82  -1./3, 1./3, -1./3,
83  -1./3, -1./3, -1./3,
84  1./12, -1./6, 1./3,
85  1./12, 1./6, 1./3,
86  0, 0, 1
87 };
88 
89 constexpr float GT_4x4_3x3[] = {
90  1, -1./3, -1./3, 1./12, 1./12, 0,
91  0, 1./3, -1./3, -1./6, 1./6, 0,
92  0, -1./3, -1./3, 1./3, 1./3, 1
93 };
94 
95 constexpr float BT_4x4_3x3[] = {
96  1./4, 0, -5./16, 0, 1./16, 0,
97  0, 1./4, -1./4, -1./16, 1./16, 0,
98  0, -1./4, -1./4, 1./16, 1./16, 0,
99  0, 1./4, -1./8, -1./4, 1./8, 0,
100  0, -1./4, -1./8, 1./4, 1./8, 0,
101  0, 1./4, 0, -5./16, 0, 1./16
102 };
103 
104 constexpr float B_4x4_3x3[] = {
105  1./4, 0, 0, 0, 0, 0,
106  0, 1./4, -1./4, 1./4, -1./4, 1./4,
107  -5./16, -1./4, -1./4, -1./8, -1./8, 0,
108  0, -1./16, 1./16, -1./4, 1./4, -5./16,
109  1./16, 1./16, 1./16, 1./8, 1./8, 0,
110  0, 0, 0, 0, 0, 1./16
111 };
112 
113 constexpr float AT_4x4_3x3[] = {
114  1./8, 1./4, 1./4, 1./8, 1./8, 0,
115  0, -1./4, 1./4, -1./4, 1./4, 0,
116  0, 1./4, 1./4, 1./2, 1./2, 0,
117  0, -1./4, 1./4, -1, 1, 1./2
118 };
119 
120 constexpr float A_4x4_3x3[] = {
121  1./8, 0, 0, 0,
122  1./4, -1./4, 1./4, -1./4,
123  1./4, 1./4, 1./4, 1./4,
124  1./8, -1./4, 1./2, -1,
125  1./8, 1./4, 1./2, 1,
126  0, 0, 0, 1./2
127 };
128 
129 constexpr float G_2x2_5x5[] = {
130  1, 0, 0, 0, 0,
131  1./6, -1./6, 1./6, -1./6, 1./6,
132  -1./6, -1./6, -1./6, -1./6, -1./6,
133 -4./15, 2./15, -1./15, 1./30, -1./60,
134  1./60, 1./30, 1./15, 2./15, 4./15,
135  0, 0, 0, 0, 1
136 };
137 
138 constexpr float GT_2x2_5x5[] = {
139  1, 1./6, -1./6, -4./15, 1./60, 0,
140  0, -1./6, -1./6, 2./15, 1./30, 0,
141  0, 1./6, -1./6, -1./15, 1./15, 0,
142  0, -1./6, -1./6, 1./30, 2./15, 0,
143  0, 1./6, -1./6, -1./60, 4./15, 1
144 };
145 
146 constexpr float BT_2x2_5x5[] = {
147  1./8, 3./16, -1./4, -3./16, 1./8, 0,
148  0, 1./8, 1./16, -5./16, 1./8, 0,
149  0, -1./8, -5./16, -1./16, 1./8, 0,
150  0, 1./4, -1./8, -1./4, 1./8, 0,
151  0, -1./8, -1./4, 1./8, 1./4, 0,
152  0, 1./8, 3./16, -1./4, -3./16, 1./8
153 };
154 
155 constexpr float B_2x2_5x5[] = {
156  1./8, 0, 0, 0, 0, 0,
157  3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
158  -1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
159  -3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
160  1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
161  0, 0, 0, 0, 0, 1./8
162 };
163 
164 constexpr float AT_2x2_5x5[] = {
165  1./2, 1, 1, 2, 1, 0,
166  0, -1, 1, -1, 2, 1./2
167 };
168 
169 constexpr float A_2x2_5x5[] = {
170  1./2, 0,
171  1, -1,
172  1, 1,
173  2, -1,
174  1, 2,
175  0, 1./2
176 };
177 // clang-format on
178 
179 using TransformMapKeyTy = std::pair<int, int>;
180 
181 /// We use F(m, r) to define the size of minimal filtering algorithms.
182 /// m is the output dimension and r is the filter dimension. We can get
183 /// the input dimension, alpha, from the formula, alpha = m + r - 1.
184 ///
185 /// For example, when m = 2 and r = 3, we know its input size is 4.
186 /// The Conv2D will operate on 4x4 input data with 3x3 filter and get
187 /// 2x2 output result.
188 constexpr TransformMapKeyTy F_2_3{2, 3};
189 constexpr TransformMapKeyTy F_4_3{4, 3};
190 constexpr TransformMapKeyTy F_2_5{2, 5};
191 
192 /// Structure to keep information of constant transform matrices.
193 struct TransformMatrix {
194  TransformMatrix(const float *table, int64_t rows, int64_t cols,
195  int64_t scalarFactor = 1)
197 
198  const float *table;
199  int64_t rows;
200  int64_t cols;
201  int64_t scalarFactor;
202 };
203 
204 /// Utility function to convert constant array to arith.constant Value.
205 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
206  TransformMatrix transform, Type type) {
207  ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
208 
209  return builder.create<arith::ConstantOp>(
212  SmallVector<int64_t>{transform.rows, transform.cols}, type),
213  constVec));
214 }
215 
216 /// Extract height x width data from 4D tensors.
217 Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
218  Value loopNorFIndex, Value loopCorFIndex,
219  Value heightOffset, Value widthOffset,
220  int64_t extractHeight, int64_t extractWidth,
221  int64_t loopNorFIdx, int64_t loopCorFIdx,
222  int64_t heightIdx, int64_t widthIdx) {
223  auto sourceType = cast<ShapedType>(source.getType());
224  Type elementType = sourceType.getElementType();
225  int64_t srcSize = sourceType.getRank();
226 
227  auto oneIndex = builder.getIndexAttr(1);
228  SmallVector<OpFoldResult> offsets;
229  offsets.resize(srcSize);
230  offsets[loopNorFIdx] = loopNorFIndex;
231  offsets[loopCorFIdx] = loopCorFIndex;
232  offsets[heightIdx] = heightOffset;
233  offsets[widthIdx] = widthOffset;
234  SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
235  sizes[heightIdx] = builder.getIndexAttr(extractHeight);
236  sizes[widthIdx] = builder.getIndexAttr(extractWidth);
237  SmallVector<OpFoldResult> strides(srcSize, oneIndex);
238 
239  auto extractFilterType =
240  RankedTensorType::get({extractHeight, extractWidth}, elementType);
241  auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
242  loc, extractFilterType, source, offsets, sizes, strides);
243 
244  return extractFilterOp;
245 }
246 
247 /// Extract height x width data from 6D tensors.
248 Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
249  Value tileHIndex, Value tileWIndex,
250  Value loopNorFIndex, Value loopCorFIndex,
251  int64_t tileHIdx, int64_t tileWIdx,
252  int64_t loopNorFIdx, int64_t loopCorFIdx,
253  int64_t heightIdx, int64_t widthIdx) {
254  auto sourceType = cast<ShapedType>(source.getType());
255  Type elementType = sourceType.getElementType();
256  auto sourceShape = sourceType.getShape();
257  int64_t srcSize = sourceType.getRank();
258  int64_t height = sourceShape[heightIdx];
259  int64_t width = sourceShape[widthIdx];
260 
261  auto zeroIndex = builder.getIndexAttr(0);
262  auto oneIndex = builder.getIndexAttr(1);
263  SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
264  offsets.resize(srcSize);
265  offsets[tileHIdx] = tileHIndex;
266  offsets[tileWIdx] = tileWIndex;
267  offsets[loopNorFIdx] = loopNorFIndex;
268  offsets[loopCorFIdx] = loopCorFIndex;
269  SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
270  sizes[heightIdx] = builder.getIndexAttr(height);
271  sizes[widthIdx] = builder.getIndexAttr(width);
272  SmallVector<OpFoldResult> strides(srcSize, oneIndex);
273 
274  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
275  auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
276  loc, extractFilterType, source, offsets, sizes, strides);
277 
278  return extractFilterOp;
279 }
280 
281 /// Insert transformed height x width data to 4D tensors which it is
282 /// extracted from.
283 Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
284  Value dest, Value loopNorFIndex, Value loopCorFIndex,
285  Value heightOffset, Value widthOffset, int64_t height,
286  int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
287  int64_t heightIdx, int64_t widthIdx) {
288  int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
289  auto oneIndex = builder.getIndexAttr(1);
290  SmallVector<OpFoldResult> retOffsets;
291  retOffsets.resize(destSize);
292  retOffsets[loopNorFIdx] = loopNorFIndex;
293  retOffsets[loopCorFIdx] = loopCorFIndex;
294  retOffsets[heightIdx] = heightOffset;
295  retOffsets[widthIdx] = widthOffset;
296  SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
297  retSizes[heightIdx] = builder.getIndexAttr(height);
298  retSizes[widthIdx] = builder.getIndexAttr(width);
299  SmallVector<OpFoldResult> strides(destSize, oneIndex);
300 
301  auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
302  loc, source, dest, retOffsets, retSizes, strides);
303 
304  return insertSliceOp;
305 }
306 
307 /// Insert transformed height x width data to 6D tensors which it is
308 /// extracted from.
309 Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
310  Value dest, Value tileHIndex, Value tileWIndex,
311  Value loopNorFIndex, Value loopCorFIndex, int64_t height,
312  int64_t width, int64_t tileHIdx, int64_t tileWIdx,
313  int64_t loopNorFIdx, int64_t loopCorFIdx,
314  int64_t heightIdx, int64_t widthIdx) {
315  int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
316  auto zeroIndex = builder.getIndexAttr(0);
317  auto oneIndex = builder.getIndexAttr(1);
318  SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
319  retOffsets.resize(destSize);
320  retOffsets[tileHIdx] = tileHIndex;
321  retOffsets[tileWIdx] = tileWIndex;
322  retOffsets[loopNorFIdx] = loopNorFIndex;
323  retOffsets[loopCorFIdx] = loopCorFIndex;
324  SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
325  retSizes[heightIdx] = builder.getIndexAttr(height);
326  retSizes[widthIdx] = builder.getIndexAttr(width);
327  SmallVector<OpFoldResult> strides(destSize, oneIndex);
328 
329  auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
330  loc, source, dest, retOffsets, retSizes, strides);
331 
332  return insertSliceOp;
333 }
334 
335 /// This function transforms the filter. The data layout of the filter is FHWC.
336 /// The transformation matrix is 2-dimension. We need to extract H x W from
337 /// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
338 /// After the transformation, we get
339 ///
340 /// scf.for %f = lo_f to hi_f step 1
341 /// scf.for %c = lo_c to hi_c step 1
342 /// %extracted = extract filter<h x w> from filter<f x h x w x c>
343 /// %ret = linalg.matmul G, %extracted
344 /// %ret = linalg.matmul %ret, GT
345 /// %inserted = insert %ret into filter<h x w x c x f>
346 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
347  Value retValue, int64_t m, int64_t r,
348  bool leftTransform = true, bool rightTransform = true) {
349  // Map from (m, r) to G transform matrix.
350  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
351  GMatrices = {
352  {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
353  {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
354  {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
355  };
356 
357  // Map from (m, r) to GT transform matrix.
358  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
359  GTMatrices = {
360  {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
361  {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
362  {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
363  };
364 
365  auto filterType = cast<ShapedType>(filter.getType());
366  Type elementType = filterType.getElementType();
367  auto filterShape = filterType.getShape(); // F, H, W, C
368  int64_t filterF = filterShape[0];
369  int64_t filterH = filterShape[1];
370  int64_t filterW = filterShape[2];
371  int64_t filterC = filterShape[3];
372 
373  if (filterH != r && filterH != 1)
374  return Value();
375  if (filterW != r && filterW != 1)
376  return Value();
377 
378  Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
379  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
380  ValueRange args) -> scf::ValueVector {
381  Value FIter = ivs[0];
382  Value CIter = ivs[1];
383 
384  // Extract (H, W) from (F, H, W, C).
385  auto extractFilter =
386  extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
387  zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
388  /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
389 
390  TransformMapKeyTy key = {m, r};
391  int64_t retRows = 1;
392  Value matmulRetValue = extractFilter;
393  Value zero = builder.create<arith::ConstantOp>(
394  loc, rewriter.getZeroAttr(elementType));
395  if (leftTransform) {
396  // Get constant transform matrix G.
397  auto it = GMatrices.find(key);
398  if (it == GMatrices.end())
399  return {};
400  const TransformMatrix &GMatrix = it->second;
401 
402  retRows = GMatrix.rows;
403  auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
404  auto empty =
405  builder
406  .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
407  .getResult();
408  auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
409 
410  Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
411  // Multiply G x g.
412  auto matmulOp = builder.create<linalg::MatmulOp>(
413  loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
414  matmulRetValue = matmulOp.getResult(0);
415  }
416 
417  if (rightTransform) {
418  // Get constant transform matrix GT.
419  auto it = GTMatrices.find(key);
420  if (it == GTMatrices.end())
421  return {};
422  const TransformMatrix &GTMatrix = it->second;
423 
424  auto matmulType =
425  RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
426  auto empty =
427  builder
428  .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
429  .getResult();
430  auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
431 
432  Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
433  // Multiply u = (G x g) x GT.
434  auto matmulOp = builder.create<linalg::MatmulOp>(
435  loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
436  matmulRetValue = matmulOp.getResult(0);
437  }
438 
439  // Insert (H, W) to (H, W, C, F).
440  int64_t retHeight = leftTransform ? m + r - 1 : 1;
441  int64_t retWidth = rightTransform ? m + r - 1 : 1;
442 
443  auto insertSliceOp =
444  insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
445  zeroIdx, zeroIdx, retHeight, retWidth,
446  /*loopNorFIdx=*/3, /*loopCorFIdx=*/2,
447  /*heightIdx=*/0, /*widthIdx=*/1);
448 
449  return {insertSliceOp};
450  };
451 
452  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
453  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
454  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
455  scf::LoopNest loops = scf::buildLoopNest(
456  rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
457  {oneStep, oneStep}, {retValue}, buildBody);
458  return loops.results[0];
459 }
460 
461 /// This function transforms the input. The data layout of the input is NHWC.
462 /// The transformation matrix is 2-dimension. We need to extract H x W from
463 /// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
464 /// After the transformation, we get
465 ///
466 /// scf.for %h = 0 to tileH step 1
467 /// scf.for %w = 0 to tileW step 1
468 /// scf.for %n = 0 to N step 1
469 /// scf.for %c = 0 to C step 1
470 /// %extracted = extract %extracted<alphaH x alphaW> from
471 /// %input<N x H x W x C>
472 /// at [%n, (%h x m), (%w x m), %c]
473 /// %ret = linalg.matmul BT, %extracted
474 /// %ret = linalg.matmul %ret, B
475 /// %inserted = insert %ret<alphaH x alphaW> into
476 /// %output<alphaH x alphaW x tileH x tileW x N x C>
477 /// at [0, 0, %h, %w, %n, %c]
478 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
479  Value retValue, int64_t m, int64_t r,
480  bool leftTransform = true, bool rightTransform = true) {
481  // Map from (m, r) to BT transform matrix.
482  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
483  BTMatrices = {
484  {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
485  {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
486  {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
487  };
488 
489  // Map from (m, r) to B transform matrix.
490  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
491  BMatrices = {
492  {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
493  {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
494  {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
495  };
496 
497  auto inputType = cast<ShapedType>(input.getType());
498  Type elementType = inputType.getElementType();
499  auto inputShape = inputType.getShape(); // N, H, W, C
500  int64_t inputN = inputShape[0];
501  int64_t inputC = inputShape[3];
502  auto valueType = cast<ShapedType>(retValue.getType());
503  auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C
504  int64_t tileH = valueShape[2];
505  int64_t tileW = valueShape[3];
506  int64_t alphaH = leftTransform ? m + r - 1 : 1;
507  int64_t alphaW = rightTransform ? m + r - 1 : 1;
508 
509  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
510  ValueRange args) -> scf::ValueVector {
511  Value tileHIter = ivs[0];
512  Value tileWIter = ivs[1];
513  Value NIter = ivs[2];
514  Value CIter = ivs[3];
515 
516  auto context = builder.getContext();
517  auto affineMap =
518  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
519  Value heightOffset =
520  builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
521  Value widthOffset =
522  builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
523 
524  // Extract (H, W) from (N, H, W, C).
525  auto extractInput =
526  extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
527  widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
528  /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
529 
530  TransformMapKeyTy key = {m, r};
531  int64_t retRows = 1;
532  int64_t retCols = 1;
533  Value matmulRetValue = extractInput;
534  Value zero = builder.create<arith::ConstantOp>(
535  loc, rewriter.getZeroAttr(elementType));
536  if (leftTransform) {
537  // Get constant transform matrix BT.
538  auto it = BTMatrices.find(key);
539  if (it == BTMatrices.end())
540  return {};
541  const TransformMatrix &BTMatrix = it->second;
542 
543  retRows = BTMatrix.rows;
544  auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
545  auto empty =
546  builder
547  .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
548  .getResult();
549  auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
550 
551  Value BT =
552  create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
553  // Multiply BT x d.
554  auto matmulOp = builder.create<linalg::MatmulOp>(
555  loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
556  matmulRetValue = matmulOp.getResult(0);
557  }
558 
559  if (rightTransform) {
560  // Get constant transform matrix B.
561  auto it = BMatrices.find(key);
562  if (it == BMatrices.end())
563  return {};
564  const TransformMatrix &BMatrix = it->second;
565 
566  retCols = BMatrix.cols;
567  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
568  auto empty =
569  builder
570  .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
571  .getResult();
572  auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
573  Value B =
574  create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
575  // Multiply v = (BT x d) x B.
576  auto matmulOp = builder.create<linalg::MatmulOp>(
577  loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init});
578  matmulRetValue = matmulOp.getResult(0);
579  }
580 
581  // Insert (H, W) to (H, W, tileH, tileW, N, C).
582  auto combinedVal = insert2DDataTo6D(
583  builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
584  CIter, retRows, retCols, 2, 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5,
585  /*heightIdx=*/0, /*widthIdx=*/1);
586 
587  return {combinedVal};
588  };
589 
590  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
591  auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tileH);
592  auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
593  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
594  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
595  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
596  scf::LoopNest loops = scf::buildLoopNest(
597  rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
598  {tileHBound, tileWBound, nUpperBound, cUpperBound},
599  {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
600  return loops.results[0];
601 }
602 
603 /// This function generates linalg.batch_matmul to multiply input with filter.
604 /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
605 /// tileH x tileW x H x W data as the 1-dimensional data array. That is to
606 /// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
607 /// way, we can convert 6-dimensional inputs to 3-dimensional representation
608 /// that is suitable for linalg.batch_matmul.
609 ///
610 /// Batched matmul will do the matrix multiply with the reduction on channel.
611 ///
612 /// We get
613 ///
614 /// %collapsed_input = tensor.collapse_shape %input
615 /// %collapsed_filter = tensor.collapse_shape %filter
616 /// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
617 /// %expanded_ret = tensor.expand_shape %ret
618 ///
619 /// After this function, we get return value with data layout
620 /// (tileH, tileW, H, W, N, F).
621 static Value matrixMultiply(RewriterBase &rewriter, Location loc,
622  Value transformedFilter, Value transformedInput,
623  Type outputElementType) {
624  // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter.
625  auto filterType = cast<ShapedType>(transformedFilter.getType());
626  assert(filterType.hasStaticShape() && "only support static shapes.");
627  ArrayRef<int64_t> filterShape = filterType.getShape();
628  Type filterElementType = filterType.getElementType();
629  auto filterReassocType = RankedTensorType::get(
630  {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
631  filterElementType);
632  SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
633  Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
634  loc, filterReassocType, transformedFilter, filterReassoc);
635 
636  // Convert (alphaH, alphaW, tileH, tileW, N, C) to
637  // (alphaH x alphaW, tileH x tileW x N, C) for input.
638  auto inputType = cast<ShapedType>(transformedInput.getType());
639  assert(inputType.hasStaticShape() && "only support static shapes.");
640  ArrayRef<int64_t> inputShape = inputType.getShape();
641  Type inputElementType = inputType.getElementType();
642  auto inputReassocType = RankedTensorType::get(
643  {inputShape[0] * inputShape[1],
644  inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
645  inputElementType);
646  SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
647  Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
648  loc, inputReassocType, transformedInput, inputReassoc);
649 
650  // Batched matrix multiply.
651  auto matmulType = RankedTensorType::get(
652  {inputShape[0] * inputShape[1],
653  inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
654  outputElementType);
655  Value empty = rewriter
656  .create<tensor::EmptyOp>(loc, matmulType.getShape(),
657  outputElementType)
658  .getResult();
659  Value zero = rewriter.create<arith::ConstantOp>(
660  loc, rewriter.getZeroAttr(outputElementType));
661  Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
662 
663  auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
664  loc, matmulType, ValueRange({collapseInput, collapseFilter}),
665  ValueRange{init});
666 
667  // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
668  // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F).
669  SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
670  auto outputReassocType =
671  RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
672  inputShape[3], inputShape[4], filterShape[3]},
673  outputElementType);
674  auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
675  loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
676  return expandOutput;
677 }
678 
679 /// This function transforms the output. The data layout of the output is HWNF.
680 /// The transformation matrix is 2-dimension. We need to extract H x W from
681 /// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
682 /// After the transformation, we get
683 ///
684 /// scf.for %h = 0 to tileH step 1
685 /// scf.for %w = 0 to tileW step 1
686 /// scf.for %n = 0 to N step 1
687 /// scf.for %f = 0 to F step 1
688 /// %extracted = extract %extracted<alphaH x alphaW> from
689 /// %input<alphaH x alphaW x tileH x tileW x N x F>
690 /// at [0, 0, %h, %w, %n, %f]
691 /// %ret = linalg.matmul AT, %extracted
692 /// %ret = linalg.matmul %ret, A
693 /// %inserted = insert %ret<alphaH x alphaW> into
694 /// output<N x H x W x F>
695 /// at [%n, (%h x m), (%w x m), %f]
696 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
697  Value output, int64_t m, int64_t r,
698  bool leftTransform = true, bool rightTransform = true) {
699  // Map from (m, r) to AT transform matrix.
700  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
701  ATMatrices = {
702  {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
703  {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
704  {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
705  };
706 
707  // Map from (m, r) to A transform matrix.
708  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
709  AMatrices = {
710  {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
711  {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
712  {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
713  };
714 
715  auto valueType = cast<ShapedType>(value.getType());
716  Type elementType = valueType.getElementType();
717  auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
718  int64_t valueH = valueShape[0];
719  int64_t valueW = valueShape[1];
720  int64_t valueN = valueShape[4];
721  int64_t valueF = valueShape[5];
722  int64_t alphaH = leftTransform ? m + r - 1 : 1;
723  int64_t alphaW = rightTransform ? m + r - 1 : 1;
724 
725  if (valueH != alphaH && valueH != 1)
726  return Value();
727  if (valueW != alphaW && valueW != 1)
728  return Value();
729 
730  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
731  ValueRange args) -> scf::ValueVector {
732  auto context = builder.getContext();
733  Value tileHIter = ivs[0];
734  Value tileWIter = ivs[1];
735  Value NIter = ivs[2];
736  Value FIter = ivs[3];
737 
738  // Extract (H, W) from (H, W, tileH, tileW, N, F).
739  auto extractValue =
740  extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
741  FIter, 2, 3, /*loopNorFIdx=*/4,
742  /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
743 
744  const TransformMapKeyTy key = {m, r};
745  const TransformMatrix &AMatrix = AMatrices.at(key);
746  const TransformMatrix &ATMatrix = ATMatrices.at(key);
747  int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
748  (leftTransform ? ATMatrix.scalarFactor : 1);
749  int64_t retCols = rightTransform ? AMatrix.cols : 1;
750  int64_t retRows = leftTransform ? ATMatrix.rows : 1;
751 
752  Value matmulRetValue = extractValue;
753  Value zero = builder.create<arith::ConstantOp>(
754  loc, rewriter.getZeroAttr(elementType));
755 
756  auto affineMap =
757  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
758  Value heightOffset =
759  builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
760  Value widthOffset =
761  builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
762 
763  Value outInitVal =
764  extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
765  widthOffset, retRows, retCols,
766  /*loopNorFIdx=*/0,
767  /*loopCorFIdx=*/3, /*heightIdx=*/1,
768  /*widthIdx=*/2);
769  if (leftTransform) {
770  auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
771  Value init = outInitVal;
772  if (rightTransform || scalarFactor != 1) {
773  auto empty = builder
774  .create<tensor::EmptyOp>(loc, matmulType.getShape(),
775  elementType)
776  .getResult();
777  init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
778  }
779 
780  Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
781  // Multiply AT x m.
782  auto matmulOp = builder.create<linalg::MatmulOp>(
783  loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
784  matmulRetValue = matmulOp.getResult(0);
785  }
786 
787  if (rightTransform) {
788  auto matmulType =
789  RankedTensorType::get({retRows, AMatrix.cols}, elementType);
790  Value init = outInitVal;
791  if (scalarFactor != 1) {
792  auto empty = builder
793  .create<tensor::EmptyOp>(loc, matmulType.getShape(),
794  elementType)
795  .getResult();
796  init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
797  }
798 
799  Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
800  // Multiply y = (AT x m) x A.
801  auto matmulOp = builder.create<linalg::MatmulOp>(
802  loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init});
803  matmulRetValue = matmulOp.getResult(0);
804  }
805 
806  if (scalarFactor != 1) {
807  // Multiply by scalar factor and add outInitVal.
808  Value scalarFactorValue = builder.create<arith::ConstantOp>(
809  loc, FloatAttr::get(elementType, scalarFactor));
810  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
811  auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
812  SmallVector<AffineMap> affineMaps = {
813  AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
814 
815  matmulRetValue =
816  rewriter
817  .create<linalg::GenericOp>(
818  loc, matmulType,
819  ValueRange{scalarFactorValue, matmulRetValue},
820  ValueRange{outInitVal}, affineMaps,
822  utils::IteratorType::parallel,
823  utils::IteratorType::parallel},
824  [&](OpBuilder &nestedBuilder, Location nestedLoc,
825  ValueRange args) {
826  auto mulf = nestedBuilder.create<arith::MulFOp>(
827  nestedLoc, args[0], args[1]);
828  auto addf = nestedBuilder.create<arith::AddFOp>(
829  nestedLoc, mulf.getResult(), args[2]);
830  nestedBuilder.create<linalg::YieldOp>(nestedLoc,
831  addf.getResult());
832  })
833  .getResult(0);
834  }
835 
836  // Insert (H, W) to (N, H, W, F).
837  Value combinedVal =
838  insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
839  heightOffset, widthOffset, retRows, retCols,
840  /*loopNorFIdx=*/0,
841  /*loopCorFIdx=*/3, /*heightIdx=*/1,
842  /*widthIdx=*/2);
843 
844  return {combinedVal};
845  };
846 
847  int64_t tilwH = valueShape[2];
848  int64_t tileW = valueShape[3];
849  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
850  auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tilwH);
851  auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
852  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
853  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
854  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
855  scf::LoopNest loops = scf::buildLoopNest(
856  rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
857  {tileHBound, tileWBound, nUpperBound, fUpperBound},
858  {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
859  return loops.results[0];
860 }
861 
862 /// Create an empty tensor with alignedType and insert the value into the
863 /// created empty tensor with aligned size.
864 static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
865  Value value, ArrayRef<int64_t> alignedShape) {
866  auto valueType = cast<ShapedType>(value.getType());
867  Type elementType = valueType.getElementType();
868  auto alignedType = RankedTensorType::get(alignedShape, elementType);
869  Value padValue = rewriter.create<arith::ConstantOp>(
870  loc, elementType, rewriter.getZeroAttr(elementType));
871 
872  return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value,
873  padValue, false);
874 }
875 
876 /// Extract sub-tensor with extractedType from value.
877 static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
878  Value value,
879  RankedTensorType extractedType) {
880  OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
881  OpFoldResult oneIndex = rewriter.getIndexAttr(1);
882  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
883  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
884 
885  ArrayRef<int64_t> extractedShape = extractedType.getShape();
886  SmallVector<OpFoldResult> sizes =
887  getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape));
888 
889  return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
890  offsets, sizes, strides);
891 }
892 
893 /// Utility function to check all values in the attribute are 1.
894 static bool hasAllOneValues(DenseIntElementsAttr attr) {
895  return llvm::all_of(
896  attr, [](const APInt &element) { return element.getSExtValue() == 1; });
897 }
898 
899 /// A helper function to convert linalg.conv_2d_nhwc_fhwc to
900 /// linalg.winograd_*_transform ops.
901 static FailureOr<Operation *>
902 winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
903  int64_t m, int64_t r) {
904  Value input = convOp.getInputs()[0];
905  Value filter = convOp.getInputs()[1];
906  Value output = convOp.getOutputs()[0];
907  auto inputType = cast<ShapedType>(input.getType());
908  auto filterType = cast<ShapedType>(filter.getType());
909  auto outputType = cast<ShapedType>(output.getType());
910 
911  if (!inputType.hasStaticShape())
912  return rewriter.notifyMatchFailure(convOp,
913  "expected a static shape for the input");
914 
915  if (!filterType.hasStaticShape())
916  return rewriter.notifyMatchFailure(
917  convOp, "expected a static shape for the filter");
918 
919  if (!hasAllOneValues(convOp.getDilations()))
920  return rewriter.notifyMatchFailure(convOp,
921  "expected all ones for dilations");
922 
923  if (!hasAllOneValues(convOp.getStrides()))
924  return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
925 
926  ArrayRef<int64_t> filterShape = filterType.getShape();
927  int64_t filterF = filterShape[0];
928  int64_t filterH = filterShape[1];
929  int64_t filterW = filterShape[2];
930  int64_t filterC = filterShape[3];
931  ArrayRef<int64_t> inputShape = inputType.getShape();
932  int64_t inputN = inputShape[0];
933  int64_t inputH = inputShape[1];
934  int64_t inputW = inputShape[2];
935  int64_t inputC = inputShape[3];
936  ArrayRef<int64_t> outputShape = outputType.getShape();
937  int64_t outputN = outputShape[0];
938  int64_t outputH = outputShape[1];
939  int64_t outputW = outputShape[2];
940  int64_t outputF = outputShape[3];
941 
942  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
943  bool isSupportedFilter = false;
944  if (filterH == filterW && filterH == r)
945  isSupportedFilter = true;
946  if (filterH == r && filterW == 1)
947  isSupportedFilter = true;
948  if (filterH == 1 && filterW == r)
949  isSupportedFilter = true;
950 
951  if (!isSupportedFilter)
952  return rewriter.notifyMatchFailure(
953  convOp, "only support filter (r x r), (r x 1) or (1 x r)");
954 
955  // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5).
956  static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
957  F_2_3, F_4_3, F_2_5};
958 
959  TransformMapKeyTy key = {m, r};
960  auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
961  // If we cannot find the constant transformation matrix, it means we do
962  // not support this configuration yet.
963  if (it == validConfigs.end())
964  return failure();
965 
966  // All the criterias are satisfied. We can do Winograd Conv2D.
967  Location loc = convOp.getLoc();
968 
969  // For F(m x 1, r x 1), we only need to do left side transform.
970  bool leftTransform = filterH != 1;
971  // For F(1 x m, 1 x r), we only need to do right side transform.
972  bool rightTransform = filterW != 1;
973  int64_t heightM = leftTransform ? m : 1;
974  int64_t widthM = rightTransform ? m : 1;
975  int64_t heightR = leftTransform ? r : 1;
976  int64_t widthR = rightTransform ? r : 1;
977 
978  // --- Create operation for filter transform ---
979  Type filterElementType = filterType.getElementType();
980  int64_t alphaH = heightM + heightR - 1;
981  int64_t alphaW = widthM + widthR - 1;
982  int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
983  int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
984  auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF},
985  filterElementType);
986  Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
987  filterElementType);
988  auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
989  loc, retType, filter, retValue, m, r);
990 
991  // --- Create operation for input transform ---
992 
993  // When input size - (r - 1) is not aligned with output tile size, we need to
994  // pad the input data to create the full tiles as tiling.
995  Type inputElementType = inputType.getElementType();
996  int64_t alignedInputH = tileH * heightM + (heightR - 1);
997  int64_t alignedInputW = tileW * widthM + (widthR - 1);
998  if (alignedInputH != inputH || alignedInputW != inputW) {
999  input = padToAlignedTensor(rewriter, loc, input,
1000  {inputN, alignedInputH, alignedInputW, inputC});
1001  }
1002 
1003  retType = RankedTensorType::get(
1004  {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
1005  retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
1006  inputElementType);
1007  auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
1008  loc, retType, input, retValue, m, r);
1009 
1010  Type outputElementType = outputType.getElementType();
1011  Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
1012  transformedInput, outputElementType);
1013 
1014  // --- Create operation for output transform ---
1015 
1016  // When output size is not aligned with output tile size, we need to pad the
1017  // output buffer to insert the full tiles after tiling.
1018  int64_t alignedOutputH = tileH * heightM;
1019  int64_t alignedOutputW = tileW * widthM;
1020  bool isOutputUnaligned =
1021  ((alignedOutputH != outputH) || (alignedOutputW != outputW));
1022  if (isOutputUnaligned) {
1023  auto alignedOutputType = RankedTensorType::get(
1024  {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
1025  output =
1026  padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
1027  outputType = alignedOutputType;
1028  }
1029 
1030  Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
1031  loc, outputType, matmulRet, output, m, r);
1032 
1033  // When output size is not aligned with output tile size, extract the
1034  // value from the padded buffer.
1035  if (isOutputUnaligned) {
1036  transformedOutput = extractFromAlignedTensor(
1037  rewriter, loc, transformedOutput,
1038  RankedTensorType::get({outputN, outputH, outputW, outputF},
1039  outputElementType));
1040  }
1041 
1042  rewriter.replaceOp(convOp, transformedOutput);
1043 
1044  return transformedOutput.getDefiningOp();
1045 }
1046 
1047 /// A helper function to decompose linalg.winograd_filter_transform.
1048 FailureOr<Operation *>
1049 decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
1050  linalg::WinogradFilterTransformOp op) {
1051  Location loc = op.getLoc();
1052  Value filter = op.getFilter();
1053  auto filterType = cast<ShapedType>(filter.getType());
1054  auto filterShape = filterType.getShape();
1055  int64_t filterH = filterShape[1];
1056  int64_t filterW = filterShape[2];
1057 
1058  // For F(m x 1, r x 1), we only need to do left side transform.
1059  bool leftTransform = filterH != 1;
1060  // For F(1 x m, 1 x r), we only need to do right side transform.
1061  bool rightTransform = filterW != 1;
1062  Value transformedFilter =
1063  filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
1064  op.getR(), leftTransform, rightTransform);
1065  if (!transformedFilter)
1066  return failure();
1067 
1068  rewriter.replaceOp(op, transformedFilter);
1069 
1070  return transformedFilter.getDefiningOp();
1071 }
1072 
1073 /// A helper function to decompose linalg.winograd_input_transform.
1074 FailureOr<Operation *>
1075 decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
1076  linalg::WinogradInputTransformOp op) {
1077  Location loc = op.getLoc();
1078  Value input = op.getInput();
1079  auto inputType = cast<ShapedType>(input.getType());
1080  auto inputShape = inputType.getShape();
1081  int64_t inputH = inputShape[1];
1082  int64_t inputW = inputShape[2];
1083 
1084  // For F(m x 1, r x 1), we only need to do left side transform.
1085  bool leftTransform = inputH != 1;
1086  // For F(1 x m, 1 x r), we only need to do right side transform.
1087  bool rightTransform = inputW != 1;
1088  Value transformedInput =
1089  inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
1090  op.getR(), leftTransform, rightTransform);
1091  if (!transformedInput)
1092  return failure();
1093 
1094  rewriter.replaceOp(op, transformedInput);
1095 
1096  return transformedInput.getDefiningOp();
1097 }
1098 
1099 /// A helper function to decompose linalg.winograd_output_transform.
1100 FailureOr<Operation *>
1101 decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
1102  linalg::WinogradOutputTransformOp op) {
1103  Location loc = op.getLoc();
1104  Value value = op.getValue();
1105  auto valueType = cast<ShapedType>(value.getType());
1106  auto valueShape = valueType.getShape();
1107  int64_t valueH = valueShape[0];
1108  int64_t valueW = valueShape[1];
1109 
1110  // For F(m x 1, r x 1), we only need to do left side transform.
1111  bool leftTransform = valueH != 1;
1112  // For F(1 x m, 1 x r), we only need to do right side transform.
1113  bool rightTransform = valueW != 1;
1114  Value transformedOutput =
1115  outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
1116  op.getR(), leftTransform, rightTransform);
1117  if (!transformedOutput)
1118  return failure();
1119 
1120  rewriter.replaceOp(op, transformedOutput);
1121 
1122  return transformedOutput.getDefiningOp();
1123 }
1124 
1125 /// A rewrite pattern to decompose linalg.winograd_filter_transform operations.
1126 class DecomposeWinogradFilterTransform final
1127  : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
1128 public:
1130 
1131  LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
1132  PatternRewriter &rewriter) const override {
1133  return decomposeWinogradFilterTransformHelper(rewriter, op);
1134  }
1135 };
1136 
1137 /// A rewrite pattern to decompose linalg.winograd_input_transform operations.
1138 class DecomposeWinogradInputTransform final
1139  : public OpRewritePattern<linalg::WinogradInputTransformOp> {
1140 public:
1142 
1143  LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
1144  PatternRewriter &rewriter) const override {
1145  return decomposeWinogradInputTransformHelper(rewriter, op);
1146  }
1147 };
1148 
1149 /// A rewrite pattern to decompose linalg.winograd_output_transform operations.
1150 class DecomposeWinogradOutputTransform final
1151  : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
1152 public:
1154 
1155  LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
1156  PatternRewriter &rewriter) const override {
1157  return decomposeWinogradOutputTransformHelper(rewriter, op);
1158  }
1159 };
1160 
1161 /// A rewrite pattern for Winograd Conv2D algorithm.
1162 class WinogradConv2DNhwcFhwc final
1163  : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
1164 public:
1166  WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
1167  : OpRewritePattern(context), m(m), r(r) {}
1168 
1169  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
1170  PatternRewriter &rewriter) const override {
1171  if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
1172  return failure();
1173 
1174  return success();
1175  }
1176 
1177 private:
1178  int64_t m;
1179  int64_t r;
1180 };
1181 } // end anonymous namespace
1182 
1183 //===----------------------------------------------------------------------===//
1184 FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1185  linalg::Conv2DNhwcFhwcOp op, int64_t m,
1186  int64_t r) {
1187  return winogradConv2DHelper(rewriter, op, m, r);
1188 }
1189 
1190 FailureOr<Operation *>
1192  linalg::WinogradFilterTransformOp op) {
1193  return decomposeWinogradFilterTransformHelper(rewriter, op);
1194 }
1195 
1196 FailureOr<Operation *>
1198  linalg::WinogradInputTransformOp op) {
1199  return decomposeWinogradInputTransformHelper(rewriter, op);
1200 }
1201 
1202 FailureOr<Operation *>
1204  linalg::WinogradOutputTransformOp op) {
1205  return decomposeWinogradOutputTransformHelper(rewriter, op);
1206 }
1207 
1209  int64_t r) {
1210  MLIRContext *context = patterns.getContext();
1211  // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw
1212  patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
1213 }
1214 
1216  MLIRContext *context = patterns.getContext();
1217  patterns
1218  .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
1219  DecomposeWinogradOutputTransform>(context);
1220 }
1221 
1222 } // end namespace linalg
1223 } // end namespace mlir
int64_t scalarFactor
const float * table
int64_t cols
int64_t rows
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns)
Patterns to decompose Winograd operators.
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold)
Create a tensor::PadOp that pads source to the size of the statically sized type whose static sizes a...
Definition: Utils.cpp:192
static bool hasAllOneValues(DenseIntElementsAttr attr)
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, int64_t r)
Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
@ Type
An inlay hint that for a type annotation.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
ValueVector results
Definition: SCF.h:74