MLIR  22.0.0git
WinogradConv2D.cpp
Go to the documentation of this file.
1 //===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implement Winograd Conv2D algorithm. The implementation is based on the
10 // paper: Fast Algorithms for Convolutional Neural Networks
11 // (https://arxiv.org/abs/1509.09308)
12 //
13 //===----------------------------------------------------------------------===//
14 
21 #include "llvm/Support/MathExtras.h"
22 
23 namespace mlir {
24 namespace linalg {
25 
26 namespace {
27 
28 // clang-format off
29 /// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
30 /// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
31 /// m is the output dimension and r is the filter dimension, is
32 ///
33 /// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
34 ///
35 /// g is filter and d is input data. We need to prepare 6 constant
36 /// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
37 ///
38 /// The following tables define these constant transformation matrices for
39 /// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
40 ///
41 /// To add more transformation matrices, we need to add the following
42 /// items:
43 /// 1. Add the constant transformation matrix to the corresponding
44 /// G, GT, BT, B, AT, or A array.
45 /// 2. Add the corresponding TransformMatrix to the GMatrices, GTMatrices,
46 /// BTMatrices, BMatrices, ATMatrices, or AMatrices map.
47 /// 3. Add a enum value F_m_r to WinogradConv2DFmr enum.
48 ///
49 constexpr float G_2x2_3x3[] = {
50  -1, 0, 0,
51  1./2, -1./2, 1./2,
52  1./2, 1./2, 1./2,
53  0, 0, 1
54 };
55 
56 constexpr float GT_2x2_3x3[] = {
57  -1, 1./2, 1./2, 0,
58  0, -1./2, 1./2, 0,
59  0, 1./2, 1./2, 1
60 };
61 
62 constexpr float BT_2x2_3x3[] = {
63  -1, 0, 1, 0,
64  0, -1, 1, 0,
65  0, 1, 1, 0,
66  0, -1, 0, 1
67 };
68 
69 constexpr float B_2x2_3x3[] = {
70  -1, 0, 0, 0,
71  0, -1, 1, -1,
72  1, 1, 1, 0,
73  0, 0, 0, 1
74 };
75 
76 constexpr float AT_2x2_3x3[] = {
77  1, 1, 1, 0,
78  0, -1, 1, 1
79 };
80 
81 constexpr float A_2x2_3x3[] = {
82  1, 0,
83  1, -1,
84  1, 1,
85  0, 1
86 };
87 
88 constexpr float G_4x4_3x3[] = {
89  1, 0, 0,
90  -1./3, 1./3, -1./3,
91  -1./3, -1./3, -1./3,
92  1./12, -1./6, 1./3,
93  1./12, 1./6, 1./3,
94  0, 0, 1
95 };
96 
97 constexpr float GT_4x4_3x3[] = {
98  1, -1./3, -1./3, 1./12, 1./12, 0,
99  0, 1./3, -1./3, -1./6, 1./6, 0,
100  0, -1./3, -1./3, 1./3, 1./3, 1
101 };
102 
103 constexpr float BT_4x4_3x3[] = {
104  1./4, 0, -5./16, 0, 1./16, 0,
105  0, 1./4, -1./4, -1./16, 1./16, 0,
106  0, -1./4, -1./4, 1./16, 1./16, 0,
107  0, 1./4, -1./8, -1./4, 1./8, 0,
108  0, -1./4, -1./8, 1./4, 1./8, 0,
109  0, 1./4, 0, -5./16, 0, 1./16
110 };
111 
112 constexpr float B_4x4_3x3[] = {
113  1./4, 0, 0, 0, 0, 0,
114  0, 1./4, -1./4, 1./4, -1./4, 1./4,
115  -5./16, -1./4, -1./4, -1./8, -1./8, 0,
116  0, -1./16, 1./16, -1./4, 1./4, -5./16,
117  1./16, 1./16, 1./16, 1./8, 1./8, 0,
118  0, 0, 0, 0, 0, 1./16
119 };
120 
121 constexpr float AT_4x4_3x3[] = {
122  1./8, 1./4, 1./4, 1./8, 1./8, 0,
123  0, -1./4, 1./4, -1./4, 1./4, 0,
124  0, 1./4, 1./4, 1./2, 1./2, 0,
125  0, -1./4, 1./4, -1, 1, 1./2
126 };
127 
128 constexpr float A_4x4_3x3[] = {
129  1./8, 0, 0, 0,
130  1./4, -1./4, 1./4, -1./4,
131  1./4, 1./4, 1./4, 1./4,
132  1./8, -1./4, 1./2, -1,
133  1./8, 1./4, 1./2, 1,
134  0, 0, 0, 1./2
135 };
136 
137 constexpr float G_2x2_5x5[] = {
138  1, 0, 0, 0, 0,
139  1./6, -1./6, 1./6, -1./6, 1./6,
140  -1./6, -1./6, -1./6, -1./6, -1./6,
141 -4./15, 2./15, -1./15, 1./30, -1./60,
142  1./60, 1./30, 1./15, 2./15, 4./15,
143  0, 0, 0, 0, 1
144 };
145 
146 constexpr float GT_2x2_5x5[] = {
147  1, 1./6, -1./6, -4./15, 1./60, 0,
148  0, -1./6, -1./6, 2./15, 1./30, 0,
149  0, 1./6, -1./6, -1./15, 1./15, 0,
150  0, -1./6, -1./6, 1./30, 2./15, 0,
151  0, 1./6, -1./6, -1./60, 4./15, 1
152 };
153 
154 constexpr float BT_2x2_5x5[] = {
155  1./8, 3./16, -1./4, -3./16, 1./8, 0,
156  0, 1./8, 1./16, -5./16, 1./8, 0,
157  0, -1./8, -5./16, -1./16, 1./8, 0,
158  0, 1./4, -1./8, -1./4, 1./8, 0,
159  0, -1./8, -1./4, 1./8, 1./4, 0,
160  0, 1./8, 3./16, -1./4, -3./16, 1./8
161 };
162 
163 constexpr float B_2x2_5x5[] = {
164  1./8, 0, 0, 0, 0, 0,
165  3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
166  -1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
167  -3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
168  1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
169  0, 0, 0, 0, 0, 1./8
170 };
171 
172 constexpr float AT_2x2_5x5[] = {
173  1./2, 1, 1, 2, 1, 0,
174  0, -1, 1, -1, 2, 1./2
175 };
176 
177 constexpr float A_2x2_5x5[] = {
178  1./2, 0,
179  1, -1,
180  1, 1,
181  2, -1,
182  1, 2,
183  0, 1./2
184 };
185 // clang-format on
186 
187 /// Structure to keep information of constant transform matrices.
188 struct TransformMatrix {
189  TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols,
190  int64_t scalarFactor = 1)
192 
193  ArrayRef<float> table;
194  int64_t rows;
195  int64_t cols;
196  int64_t scalarFactor;
197 };
198 
199 /// Utility function to convert constant array to arith.constant Value.
200 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
201  TransformMatrix transform, Type type) {
202  assert(transform.table.size() ==
203  static_cast<size_t>(transform.rows * transform.cols));
204  assert(type.isFloat() && "Only floats are supported by Winograd");
205  ArrayRef<float> constVec(transform.table.data(),
206  transform.rows * transform.cols);
207  auto constAttrVec =
208  llvm::map_to_vector<>(constVec, [&](const float v) -> Attribute {
209  return builder.getFloatAttr(type, v);
210  });
211  SmallVector<int64_t, 2> shape{transform.rows, transform.cols};
212  return arith::ConstantOp::create(
213  builder, loc,
215  constAttrVec));
216 }
217 
218 /// Extract height x width data from 4D tensors.
219 Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
220  Value loopNorFIndex, Value loopCorFIndex,
221  Value heightOffset, Value widthOffset,
222  int64_t extractHeight, int64_t extractWidth,
223  int64_t loopNorFIdx, int64_t loopCorFIdx,
224  int64_t heightIdx, int64_t widthIdx) {
225  auto sourceType = cast<ShapedType>(source.getType());
226  Type elementType = sourceType.getElementType();
227  int64_t srcSize = sourceType.getRank();
228 
229  auto oneIndex = builder.getIndexAttr(1);
230  SmallVector<OpFoldResult> offsets;
231  offsets.resize(srcSize);
232  offsets[loopNorFIdx] = loopNorFIndex;
233  offsets[loopCorFIdx] = loopCorFIndex;
234  offsets[heightIdx] = heightOffset;
235  offsets[widthIdx] = widthOffset;
236  SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
237  sizes[heightIdx] = builder.getIndexAttr(extractHeight);
238  sizes[widthIdx] = builder.getIndexAttr(extractWidth);
239  SmallVector<OpFoldResult> strides(srcSize, oneIndex);
240 
241  auto extractFilterType =
242  RankedTensorType::get({extractHeight, extractWidth}, elementType);
243  auto extractFilterOp = tensor::ExtractSliceOp::create(
244  builder, loc, extractFilterType, source, offsets, sizes, strides);
245 
246  return extractFilterOp;
247 }
248 
249 /// Extract height x width data from 6D tensors.
250 Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
251  Value tileHIndex, Value tileWIndex,
252  Value loopNorFIndex, Value loopCorFIndex,
253  int64_t tileHIdx, int64_t tileWIdx,
254  int64_t loopNorFIdx, int64_t loopCorFIdx,
255  int64_t heightIdx, int64_t widthIdx) {
256  auto sourceType = cast<ShapedType>(source.getType());
257  Type elementType = sourceType.getElementType();
258  auto sourceShape = sourceType.getShape();
259  int64_t srcSize = sourceType.getRank();
260  int64_t height = sourceShape[heightIdx];
261  int64_t width = sourceShape[widthIdx];
262 
263  auto zeroIndex = builder.getIndexAttr(0);
264  auto oneIndex = builder.getIndexAttr(1);
265  SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
266  offsets.resize(srcSize);
267  offsets[tileHIdx] = tileHIndex;
268  offsets[tileWIdx] = tileWIndex;
269  offsets[loopNorFIdx] = loopNorFIndex;
270  offsets[loopCorFIdx] = loopCorFIndex;
271  SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
272  sizes[heightIdx] = builder.getIndexAttr(height);
273  sizes[widthIdx] = builder.getIndexAttr(width);
274  SmallVector<OpFoldResult> strides(srcSize, oneIndex);
275 
276  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
277  auto extractFilterOp = tensor::ExtractSliceOp::create(
278  builder, loc, extractFilterType, source, offsets, sizes, strides);
279 
280  return extractFilterOp;
281 }
282 
283 /// Insert transformed height x width data to 4D tensors which it is
284 /// extracted from.
285 Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
286  Value dest, Value loopNorFIndex, Value loopCorFIndex,
287  Value heightOffset, Value widthOffset, int64_t height,
288  int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
289  int64_t heightIdx, int64_t widthIdx) {
290  int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
291  auto oneIndex = builder.getIndexAttr(1);
292  SmallVector<OpFoldResult> retOffsets;
293  retOffsets.resize(destSize);
294  retOffsets[loopNorFIdx] = loopNorFIndex;
295  retOffsets[loopCorFIdx] = loopCorFIndex;
296  retOffsets[heightIdx] = heightOffset;
297  retOffsets[widthIdx] = widthOffset;
298  SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
299  retSizes[heightIdx] = builder.getIndexAttr(height);
300  retSizes[widthIdx] = builder.getIndexAttr(width);
301  SmallVector<OpFoldResult> strides(destSize, oneIndex);
302 
303  auto insertSliceOp = tensor::InsertSliceOp::create(
304  builder, loc, source, dest, retOffsets, retSizes, strides);
305 
306  return insertSliceOp;
307 }
308 
309 /// Insert transformed height x width data to 6D tensors which it is
310 /// extracted from.
311 Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
312  Value dest, Value tileHIndex, Value tileWIndex,
313  Value loopNorFIndex, Value loopCorFIndex, int64_t height,
314  int64_t width, int64_t tileHIdx, int64_t tileWIdx,
315  int64_t loopNorFIdx, int64_t loopCorFIdx,
316  int64_t heightIdx, int64_t widthIdx) {
317  int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
318  auto zeroIndex = builder.getIndexAttr(0);
319  auto oneIndex = builder.getIndexAttr(1);
320  SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
321  retOffsets.resize(destSize);
322  retOffsets[tileHIdx] = tileHIndex;
323  retOffsets[tileWIdx] = tileWIndex;
324  retOffsets[loopNorFIdx] = loopNorFIndex;
325  retOffsets[loopCorFIdx] = loopCorFIndex;
326  SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
327  retSizes[heightIdx] = builder.getIndexAttr(height);
328  retSizes[widthIdx] = builder.getIndexAttr(width);
329  SmallVector<OpFoldResult> strides(destSize, oneIndex);
330 
331  auto insertSliceOp = tensor::InsertSliceOp::create(
332  builder, loc, source, dest, retOffsets, retSizes, strides);
333 
334  return insertSliceOp;
335 }
336 
337 /// This function transforms the filter. The data layout of the filter is FHWC.
338 /// The transformation matrix is 2-dimension. We need to extract H x W from
339 /// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
340 /// After the transformation, we get
341 ///
342 /// scf.for %f = lo_f to hi_f step 1
343 /// scf.for %c = lo_c to hi_c step 1
344 /// %extracted = extract filter<h x w> from filter<f x h x w x c>
345 /// %ret = linalg.matmul G, %extracted
346 /// %ret = linalg.matmul %ret, GT
347 /// %inserted = insert %ret into filter<h x w x c x f>
348 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
349  Value retValue, WinogradConv2DFmr fmr,
350  bool leftTransform = true, bool rightTransform = true) {
351  // Map from (m, r) to G transform matrix.
352  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
353  GMatrices = {
354  {WinogradConv2DFmr::F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
355  {WinogradConv2DFmr::F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
356  {WinogradConv2DFmr::F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
357  };
358 
359  // Map from (m, r) to GT transform matrix.
360  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
361  GTMatrices = {
362  {WinogradConv2DFmr::F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
363  {WinogradConv2DFmr::F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
364  {WinogradConv2DFmr::F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
365  };
366 
367  auto filterType = cast<ShapedType>(filter.getType());
368  Type elementType = filterType.getElementType();
369  auto filterShape = filterType.getShape(); // F, H, W, C
370  int64_t filterF = filterShape[0];
371  int64_t filterH = filterShape[1];
372  int64_t filterW = filterShape[2];
373  int64_t filterC = filterShape[3];
374 
375  int64_t m, r;
376  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
377  if (filterH != r && filterH != 1)
378  return Value();
379  if (filterW != r && filterW != 1)
380  return Value();
381 
382  Value zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0);
383  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
384  ValueRange args) -> scf::ValueVector {
385  Value FIter = ivs[0];
386  Value CIter = ivs[1];
387 
388  // Extract (H, W) from (F, H, W, C).
389  auto extractFilter =
390  extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
391  zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
392  /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
393 
394  int64_t retRows = 1;
395  Value matmulRetValue = extractFilter;
396  Value zero = arith::ConstantOp::create(builder, loc,
397  rewriter.getZeroAttr(elementType));
398  if (leftTransform) {
399  // Get constant transform matrix G.
400  auto it = GMatrices.find(fmr);
401  if (it == GMatrices.end())
402  return {};
403  const TransformMatrix &GMatrix = it->second;
404 
405  retRows = GMatrix.rows;
406  auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
407  auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
408  elementType)
409  .getResult();
410  auto init =
411  linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
412 
413  Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
414  // Multiply G x g.
415  auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
416  ValueRange{G, extractFilter},
417  ValueRange{init});
418  matmulRetValue = matmulOp.getResult(0);
419  }
420 
421  if (rightTransform) {
422  // Get constant transform matrix GT.
423  auto it = GTMatrices.find(fmr);
424  if (it == GTMatrices.end())
425  return {};
426  const TransformMatrix &GTMatrix = it->second;
427 
428  auto matmulType =
429  RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
430  auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
431  elementType)
432  .getResult();
433  auto init =
434  linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
435 
436  Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
437  // Multiply u = (G x g) x GT.
438  auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
439  ValueRange{matmulRetValue, GT},
440  ValueRange{init});
441  matmulRetValue = matmulOp.getResult(0);
442  }
443 
444  // Insert (H, W) to (H, W, C, F).
445  int64_t retHeight = leftTransform ? m + r - 1 : 1;
446  int64_t retWidth = rightTransform ? m + r - 1 : 1;
447 
448  auto insertSliceOp =
449  insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
450  zeroIdx, zeroIdx, retHeight, retWidth,
451  /*loopNorFIdx=*/3, /*loopCorFIdx=*/2,
452  /*heightIdx=*/0, /*widthIdx=*/1);
453 
454  return {insertSliceOp};
455  };
456 
457  auto fUpperBound = arith::ConstantIndexOp::create(rewriter, loc, filterF);
458  auto cUpperBound = arith::ConstantIndexOp::create(rewriter, loc, filterC);
459  auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1);
460  scf::LoopNest loops = scf::buildLoopNest(
461  rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
462  {oneStep, oneStep}, {retValue}, buildBody);
463  return loops.results[0];
464 }
465 
466 /// This function transforms the input. The data layout of the input is NHWC.
467 /// The transformation matrix is 2-dimension. We need to extract H x W from
468 /// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
469 /// After the transformation, we get
470 ///
471 /// scf.for %h = 0 to tileH step 1
472 /// scf.for %w = 0 to tileW step 1
473 /// scf.for %n = 0 to N step 1
474 /// scf.for %c = 0 to C step 1
475 /// %extracted = extract %extracted<alphaH x alphaW> from
476 /// %input<N x H x W x C>
477 /// at [%n, (%h x m), (%w x m), %c]
478 /// %ret = linalg.matmul BT, %extracted
479 /// %ret = linalg.matmul %ret, B
480 /// %inserted = insert %ret<alphaH x alphaW> into
481 /// %output<alphaH x alphaW x tileH x tileW x N x C>
482 /// at [0, 0, %h, %w, %n, %c]
483 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
484  Value retValue, WinogradConv2DFmr fmr,
485  bool leftTransform = true, bool rightTransform = true) {
486  // Map from (m, r) to BT transform matrix.
487  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
488  BTMatrices = {
489  {WinogradConv2DFmr::F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
490  {WinogradConv2DFmr::F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
491  {WinogradConv2DFmr::F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
492  };
493 
494  // Map from (m, r) to B transform matrix.
495  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
496  BMatrices = {
497  {WinogradConv2DFmr::F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
498  {WinogradConv2DFmr::F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
499  {WinogradConv2DFmr::F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
500  };
501 
502  int64_t m, r;
503  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
504  auto inputType = cast<ShapedType>(input.getType());
505  Type elementType = inputType.getElementType();
506  auto inputShape = inputType.getShape(); // N, H, W, C
507  int64_t inputN = inputShape[0];
508  int64_t inputC = inputShape[3];
509  auto valueType = cast<ShapedType>(retValue.getType());
510  auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C
511  int64_t tileH = valueShape[2];
512  int64_t tileW = valueShape[3];
513  int64_t alphaH = leftTransform ? m + r - 1 : 1;
514  int64_t alphaW = rightTransform ? m + r - 1 : 1;
515 
516  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
517  ValueRange args) -> scf::ValueVector {
518  Value tileHIter = ivs[0];
519  Value tileWIter = ivs[1];
520  Value NIter = ivs[2];
521  Value CIter = ivs[3];
522 
523  auto *context = builder.getContext();
524 
525  auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
526  auto affineMap =
527  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
528  Value heightOffset = affine::AffineApplyOp::create(
529  builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
530  Value widthOffset = affine::AffineApplyOp::create(
531  builder, loc, rightTransform ? affineMap : identityAffineMap,
532  tileWIter);
533 
534  // Extract (H, W) from (N, H, W, C).
535  auto extractInput =
536  extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
537  widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
538  /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
539 
540  int64_t retRows = 1;
541  int64_t retCols = 1;
542  Value matmulRetValue = extractInput;
543  Value zero = arith::ConstantOp::create(builder, loc,
544  rewriter.getZeroAttr(elementType));
545  if (leftTransform) {
546  // Get constant transform matrix BT.
547  auto it = BTMatrices.find(fmr);
548  if (it == BTMatrices.end())
549  return {};
550  const TransformMatrix &BTMatrix = it->second;
551 
552  retRows = BTMatrix.rows;
553  auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
554  auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
555  elementType)
556  .getResult();
557  auto init =
558  linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
559 
560  Value BT = create2DTransformMatrix(builder, loc, BTMatrix, elementType);
561  // Multiply BT x d.
562  auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
563  ValueRange{BT, matmulRetValue},
564  ValueRange{init});
565  matmulRetValue = matmulOp.getResult(0);
566  }
567 
568  if (rightTransform) {
569  // Get constant transform matrix B.
570  auto it = BMatrices.find(fmr);
571  if (it == BMatrices.end())
572  return {};
573  const TransformMatrix &BMatrix = it->second;
574 
575  retCols = BMatrix.cols;
576  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
577  auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
578  elementType)
579  .getResult();
580  auto init =
581  linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
582  Value B = create2DTransformMatrix(builder, loc, BMatrix, elementType);
583  // Multiply v = (BT x d) x B.
584  auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
585  ValueRange{matmulRetValue, B},
586  ValueRange{init});
587  matmulRetValue = matmulOp.getResult(0);
588  }
589 
590  // Insert (H, W) to (H, W, tileH, tileW, N, C).
591  auto combinedVal = insert2DDataTo6D(
592  builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
593  CIter, retRows, retCols, 2, 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5,
594  /*heightIdx=*/0, /*widthIdx=*/1);
595 
596  return {combinedVal};
597  };
598 
599  auto zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0);
600  auto tileHBound = arith::ConstantIndexOp::create(rewriter, loc, tileH);
601  auto tileWBound = arith::ConstantIndexOp::create(rewriter, loc, tileW);
602  auto nUpperBound = arith::ConstantIndexOp::create(rewriter, loc, inputN);
603  auto cUpperBound = arith::ConstantIndexOp::create(rewriter, loc, inputC);
604  auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1);
605  scf::LoopNest loops = scf::buildLoopNest(
606  rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
607  {tileHBound, tileWBound, nUpperBound, cUpperBound},
608  {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
609  return loops.results[0];
610 }
611 
612 /// This function generates linalg.batch_matmul to multiply input with filter.
613 /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
614 /// tileH x tileW x H x W data as the 1-dimensional data array. That is to
615 /// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
616 /// way, we can convert 6-dimensional inputs to 3-dimensional representation
617 /// that is suitable for linalg.batch_matmul.
618 ///
619 /// Batched matmul will do the matrix multiply with the reduction on channel.
620 ///
621 /// We get
622 ///
623 /// %collapsed_input = tensor.collapse_shape %input
624 /// %collapsed_filter = tensor.collapse_shape %filter
625 /// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
626 /// %expanded_ret = tensor.expand_shape %ret
627 ///
628 /// After this function, we get return value with data layout
629 /// (tileH, tileW, H, W, N, F).
630 static Value matrixMultiply(RewriterBase &rewriter, Location loc,
631  Value transformedFilter, Value transformedInput,
632  Type outputElementType) {
633  // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter.
634  auto filterType = cast<ShapedType>(transformedFilter.getType());
635  assert(filterType.hasStaticShape() && "only support static shapes.");
636  ArrayRef<int64_t> filterShape = filterType.getShape();
637  Type filterElementType = filterType.getElementType();
638  auto filterReassocType = RankedTensorType::get(
639  {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
640  filterElementType);
641  SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
642  Value collapseFilter = tensor::CollapseShapeOp::create(
643  rewriter, loc, filterReassocType, transformedFilter, filterReassoc);
644 
645  // Convert (alphaH, alphaW, tileH, tileW, N, C) to
646  // (alphaH x alphaW, tileH x tileW x N, C) for input.
647  auto inputType = cast<ShapedType>(transformedInput.getType());
648  assert(inputType.hasStaticShape() && "only support static shapes.");
649  ArrayRef<int64_t> inputShape = inputType.getShape();
650  Type inputElementType = inputType.getElementType();
651  auto inputReassocType = RankedTensorType::get(
652  {inputShape[0] * inputShape[1],
653  inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
654  inputElementType);
655  SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
656  Value collapseInput = tensor::CollapseShapeOp::create(
657  rewriter, loc, inputReassocType, transformedInput, inputReassoc);
658 
659  // Batched matrix multiply.
660  auto matmulType = RankedTensorType::get(
661  {inputShape[0] * inputShape[1],
662  inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
663  outputElementType);
664  Value empty = tensor::EmptyOp::create(rewriter, loc, matmulType.getShape(),
665  outputElementType)
666  .getResult();
667  Value zero = arith::ConstantOp::create(
668  rewriter, loc, rewriter.getZeroAttr(outputElementType));
669  Value init = linalg::FillOp::create(rewriter, loc, zero, empty).getResult(0);
670 
671  auto matmulOp = linalg::BatchMatmulOp::create(
672  rewriter, loc, matmulType, ValueRange({collapseInput, collapseFilter}),
673  ValueRange{init});
674 
675  // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
676  // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F).
677  SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
678  auto outputReassocType =
679  RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
680  inputShape[3], inputShape[4], filterShape[3]},
681  outputElementType);
682  auto expandOutput = tensor::ExpandShapeOp::create(
683  rewriter, loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
684  return expandOutput;
685 }
686 
687 /// This function transforms the output. The data layout of the output is HWNF.
688 /// The transformation matrix is 2-dimension. We need to extract H x W from
689 /// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
690 /// After the transformation, we get
691 ///
692 /// scf.for %h = 0 to tileH step 1
693 /// scf.for %w = 0 to tileW step 1
694 /// scf.for %n = 0 to N step 1
695 /// scf.for %f = 0 to F step 1
696 /// %extracted = extract %extracted<alphaH x alphaW> from
697 /// %input<alphaH x alphaW x tileH x tileW x N x F>
698 /// at [0, 0, %h, %w, %n, %f]
699 /// %ret = linalg.matmul AT, %extracted
700 /// %ret = linalg.matmul %ret, A
701 /// %inserted = insert %ret<alphaH x alphaW> into
702 /// output<N x H x W x F>
703 /// at [%n, (%h x m), (%w x m), %f]
704 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
705  Value output, WinogradConv2DFmr fmr,
706  bool leftTransform = true, bool rightTransform = true) {
707  // Map from (m, r) to AT transform matrix.
708  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
709  ATMatrices = {
710  {WinogradConv2DFmr::F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
711  {WinogradConv2DFmr::F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
712  {WinogradConv2DFmr::F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
713  };
714 
715  // Map from (m, r) to A transform matrix.
716  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
717  AMatrices = {
718  {WinogradConv2DFmr::F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
719  {WinogradConv2DFmr::F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
720  {WinogradConv2DFmr::F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
721  };
722 
723  int64_t m, r;
724  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
725  auto valueType = cast<ShapedType>(value.getType());
726  Type elementType = valueType.getElementType();
727  auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
728  int64_t valueH = valueShape[0];
729  int64_t valueW = valueShape[1];
730  int64_t valueN = valueShape[4];
731  int64_t valueF = valueShape[5];
732  int64_t alphaH = leftTransform ? m + r - 1 : 1;
733  int64_t alphaW = rightTransform ? m + r - 1 : 1;
734 
735  if (valueH != alphaH && valueH != 1)
736  return Value();
737  if (valueW != alphaW && valueW != 1)
738  return Value();
739 
740  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
741  ValueRange args) -> scf::ValueVector {
742  auto *context = builder.getContext();
743  Value tileHIter = ivs[0];
744  Value tileWIter = ivs[1];
745  Value NIter = ivs[2];
746  Value FIter = ivs[3];
747 
748  // Extract (H, W) from (H, W, tileH, tileW, N, F).
749  auto extractValue =
750  extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
751  FIter, 2, 3, /*loopNorFIdx=*/4,
752  /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
753 
754  const TransformMatrix &AMatrix = AMatrices.at(fmr);
755  const TransformMatrix &ATMatrix = ATMatrices.at(fmr);
756  int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
757  (leftTransform ? ATMatrix.scalarFactor : 1);
758  int64_t retCols = rightTransform ? AMatrix.cols : 1;
759  int64_t retRows = leftTransform ? ATMatrix.rows : 1;
760 
761  Value matmulRetValue = extractValue;
762  Value zero = arith::ConstantOp::create(builder, loc,
763  rewriter.getZeroAttr(elementType));
764 
765  auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
766  auto affineMap =
767  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
768  Value heightOffset = affine::AffineApplyOp::create(
769  builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
770  Value widthOffset = affine::AffineApplyOp::create(
771  builder, loc, rightTransform ? affineMap : identityAffineMap,
772  tileWIter);
773 
774  Value outInitVal =
775  extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
776  widthOffset, retRows, retCols,
777  /*loopNorFIdx=*/0,
778  /*loopCorFIdx=*/3, /*heightIdx=*/1,
779  /*widthIdx=*/2);
780  if (leftTransform) {
781  auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
782  Value init = outInitVal;
783  if (rightTransform || scalarFactor != 1) {
784  auto empty = tensor::EmptyOp::create(builder, loc,
785  matmulType.getShape(), elementType)
786  .getResult();
787  init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
788  }
789 
790  Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
791  // Multiply AT x m.
792  auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
793  ValueRange{AT, matmulRetValue},
794  ValueRange{init});
795  matmulRetValue = matmulOp.getResult(0);
796  }
797 
798  if (rightTransform) {
799  auto matmulType =
800  RankedTensorType::get({retRows, AMatrix.cols}, elementType);
801  Value init = outInitVal;
802  if (scalarFactor != 1) {
803  auto empty = tensor::EmptyOp::create(builder, loc,
804  matmulType.getShape(), elementType)
805  .getResult();
806  init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
807  }
808 
809  Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
810  // Multiply y = (AT x m) x A.
811  auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
812  ValueRange{matmulRetValue, A},
813  ValueRange{init});
814  matmulRetValue = matmulOp.getResult(0);
815  }
816 
817  if (scalarFactor != 1) {
818  // Multiply by scalar factor and add outInitVal.
819  Value scalarFactorValue = arith::ConstantOp::create(
820  builder, loc, FloatAttr::get(elementType, scalarFactor));
821  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
822  auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
823  SmallVector<AffineMap> affineMaps = {
824  AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
825 
826  matmulRetValue =
827  linalg::GenericOp::create(
828  rewriter, loc, matmulType,
829  ValueRange{scalarFactorValue, matmulRetValue},
830  ValueRange{outInitVal}, affineMaps,
832  utils::IteratorType::parallel, utils::IteratorType::parallel},
833  [&](OpBuilder &nestedBuilder, Location nestedLoc,
834  ValueRange args) {
835  auto mulf = arith::MulFOp::create(nestedBuilder, nestedLoc,
836  args[0], args[1]);
837  auto addf = arith::AddFOp::create(nestedBuilder, nestedLoc,
838  mulf.getResult(), args[2]);
839  linalg::YieldOp::create(nestedBuilder, nestedLoc,
840  addf.getResult());
841  })
842  .getResult(0);
843  }
844 
845  // Insert (H, W) to (N, H, W, F).
846  Value combinedVal =
847  insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
848  heightOffset, widthOffset, retRows, retCols,
849  /*loopNorFIdx=*/0,
850  /*loopCorFIdx=*/3, /*heightIdx=*/1,
851  /*widthIdx=*/2);
852 
853  return {combinedVal};
854  };
855 
856  int64_t tilwH = valueShape[2];
857  int64_t tileW = valueShape[3];
858  auto zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0);
859  auto tileHBound = arith::ConstantIndexOp::create(rewriter, loc, tilwH);
860  auto tileWBound = arith::ConstantIndexOp::create(rewriter, loc, tileW);
861  auto nUpperBound = arith::ConstantIndexOp::create(rewriter, loc, valueN);
862  auto fUpperBound = arith::ConstantIndexOp::create(rewriter, loc, valueF);
863  auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1);
864  scf::LoopNest loops = scf::buildLoopNest(
865  rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
866  {tileHBound, tileWBound, nUpperBound, fUpperBound},
867  {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
868  return loops.results[0];
869 }
870 
871 /// Create an empty tensor with alignedType and insert the value into the
872 /// created empty tensor with aligned size.
873 static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
874  Value value, ArrayRef<int64_t> alignedShape) {
875  auto valueType = cast<ShapedType>(value.getType());
876  Type elementType = valueType.getElementType();
877  auto alignedType = RankedTensorType::get(alignedShape, elementType);
878  Value padValue = arith::ConstantOp::create(rewriter, loc, elementType,
879  rewriter.getZeroAttr(elementType));
880 
881  return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value,
882  padValue, false);
883 }
884 
885 /// Extract sub-tensor with extractedType from value.
886 static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
887  Value value,
888  RankedTensorType extractedType) {
889  OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
890  OpFoldResult oneIndex = rewriter.getIndexAttr(1);
891  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
892  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
893 
894  ArrayRef<int64_t> extractedShape = extractedType.getShape();
895  SmallVector<OpFoldResult> sizes =
896  getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape));
897 
898  return tensor::ExtractSliceOp::create(rewriter, loc, extractedType, value,
899  offsets, sizes, strides);
900 }
901 
902 /// Utility function to check all values in the attribute are 1.
903 static bool hasAllOneValues(DenseIntElementsAttr attr) {
904  return llvm::all_of(
905  attr, [](const APInt &element) { return element.getSExtValue() == 1; });
906 }
907 
908 /// A helper function to convert linalg.conv_2d_nhwc_fhwc to
909 /// linalg.winograd_*_transform ops.
910 static FailureOr<Operation *>
911 winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
912  WinogradConv2DFmr fmr) {
913  if (!convOp.hasPureTensorSemantics())
914  return rewriter.notifyMatchFailure(
915  convOp, "expected pure tensor semantics for linalg.conv_2d_nhwc_fhwc");
916 
917  Value input = convOp.getInputs()[0];
918  Value filter = convOp.getInputs()[1];
919  Value output = convOp.getOutputs()[0];
920  auto inputType = cast<ShapedType>(input.getType());
921  auto filterType = cast<ShapedType>(filter.getType());
922  auto outputType = cast<ShapedType>(output.getType());
923 
924  if (!inputType.hasStaticShape())
925  return rewriter.notifyMatchFailure(convOp,
926  "expected a static shape for the input");
927 
928  if (!filterType.hasStaticShape())
929  return rewriter.notifyMatchFailure(
930  convOp, "expected a static shape for the filter");
931 
932  if (!hasAllOneValues(convOp.getDilations()))
933  return rewriter.notifyMatchFailure(convOp,
934  "expected all ones for dilations");
935 
936  if (!hasAllOneValues(convOp.getStrides()))
937  return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
938 
939  ArrayRef<int64_t> filterShape = filterType.getShape();
940  int64_t filterF = filterShape[0];
941  int64_t filterH = filterShape[1];
942  int64_t filterW = filterShape[2];
943  int64_t filterC = filterShape[3];
944  ArrayRef<int64_t> inputShape = inputType.getShape();
945  int64_t inputN = inputShape[0];
946  int64_t inputH = inputShape[1];
947  int64_t inputW = inputShape[2];
948  int64_t inputC = inputShape[3];
949  ArrayRef<int64_t> outputShape = outputType.getShape();
950  int64_t outputN = outputShape[0];
951  int64_t outputH = outputShape[1];
952  int64_t outputW = outputShape[2];
953  int64_t outputF = outputShape[3];
954 
955  int64_t m, r;
956  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
957  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
958  bool isSupportedFilter = false;
959  if (filterH == filterW && filterH == r)
960  isSupportedFilter = true;
961  if (filterH == r && filterW == 1)
962  isSupportedFilter = true;
963  if (filterH == 1 && filterW == r)
964  isSupportedFilter = true;
965 
966  if (!isSupportedFilter)
967  return rewriter.notifyMatchFailure(
968  convOp, "only support filter (r x r), (r x 1) or (1 x r)");
969 
970  // All the criterias are satisfied. We can do Winograd Conv2D.
971  Location loc = convOp.getLoc();
972 
973  // For F(m x 1, r x 1), we only need to do left side transform.
974  bool leftTransform = filterH != 1;
975  // For F(1 x m, 1 x r), we only need to do right side transform.
976  bool rightTransform = filterW != 1;
977  int64_t heightM = leftTransform ? m : 1;
978  int64_t widthM = rightTransform ? m : 1;
979  int64_t heightR = leftTransform ? r : 1;
980  int64_t widthR = rightTransform ? r : 1;
981 
982  // --- Create operation for filter transform ---
983  Type filterElementType = filterType.getElementType();
984  int64_t alphaH = heightM + heightR - 1;
985  int64_t alphaW = widthM + widthR - 1;
986  int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
987  int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
988  auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF},
989  filterElementType);
990  Value retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(),
991  filterElementType);
992  auto transformedFilter = linalg::WinogradFilterTransformOp::create(
993  rewriter, loc, retType, filter, retValue, fmr);
994 
995  // --- Create operation for input transform ---
996 
997  // When input size - (r - 1) is not aligned with output tile size, we need to
998  // pad the input data to create the full tiles as tiling.
999  Type inputElementType = inputType.getElementType();
1000  int64_t alignedInputH = tileH * heightM + (heightR - 1);
1001  int64_t alignedInputW = tileW * widthM + (widthR - 1);
1002  if (alignedInputH != inputH || alignedInputW != inputW) {
1003  input = padToAlignedTensor(rewriter, loc, input,
1004  {inputN, alignedInputH, alignedInputW, inputC});
1005  }
1006 
1007  retType = RankedTensorType::get(
1008  {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
1009  retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(),
1010  inputElementType);
1011  auto transformedInput = linalg::WinogradInputTransformOp::create(
1012  rewriter, loc, retType, input, retValue, fmr);
1013 
1014  Type outputElementType = outputType.getElementType();
1015  Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
1016  transformedInput, outputElementType);
1017 
1018  // --- Create operation for output transform ---
1019 
1020  // When output size is not aligned with output tile size, we need to pad the
1021  // output buffer to insert the full tiles after tiling.
1022  int64_t alignedOutputH = tileH * heightM;
1023  int64_t alignedOutputW = tileW * widthM;
1024  bool isOutputUnaligned =
1025  ((alignedOutputH != outputH) || (alignedOutputW != outputW));
1026  if (isOutputUnaligned) {
1027  auto alignedOutputType = RankedTensorType::get(
1028  {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
1029  output =
1030  padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
1031  outputType = alignedOutputType;
1032  }
1033 
1034  Value transformedOutput = linalg::WinogradOutputTransformOp::create(
1035  rewriter, loc, outputType, matmulRet, output, fmr);
1036 
1037  // When output size is not aligned with output tile size, extract the
1038  // value from the padded buffer.
1039  if (isOutputUnaligned) {
1040  transformedOutput = extractFromAlignedTensor(
1041  rewriter, loc, transformedOutput,
1042  RankedTensorType::get({outputN, outputH, outputW, outputF},
1043  outputElementType));
1044  }
1045 
1046  rewriter.replaceOp(convOp, transformedOutput);
1047 
1048  return transformedOutput.getDefiningOp();
1049 }
1050 
1051 /// A helper function to decompose linalg.winograd_filter_transform.
1052 FailureOr<Operation *>
1053 decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
1054  linalg::WinogradFilterTransformOp op) {
1055  Location loc = op.getLoc();
1056  Value filter = op.getFilter();
1057  auto filterType = cast<ShapedType>(filter.getType());
1058  auto filterShape = filterType.getShape();
1059  int64_t filterH = filterShape[1];
1060  int64_t filterW = filterShape[2];
1061 
1062  // For F(m x 1, r x 1), we only need to do left side transform.
1063  bool leftTransform = filterH != 1;
1064  // For F(1 x m, 1 x r), we only need to do right side transform.
1065  bool rightTransform = filterW != 1;
1066  Value transformedFilter =
1067  filterTransform(rewriter, loc, filter, op.getOutput(), op.getFmr(),
1068  leftTransform, rightTransform);
1069  if (!transformedFilter)
1070  return failure();
1071 
1072  rewriter.replaceOp(op, transformedFilter);
1073 
1074  return transformedFilter.getDefiningOp();
1075 }
1076 
1077 /// A helper function to decompose linalg.winograd_input_transform.
1078 FailureOr<Operation *>
1079 decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
1080  linalg::WinogradInputTransformOp op) {
1081  Location loc = op.getLoc();
1082  Value output = op.getOutput();
1083  auto outputType = cast<ShapedType>(output.getType());
1084  auto outputShape = outputType.getShape();
1085 
1086  int64_t outputH = outputShape[0];
1087  int64_t outputW = outputShape[1];
1088 
1089  // For F(m x 1, r x 1), we only need to do left side transform.
1090  bool leftTransform = outputH != 1;
1091  // For F(1 x m, 1 x r), we only need to do right side transform.
1092  bool rightTransform = outputW != 1;
1093  Value transformedInput =
1094  inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getFmr(),
1095  leftTransform, rightTransform);
1096  if (!transformedInput)
1097  return failure();
1098 
1099  rewriter.replaceOp(op, transformedInput);
1100 
1101  return transformedInput.getDefiningOp();
1102 }
1103 
1104 /// A helper function to decompose linalg.winograd_output_transform.
1105 FailureOr<Operation *>
1106 decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
1107  linalg::WinogradOutputTransformOp op) {
1108  Location loc = op.getLoc();
1109  Value value = op.getValue();
1110  auto valueType = cast<ShapedType>(value.getType());
1111  auto valueShape = valueType.getShape();
1112  int64_t valueH = valueShape[0];
1113  int64_t valueW = valueShape[1];
1114 
1115  // For F(m x 1, r x 1), we only need to do left side transform.
1116  bool leftTransform = valueH != 1;
1117  // For F(1 x m, 1 x r), we only need to do right side transform.
1118  bool rightTransform = valueW != 1;
1119  Value transformedOutput =
1120  outputTransform(rewriter, loc, value, op.getOutput(), op.getFmr(),
1121  leftTransform, rightTransform);
1122  if (!transformedOutput)
1123  return failure();
1124 
1125  rewriter.replaceOp(op, transformedOutput);
1126 
1127  return transformedOutput.getDefiningOp();
1128 }
1129 
1130 /// A rewrite pattern to decompose linalg.winograd_filter_transform operations.
1131 class DecomposeWinogradFilterTransform final
1132  : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
1133 public:
1135 
1136  LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
1137  PatternRewriter &rewriter) const override {
1138  return decomposeWinogradFilterTransformHelper(rewriter, op);
1139  }
1140 };
1141 
1142 /// A rewrite pattern to decompose linalg.winograd_input_transform operations.
1143 class DecomposeWinogradInputTransform final
1144  : public OpRewritePattern<linalg::WinogradInputTransformOp> {
1145 public:
1147 
1148  LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
1149  PatternRewriter &rewriter) const override {
1150  return decomposeWinogradInputTransformHelper(rewriter, op);
1151  }
1152 };
1153 
1154 /// A rewrite pattern to decompose linalg.winograd_output_transform operations.
1155 class DecomposeWinogradOutputTransform final
1156  : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
1157 public:
1159 
1160  LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
1161  PatternRewriter &rewriter) const override {
1162  return decomposeWinogradOutputTransformHelper(rewriter, op);
1163  }
1164 };
1165 
1166 /// A rewrite pattern for Winograd Conv2D algorithm.
1167 class WinogradConv2DNhwcFhwc final
1168  : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
1169 public:
1171  WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, WinogradConv2DFmr fmr)
1172  : OpRewritePattern(context), fmr(fmr) {}
1173 
1174  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
1175  PatternRewriter &rewriter) const override {
1176  if (failed(winogradConv2DHelper(rewriter, convOp, fmr)))
1177  return failure();
1178 
1179  return success();
1180  }
1181 
1182 private:
1183  WinogradConv2DFmr fmr;
1184 };
1185 
1186 } // end anonymous namespace
1187 
1188 //===----------------------------------------------------------------------===//
1189 FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1190  linalg::Conv2DNhwcFhwcOp op,
1191  linalg::WinogradConv2DFmr fmr) {
1192  return winogradConv2DHelper(rewriter, op, fmr);
1193 }
1194 
1195 FailureOr<Operation *>
1197  linalg::WinogradFilterTransformOp op) {
1198  return decomposeWinogradFilterTransformHelper(rewriter, op);
1199 }
1200 
1201 FailureOr<Operation *>
1203  linalg::WinogradInputTransformOp op) {
1204  return decomposeWinogradInputTransformHelper(rewriter, op);
1205 }
1206 
1207 FailureOr<Operation *>
1209  linalg::WinogradOutputTransformOp op) {
1210  return decomposeWinogradOutputTransformHelper(rewriter, op);
1211 }
1212 
1214  WinogradConv2DFmr fmr) {
1215  MLIRContext *context = patterns.getContext();
1216  // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw
1217  patterns.insert<WinogradConv2DNhwcFhwc>(context, fmr);
1218 }
1219 
1221  MLIRContext *context = patterns.getContext();
1222  patterns
1223  .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
1224  DecomposeWinogradOutputTransform>(context);
1225 }
1226 
1227 } // end namespace linalg
1228 } // end namespace mlir
int64_t scalarFactor
int64_t cols
ArrayRef< float > table
int64_t rows
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, WinogradConv2DFmr fmr)
Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns)
Patterns to decompose Winograd operators.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
static bool hasAllOneValues(DenseIntElementsAttr attr)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
Definition: LinalgOps.cpp:3748
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value padding, bool nofold, ValueRange typeDynDims={})
Create a tensor::PadOp that pads source to the shape of type whose sizes are assumed to be greater th...
Definition: Utils.cpp:243
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:744
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322
ValueVector results
Definition: SCF.h:68