MLIR  22.0.0git
WinogradConv2D.cpp
Go to the documentation of this file.
1 //===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implement Winograd Conv2D algorithm. The implementation is based on the
10 // paper: Fast Algorithms for Convolutional Neural Networks
11 // (https://arxiv.org/abs/1509.09308)
12 //
13 //===----------------------------------------------------------------------===//
14 
21 #include "llvm/Support/MathExtras.h"
22 
23 namespace mlir {
24 namespace linalg {
25 
26 namespace {
27 
28 // clang-format off
29 /// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
30 /// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
31 /// m is the output dimension and r is the filter dimension, is
32 ///
33 /// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
34 ///
35 /// g is filter and d is input data. We need to prepare 6 constant
36 /// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
37 ///
38 /// The following tables define these constant transformation matrices for
39 /// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
40 ///
41 /// To add more transformation matrices, we need to add the following
42 /// items:
43 /// 1. Add the constant transformation matrix to the corresponding
44 /// G, GT, BT, B, AT, or A array.
45 /// 2. Add the corresponding TransformMatrix to the GMatrices, GTMatrices,
46 /// BTMatrices, BMatrices, ATMatrices, or AMatrices map.
47 /// 3. Add a enum value F_m_r to WinogradConv2DFmr enum.
48 ///
49 constexpr float G_2x2_3x3[] = {
50  -1, 0, 0,
51  1./2, -1./2, 1./2,
52  1./2, 1./2, 1./2,
53  0, 0, 1
54 };
55 
56 constexpr float GT_2x2_3x3[] = {
57  -1, 1./2, 1./2, 0,
58  0, -1./2, 1./2, 0,
59  0, 1./2, 1./2, 1
60 };
61 
62 constexpr float BT_2x2_3x3[] = {
63  -1, 0, 1, 0,
64  0, -1, 1, 0,
65  0, 1, 1, 0,
66  0, -1, 0, 1
67 };
68 
69 constexpr float B_2x2_3x3[] = {
70  -1, 0, 0, 0,
71  0, -1, 1, -1,
72  1, 1, 1, 0,
73  0, 0, 0, 1
74 };
75 
76 constexpr float AT_2x2_3x3[] = {
77  1, 1, 1, 0,
78  0, -1, 1, 1
79 };
80 
81 constexpr float A_2x2_3x3[] = {
82  1, 0,
83  1, -1,
84  1, 1,
85  0, 1
86 };
87 
88 constexpr float G_4x4_3x3[] = {
89  1, 0, 0,
90  -1./3, 1./3, -1./3,
91  -1./3, -1./3, -1./3,
92  1./12, -1./6, 1./3,
93  1./12, 1./6, 1./3,
94  0, 0, 1
95 };
96 
97 constexpr float GT_4x4_3x3[] = {
98  1, -1./3, -1./3, 1./12, 1./12, 0,
99  0, 1./3, -1./3, -1./6, 1./6, 0,
100  0, -1./3, -1./3, 1./3, 1./3, 1
101 };
102 
103 constexpr float BT_4x4_3x3[] = {
104  1./4, 0, -5./16, 0, 1./16, 0,
105  0, 1./4, -1./4, -1./16, 1./16, 0,
106  0, -1./4, -1./4, 1./16, 1./16, 0,
107  0, 1./4, -1./8, -1./4, 1./8, 0,
108  0, -1./4, -1./8, 1./4, 1./8, 0,
109  0, 1./4, 0, -5./16, 0, 1./16
110 };
111 
112 constexpr float B_4x4_3x3[] = {
113  1./4, 0, 0, 0, 0, 0,
114  0, 1./4, -1./4, 1./4, -1./4, 1./4,
115  -5./16, -1./4, -1./4, -1./8, -1./8, 0,
116  0, -1./16, 1./16, -1./4, 1./4, -5./16,
117  1./16, 1./16, 1./16, 1./8, 1./8, 0,
118  0, 0, 0, 0, 0, 1./16
119 };
120 
121 constexpr float AT_4x4_3x3[] = {
122  1./8, 1./4, 1./4, 1./8, 1./8, 0,
123  0, -1./4, 1./4, -1./4, 1./4, 0,
124  0, 1./4, 1./4, 1./2, 1./2, 0,
125  0, -1./4, 1./4, -1, 1, 1./2
126 };
127 
128 constexpr float A_4x4_3x3[] = {
129  1./8, 0, 0, 0,
130  1./4, -1./4, 1./4, -1./4,
131  1./4, 1./4, 1./4, 1./4,
132  1./8, -1./4, 1./2, -1,
133  1./8, 1./4, 1./2, 1,
134  0, 0, 0, 1./2
135 };
136 
137 constexpr float G_2x2_5x5[] = {
138  1, 0, 0, 0, 0,
139  1./6, -1./6, 1./6, -1./6, 1./6,
140  -1./6, -1./6, -1./6, -1./6, -1./6,
141 -4./15, 2./15, -1./15, 1./30, -1./60,
142  1./60, 1./30, 1./15, 2./15, 4./15,
143  0, 0, 0, 0, 1
144 };
145 
146 constexpr float GT_2x2_5x5[] = {
147  1, 1./6, -1./6, -4./15, 1./60, 0,
148  0, -1./6, -1./6, 2./15, 1./30, 0,
149  0, 1./6, -1./6, -1./15, 1./15, 0,
150  0, -1./6, -1./6, 1./30, 2./15, 0,
151  0, 1./6, -1./6, -1./60, 4./15, 1
152 };
153 
154 constexpr float BT_2x2_5x5[] = {
155  1./8, 3./16, -1./4, -3./16, 1./8, 0,
156  0, 1./8, 1./16, -5./16, 1./8, 0,
157  0, -1./8, -5./16, -1./16, 1./8, 0,
158  0, 1./4, -1./8, -1./4, 1./8, 0,
159  0, -1./8, -1./4, 1./8, 1./4, 0,
160  0, 1./8, 3./16, -1./4, -3./16, 1./8
161 };
162 
163 constexpr float B_2x2_5x5[] = {
164  1./8, 0, 0, 0, 0, 0,
165  3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
166  -1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
167  -3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
168  1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
169  0, 0, 0, 0, 0, 1./8
170 };
171 
172 constexpr float AT_2x2_5x5[] = {
173  1./2, 1, 1, 2, 1, 0,
174  0, -1, 1, -1, 2, 1./2
175 };
176 
177 constexpr float A_2x2_5x5[] = {
178  1./2, 0,
179  1, -1,
180  1, 1,
181  2, -1,
182  1, 2,
183  0, 1./2
184 };
185 // clang-format on
186 
187 /// Structure to keep information of constant transform matrices.
188 struct TransformMatrix {
189  TransformMatrix(const float *table, int64_t rows, int64_t cols,
190  int64_t scalarFactor = 1)
192 
193  const float *table;
194  int64_t rows;
195  int64_t cols;
196  int64_t scalarFactor;
197 };
198 
199 /// Utility function to convert constant array to arith.constant Value.
200 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
201  TransformMatrix transform, Type type) {
202  ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
203 
204  return arith::ConstantOp::create(
205  builder, loc,
208  SmallVector<int64_t>{transform.rows, transform.cols}, type),
209  constVec));
210 }
211 
212 /// Extract height x width data from 4D tensors.
213 Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
214  Value loopNorFIndex, Value loopCorFIndex,
215  Value heightOffset, Value widthOffset,
216  int64_t extractHeight, int64_t extractWidth,
217  int64_t loopNorFIdx, int64_t loopCorFIdx,
218  int64_t heightIdx, int64_t widthIdx) {
219  auto sourceType = cast<ShapedType>(source.getType());
220  Type elementType = sourceType.getElementType();
221  int64_t srcSize = sourceType.getRank();
222 
223  auto oneIndex = builder.getIndexAttr(1);
224  SmallVector<OpFoldResult> offsets;
225  offsets.resize(srcSize);
226  offsets[loopNorFIdx] = loopNorFIndex;
227  offsets[loopCorFIdx] = loopCorFIndex;
228  offsets[heightIdx] = heightOffset;
229  offsets[widthIdx] = widthOffset;
230  SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
231  sizes[heightIdx] = builder.getIndexAttr(extractHeight);
232  sizes[widthIdx] = builder.getIndexAttr(extractWidth);
233  SmallVector<OpFoldResult> strides(srcSize, oneIndex);
234 
235  auto extractFilterType =
236  RankedTensorType::get({extractHeight, extractWidth}, elementType);
237  auto extractFilterOp = tensor::ExtractSliceOp::create(
238  builder, loc, extractFilterType, source, offsets, sizes, strides);
239 
240  return extractFilterOp;
241 }
242 
243 /// Extract height x width data from 6D tensors.
244 Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
245  Value tileHIndex, Value tileWIndex,
246  Value loopNorFIndex, Value loopCorFIndex,
247  int64_t tileHIdx, int64_t tileWIdx,
248  int64_t loopNorFIdx, int64_t loopCorFIdx,
249  int64_t heightIdx, int64_t widthIdx) {
250  auto sourceType = cast<ShapedType>(source.getType());
251  Type elementType = sourceType.getElementType();
252  auto sourceShape = sourceType.getShape();
253  int64_t srcSize = sourceType.getRank();
254  int64_t height = sourceShape[heightIdx];
255  int64_t width = sourceShape[widthIdx];
256 
257  auto zeroIndex = builder.getIndexAttr(0);
258  auto oneIndex = builder.getIndexAttr(1);
259  SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
260  offsets.resize(srcSize);
261  offsets[tileHIdx] = tileHIndex;
262  offsets[tileWIdx] = tileWIndex;
263  offsets[loopNorFIdx] = loopNorFIndex;
264  offsets[loopCorFIdx] = loopCorFIndex;
265  SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
266  sizes[heightIdx] = builder.getIndexAttr(height);
267  sizes[widthIdx] = builder.getIndexAttr(width);
268  SmallVector<OpFoldResult> strides(srcSize, oneIndex);
269 
270  auto extractFilterType = RankedTensorType::get({height, width}, elementType);
271  auto extractFilterOp = tensor::ExtractSliceOp::create(
272  builder, loc, extractFilterType, source, offsets, sizes, strides);
273 
274  return extractFilterOp;
275 }
276 
277 /// Insert transformed height x width data to 4D tensors which it is
278 /// extracted from.
279 Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
280  Value dest, Value loopNorFIndex, Value loopCorFIndex,
281  Value heightOffset, Value widthOffset, int64_t height,
282  int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
283  int64_t heightIdx, int64_t widthIdx) {
284  int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
285  auto oneIndex = builder.getIndexAttr(1);
286  SmallVector<OpFoldResult> retOffsets;
287  retOffsets.resize(destSize);
288  retOffsets[loopNorFIdx] = loopNorFIndex;
289  retOffsets[loopCorFIdx] = loopCorFIndex;
290  retOffsets[heightIdx] = heightOffset;
291  retOffsets[widthIdx] = widthOffset;
292  SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
293  retSizes[heightIdx] = builder.getIndexAttr(height);
294  retSizes[widthIdx] = builder.getIndexAttr(width);
295  SmallVector<OpFoldResult> strides(destSize, oneIndex);
296 
297  auto insertSliceOp = tensor::InsertSliceOp::create(
298  builder, loc, source, dest, retOffsets, retSizes, strides);
299 
300  return insertSliceOp;
301 }
302 
303 /// Insert transformed height x width data to 6D tensors which it is
304 /// extracted from.
305 Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
306  Value dest, Value tileHIndex, Value tileWIndex,
307  Value loopNorFIndex, Value loopCorFIndex, int64_t height,
308  int64_t width, int64_t tileHIdx, int64_t tileWIdx,
309  int64_t loopNorFIdx, int64_t loopCorFIdx,
310  int64_t heightIdx, int64_t widthIdx) {
311  int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
312  auto zeroIndex = builder.getIndexAttr(0);
313  auto oneIndex = builder.getIndexAttr(1);
314  SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
315  retOffsets.resize(destSize);
316  retOffsets[tileHIdx] = tileHIndex;
317  retOffsets[tileWIdx] = tileWIndex;
318  retOffsets[loopNorFIdx] = loopNorFIndex;
319  retOffsets[loopCorFIdx] = loopCorFIndex;
320  SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
321  retSizes[heightIdx] = builder.getIndexAttr(height);
322  retSizes[widthIdx] = builder.getIndexAttr(width);
323  SmallVector<OpFoldResult> strides(destSize, oneIndex);
324 
325  auto insertSliceOp = tensor::InsertSliceOp::create(
326  builder, loc, source, dest, retOffsets, retSizes, strides);
327 
328  return insertSliceOp;
329 }
330 
331 /// This function transforms the filter. The data layout of the filter is FHWC.
332 /// The transformation matrix is 2-dimension. We need to extract H x W from
333 /// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
334 /// After the transformation, we get
335 ///
336 /// scf.for %f = lo_f to hi_f step 1
337 /// scf.for %c = lo_c to hi_c step 1
338 /// %extracted = extract filter<h x w> from filter<f x h x w x c>
339 /// %ret = linalg.matmul G, %extracted
340 /// %ret = linalg.matmul %ret, GT
341 /// %inserted = insert %ret into filter<h x w x c x f>
342 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
343  Value retValue, WinogradConv2DFmr fmr,
344  bool leftTransform = true, bool rightTransform = true) {
345  // Map from (m, r) to G transform matrix.
346  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
347  GMatrices = {
348  {WinogradConv2DFmr::F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
349  {WinogradConv2DFmr::F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
350  {WinogradConv2DFmr::F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
351  };
352 
353  // Map from (m, r) to GT transform matrix.
354  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
355  GTMatrices = {
356  {WinogradConv2DFmr::F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
357  {WinogradConv2DFmr::F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
358  {WinogradConv2DFmr::F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
359  };
360 
361  auto filterType = cast<ShapedType>(filter.getType());
362  Type elementType = filterType.getElementType();
363  auto filterShape = filterType.getShape(); // F, H, W, C
364  int64_t filterF = filterShape[0];
365  int64_t filterH = filterShape[1];
366  int64_t filterW = filterShape[2];
367  int64_t filterC = filterShape[3];
368 
369  int64_t m, r;
370  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
371  if (filterH != r && filterH != 1)
372  return Value();
373  if (filterW != r && filterW != 1)
374  return Value();
375 
376  Value zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0);
377  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
378  ValueRange args) -> scf::ValueVector {
379  Value FIter = ivs[0];
380  Value CIter = ivs[1];
381 
382  // Extract (H, W) from (F, H, W, C).
383  auto extractFilter =
384  extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
385  zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
386  /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
387 
388  int64_t retRows = 1;
389  Value matmulRetValue = extractFilter;
390  Value zero = arith::ConstantOp::create(builder, loc,
391  rewriter.getZeroAttr(elementType));
392  if (leftTransform) {
393  // Get constant transform matrix G.
394  auto it = GMatrices.find(fmr);
395  if (it == GMatrices.end())
396  return {};
397  const TransformMatrix &GMatrix = it->second;
398 
399  retRows = GMatrix.rows;
400  auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
401  auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
402  elementType)
403  .getResult();
404  auto init =
405  linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
406 
407  Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
408  // Multiply G x g.
409  auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
410  ValueRange{G, extractFilter},
411  ValueRange{init});
412  matmulRetValue = matmulOp.getResult(0);
413  }
414 
415  if (rightTransform) {
416  // Get constant transform matrix GT.
417  auto it = GTMatrices.find(fmr);
418  if (it == GTMatrices.end())
419  return {};
420  const TransformMatrix &GTMatrix = it->second;
421 
422  auto matmulType =
423  RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
424  auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
425  elementType)
426  .getResult();
427  auto init =
428  linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
429 
430  Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
431  // Multiply u = (G x g) x GT.
432  auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
433  ValueRange{matmulRetValue, GT},
434  ValueRange{init});
435  matmulRetValue = matmulOp.getResult(0);
436  }
437 
438  // Insert (H, W) to (H, W, C, F).
439  int64_t retHeight = leftTransform ? m + r - 1 : 1;
440  int64_t retWidth = rightTransform ? m + r - 1 : 1;
441 
442  auto insertSliceOp =
443  insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
444  zeroIdx, zeroIdx, retHeight, retWidth,
445  /*loopNorFIdx=*/3, /*loopCorFIdx=*/2,
446  /*heightIdx=*/0, /*widthIdx=*/1);
447 
448  return {insertSliceOp};
449  };
450 
451  auto fUpperBound = arith::ConstantIndexOp::create(rewriter, loc, filterF);
452  auto cUpperBound = arith::ConstantIndexOp::create(rewriter, loc, filterC);
453  auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1);
454  scf::LoopNest loops = scf::buildLoopNest(
455  rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
456  {oneStep, oneStep}, {retValue}, buildBody);
457  return loops.results[0];
458 }
459 
460 /// This function transforms the input. The data layout of the input is NHWC.
461 /// The transformation matrix is 2-dimension. We need to extract H x W from
462 /// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
463 /// After the transformation, we get
464 ///
465 /// scf.for %h = 0 to tileH step 1
466 /// scf.for %w = 0 to tileW step 1
467 /// scf.for %n = 0 to N step 1
468 /// scf.for %c = 0 to C step 1
469 /// %extracted = extract %extracted<alphaH x alphaW> from
470 /// %input<N x H x W x C>
471 /// at [%n, (%h x m), (%w x m), %c]
472 /// %ret = linalg.matmul BT, %extracted
473 /// %ret = linalg.matmul %ret, B
474 /// %inserted = insert %ret<alphaH x alphaW> into
475 /// %output<alphaH x alphaW x tileH x tileW x N x C>
476 /// at [0, 0, %h, %w, %n, %c]
477 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
478  Value retValue, WinogradConv2DFmr fmr,
479  bool leftTransform = true, bool rightTransform = true) {
480  // Map from (m, r) to BT transform matrix.
481  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
482  BTMatrices = {
483  {WinogradConv2DFmr::F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
484  {WinogradConv2DFmr::F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
485  {WinogradConv2DFmr::F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
486  };
487 
488  // Map from (m, r) to B transform matrix.
489  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
490  BMatrices = {
491  {WinogradConv2DFmr::F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
492  {WinogradConv2DFmr::F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
493  {WinogradConv2DFmr::F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
494  };
495 
496  int64_t m, r;
497  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
498  auto inputType = cast<ShapedType>(input.getType());
499  Type elementType = inputType.getElementType();
500  auto inputShape = inputType.getShape(); // N, H, W, C
501  int64_t inputN = inputShape[0];
502  int64_t inputC = inputShape[3];
503  auto valueType = cast<ShapedType>(retValue.getType());
504  auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C
505  int64_t tileH = valueShape[2];
506  int64_t tileW = valueShape[3];
507  int64_t alphaH = leftTransform ? m + r - 1 : 1;
508  int64_t alphaW = rightTransform ? m + r - 1 : 1;
509 
510  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
511  ValueRange args) -> scf::ValueVector {
512  Value tileHIter = ivs[0];
513  Value tileWIter = ivs[1];
514  Value NIter = ivs[2];
515  Value CIter = ivs[3];
516 
517  auto context = builder.getContext();
518 
519  auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
520  auto affineMap =
521  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
522  Value heightOffset = affine::AffineApplyOp::create(
523  builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
524  Value widthOffset = affine::AffineApplyOp::create(
525  builder, loc, rightTransform ? affineMap : identityAffineMap,
526  tileWIter);
527 
528  // Extract (H, W) from (N, H, W, C).
529  auto extractInput =
530  extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
531  widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
532  /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
533 
534  int64_t retRows = 1;
535  int64_t retCols = 1;
536  Value matmulRetValue = extractInput;
537  Value zero = arith::ConstantOp::create(builder, loc,
538  rewriter.getZeroAttr(elementType));
539  if (leftTransform) {
540  // Get constant transform matrix BT.
541  auto it = BTMatrices.find(fmr);
542  if (it == BTMatrices.end())
543  return {};
544  const TransformMatrix &BTMatrix = it->second;
545 
546  retRows = BTMatrix.rows;
547  auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
548  auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
549  elementType)
550  .getResult();
551  auto init =
552  linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
553 
554  Value BT =
555  create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
556  // Multiply BT x d.
557  auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
558  ValueRange{BT, matmulRetValue},
559  ValueRange{init});
560  matmulRetValue = matmulOp.getResult(0);
561  }
562 
563  if (rightTransform) {
564  // Get constant transform matrix B.
565  auto it = BMatrices.find(fmr);
566  if (it == BMatrices.end())
567  return {};
568  const TransformMatrix &BMatrix = it->second;
569 
570  retCols = BMatrix.cols;
571  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
572  auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
573  elementType)
574  .getResult();
575  auto init =
576  linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
577  Value B =
578  create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
579  // Multiply v = (BT x d) x B.
580  auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
581  ValueRange{matmulRetValue, B},
582  ValueRange{init});
583  matmulRetValue = matmulOp.getResult(0);
584  }
585 
586  // Insert (H, W) to (H, W, tileH, tileW, N, C).
587  auto combinedVal = insert2DDataTo6D(
588  builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
589  CIter, retRows, retCols, 2, 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5,
590  /*heightIdx=*/0, /*widthIdx=*/1);
591 
592  return {combinedVal};
593  };
594 
595  auto zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0);
596  auto tileHBound = arith::ConstantIndexOp::create(rewriter, loc, tileH);
597  auto tileWBound = arith::ConstantIndexOp::create(rewriter, loc, tileW);
598  auto nUpperBound = arith::ConstantIndexOp::create(rewriter, loc, inputN);
599  auto cUpperBound = arith::ConstantIndexOp::create(rewriter, loc, inputC);
600  auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1);
601  scf::LoopNest loops = scf::buildLoopNest(
602  rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
603  {tileHBound, tileWBound, nUpperBound, cUpperBound},
604  {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
605  return loops.results[0];
606 }
607 
608 /// This function generates linalg.batch_matmul to multiply input with filter.
609 /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat
610 /// tileH x tileW x H x W data as the 1-dimensional data array. That is to
611 /// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this
612 /// way, we can convert 6-dimensional inputs to 3-dimensional representation
613 /// that is suitable for linalg.batch_matmul.
614 ///
615 /// Batched matmul will do the matrix multiply with the reduction on channel.
616 ///
617 /// We get
618 ///
619 /// %collapsed_input = tensor.collapse_shape %input
620 /// %collapsed_filter = tensor.collapse_shape %filter
621 /// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
622 /// %expanded_ret = tensor.expand_shape %ret
623 ///
624 /// After this function, we get return value with data layout
625 /// (tileH, tileW, H, W, N, F).
626 static Value matrixMultiply(RewriterBase &rewriter, Location loc,
627  Value transformedFilter, Value transformedInput,
628  Type outputElementType) {
629  // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter.
630  auto filterType = cast<ShapedType>(transformedFilter.getType());
631  assert(filterType.hasStaticShape() && "only support static shapes.");
632  ArrayRef<int64_t> filterShape = filterType.getShape();
633  Type filterElementType = filterType.getElementType();
634  auto filterReassocType = RankedTensorType::get(
635  {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
636  filterElementType);
637  SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
638  Value collapseFilter = tensor::CollapseShapeOp::create(
639  rewriter, loc, filterReassocType, transformedFilter, filterReassoc);
640 
641  // Convert (alphaH, alphaW, tileH, tileW, N, C) to
642  // (alphaH x alphaW, tileH x tileW x N, C) for input.
643  auto inputType = cast<ShapedType>(transformedInput.getType());
644  assert(inputType.hasStaticShape() && "only support static shapes.");
645  ArrayRef<int64_t> inputShape = inputType.getShape();
646  Type inputElementType = inputType.getElementType();
647  auto inputReassocType = RankedTensorType::get(
648  {inputShape[0] * inputShape[1],
649  inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
650  inputElementType);
651  SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
652  Value collapseInput = tensor::CollapseShapeOp::create(
653  rewriter, loc, inputReassocType, transformedInput, inputReassoc);
654 
655  // Batched matrix multiply.
656  auto matmulType = RankedTensorType::get(
657  {inputShape[0] * inputShape[1],
658  inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
659  outputElementType);
660  Value empty = tensor::EmptyOp::create(rewriter, loc, matmulType.getShape(),
661  outputElementType)
662  .getResult();
663  Value zero = arith::ConstantOp::create(
664  rewriter, loc, rewriter.getZeroAttr(outputElementType));
665  Value init = linalg::FillOp::create(rewriter, loc, zero, empty).getResult(0);
666 
667  auto matmulOp = linalg::BatchMatmulOp::create(
668  rewriter, loc, matmulType, ValueRange({collapseInput, collapseFilter}),
669  ValueRange{init});
670 
671  // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F)
672  // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F).
673  SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
674  auto outputReassocType =
675  RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2],
676  inputShape[3], inputShape[4], filterShape[3]},
677  outputElementType);
678  auto expandOutput = tensor::ExpandShapeOp::create(
679  rewriter, loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
680  return expandOutput;
681 }
682 
683 /// This function transforms the output. The data layout of the output is HWNF.
684 /// The transformation matrix is 2-dimension. We need to extract H x W from
685 /// HWNF first. We need to generate 2 levels of loops to iterate on N and F.
686 /// After the transformation, we get
687 ///
688 /// scf.for %h = 0 to tileH step 1
689 /// scf.for %w = 0 to tileW step 1
690 /// scf.for %n = 0 to N step 1
691 /// scf.for %f = 0 to F step 1
692 /// %extracted = extract %extracted<alphaH x alphaW> from
693 /// %input<alphaH x alphaW x tileH x tileW x N x F>
694 /// at [0, 0, %h, %w, %n, %f]
695 /// %ret = linalg.matmul AT, %extracted
696 /// %ret = linalg.matmul %ret, A
697 /// %inserted = insert %ret<alphaH x alphaW> into
698 /// output<N x H x W x F>
699 /// at [%n, (%h x m), (%w x m), %f]
700 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
701  Value output, WinogradConv2DFmr fmr,
702  bool leftTransform = true, bool rightTransform = true) {
703  // Map from (m, r) to AT transform matrix.
704  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
705  ATMatrices = {
706  {WinogradConv2DFmr::F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
707  {WinogradConv2DFmr::F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
708  {WinogradConv2DFmr::F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
709  };
710 
711  // Map from (m, r) to A transform matrix.
712  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
713  AMatrices = {
714  {WinogradConv2DFmr::F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
715  {WinogradConv2DFmr::F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
716  {WinogradConv2DFmr::F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
717  };
718 
719  int64_t m, r;
720  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
721  auto valueType = cast<ShapedType>(value.getType());
722  Type elementType = valueType.getElementType();
723  auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
724  int64_t valueH = valueShape[0];
725  int64_t valueW = valueShape[1];
726  int64_t valueN = valueShape[4];
727  int64_t valueF = valueShape[5];
728  int64_t alphaH = leftTransform ? m + r - 1 : 1;
729  int64_t alphaW = rightTransform ? m + r - 1 : 1;
730 
731  if (valueH != alphaH && valueH != 1)
732  return Value();
733  if (valueW != alphaW && valueW != 1)
734  return Value();
735 
736  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
737  ValueRange args) -> scf::ValueVector {
738  auto context = builder.getContext();
739  Value tileHIter = ivs[0];
740  Value tileWIter = ivs[1];
741  Value NIter = ivs[2];
742  Value FIter = ivs[3];
743 
744  // Extract (H, W) from (H, W, tileH, tileW, N, F).
745  auto extractValue =
746  extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
747  FIter, 2, 3, /*loopNorFIdx=*/4,
748  /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
749 
750  const TransformMatrix &AMatrix = AMatrices.at(fmr);
751  const TransformMatrix &ATMatrix = ATMatrices.at(fmr);
752  int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
753  (leftTransform ? ATMatrix.scalarFactor : 1);
754  int64_t retCols = rightTransform ? AMatrix.cols : 1;
755  int64_t retRows = leftTransform ? ATMatrix.rows : 1;
756 
757  Value matmulRetValue = extractValue;
758  Value zero = arith::ConstantOp::create(builder, loc,
759  rewriter.getZeroAttr(elementType));
760 
761  auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
762  auto affineMap =
763  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
764  Value heightOffset = affine::AffineApplyOp::create(
765  builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
766  Value widthOffset = affine::AffineApplyOp::create(
767  builder, loc, rightTransform ? affineMap : identityAffineMap,
768  tileWIter);
769 
770  Value outInitVal =
771  extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
772  widthOffset, retRows, retCols,
773  /*loopNorFIdx=*/0,
774  /*loopCorFIdx=*/3, /*heightIdx=*/1,
775  /*widthIdx=*/2);
776  if (leftTransform) {
777  auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
778  Value init = outInitVal;
779  if (rightTransform || scalarFactor != 1) {
780  auto empty = tensor::EmptyOp::create(builder, loc,
781  matmulType.getShape(), elementType)
782  .getResult();
783  init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
784  }
785 
786  Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
787  // Multiply AT x m.
788  auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
789  ValueRange{AT, matmulRetValue},
790  ValueRange{init});
791  matmulRetValue = matmulOp.getResult(0);
792  }
793 
794  if (rightTransform) {
795  auto matmulType =
796  RankedTensorType::get({retRows, AMatrix.cols}, elementType);
797  Value init = outInitVal;
798  if (scalarFactor != 1) {
799  auto empty = tensor::EmptyOp::create(builder, loc,
800  matmulType.getShape(), elementType)
801  .getResult();
802  init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
803  }
804 
805  Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
806  // Multiply y = (AT x m) x A.
807  auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
808  ValueRange{matmulRetValue, A},
809  ValueRange{init});
810  matmulRetValue = matmulOp.getResult(0);
811  }
812 
813  if (scalarFactor != 1) {
814  // Multiply by scalar factor and add outInitVal.
815  Value scalarFactorValue = arith::ConstantOp::create(
816  builder, loc, FloatAttr::get(elementType, scalarFactor));
817  auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
818  auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
819  SmallVector<AffineMap> affineMaps = {
820  AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
821 
822  matmulRetValue =
823  linalg::GenericOp::create(
824  rewriter, loc, matmulType,
825  ValueRange{scalarFactorValue, matmulRetValue},
826  ValueRange{outInitVal}, affineMaps,
828  utils::IteratorType::parallel, utils::IteratorType::parallel},
829  [&](OpBuilder &nestedBuilder, Location nestedLoc,
830  ValueRange args) {
831  auto mulf = arith::MulFOp::create(nestedBuilder, nestedLoc,
832  args[0], args[1]);
833  auto addf = arith::AddFOp::create(nestedBuilder, nestedLoc,
834  mulf.getResult(), args[2]);
835  linalg::YieldOp::create(nestedBuilder, nestedLoc,
836  addf.getResult());
837  })
838  .getResult(0);
839  }
840 
841  // Insert (H, W) to (N, H, W, F).
842  Value combinedVal =
843  insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
844  heightOffset, widthOffset, retRows, retCols,
845  /*loopNorFIdx=*/0,
846  /*loopCorFIdx=*/3, /*heightIdx=*/1,
847  /*widthIdx=*/2);
848 
849  return {combinedVal};
850  };
851 
852  int64_t tilwH = valueShape[2];
853  int64_t tileW = valueShape[3];
854  auto zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0);
855  auto tileHBound = arith::ConstantIndexOp::create(rewriter, loc, tilwH);
856  auto tileWBound = arith::ConstantIndexOp::create(rewriter, loc, tileW);
857  auto nUpperBound = arith::ConstantIndexOp::create(rewriter, loc, valueN);
858  auto fUpperBound = arith::ConstantIndexOp::create(rewriter, loc, valueF);
859  auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1);
860  scf::LoopNest loops = scf::buildLoopNest(
861  rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
862  {tileHBound, tileWBound, nUpperBound, fUpperBound},
863  {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
864  return loops.results[0];
865 }
866 
867 /// Create an empty tensor with alignedType and insert the value into the
868 /// created empty tensor with aligned size.
869 static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
870  Value value, ArrayRef<int64_t> alignedShape) {
871  auto valueType = cast<ShapedType>(value.getType());
872  Type elementType = valueType.getElementType();
873  auto alignedType = RankedTensorType::get(alignedShape, elementType);
874  Value padValue = arith::ConstantOp::create(rewriter, loc, elementType,
875  rewriter.getZeroAttr(elementType));
876 
877  return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value,
878  padValue, false);
879 }
880 
881 /// Extract sub-tensor with extractedType from value.
882 static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
883  Value value,
884  RankedTensorType extractedType) {
885  OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
886  OpFoldResult oneIndex = rewriter.getIndexAttr(1);
887  SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
888  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
889 
890  ArrayRef<int64_t> extractedShape = extractedType.getShape();
891  SmallVector<OpFoldResult> sizes =
892  getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape));
893 
894  return tensor::ExtractSliceOp::create(rewriter, loc, extractedType, value,
895  offsets, sizes, strides);
896 }
897 
898 /// Utility function to check all values in the attribute are 1.
899 static bool hasAllOneValues(DenseIntElementsAttr attr) {
900  return llvm::all_of(
901  attr, [](const APInt &element) { return element.getSExtValue() == 1; });
902 }
903 
904 /// A helper function to convert linalg.conv_2d_nhwc_fhwc to
905 /// linalg.winograd_*_transform ops.
906 static FailureOr<Operation *>
907 winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
908  WinogradConv2DFmr fmr) {
909  if (!convOp.hasPureTensorSemantics())
910  return rewriter.notifyMatchFailure(
911  convOp, "expected pure tensor semantics for linalg.conv_2d_nhwc_fhwc");
912 
913  Value input = convOp.getInputs()[0];
914  Value filter = convOp.getInputs()[1];
915  Value output = convOp.getOutputs()[0];
916  auto inputType = cast<ShapedType>(input.getType());
917  auto filterType = cast<ShapedType>(filter.getType());
918  auto outputType = cast<ShapedType>(output.getType());
919 
920  if (!inputType.hasStaticShape())
921  return rewriter.notifyMatchFailure(convOp,
922  "expected a static shape for the input");
923 
924  if (!filterType.hasStaticShape())
925  return rewriter.notifyMatchFailure(
926  convOp, "expected a static shape for the filter");
927 
928  if (!hasAllOneValues(convOp.getDilations()))
929  return rewriter.notifyMatchFailure(convOp,
930  "expected all ones for dilations");
931 
932  if (!hasAllOneValues(convOp.getStrides()))
933  return rewriter.notifyMatchFailure(convOp, "expected all ones for strides");
934 
935  ArrayRef<int64_t> filterShape = filterType.getShape();
936  int64_t filterF = filterShape[0];
937  int64_t filterH = filterShape[1];
938  int64_t filterW = filterShape[2];
939  int64_t filterC = filterShape[3];
940  ArrayRef<int64_t> inputShape = inputType.getShape();
941  int64_t inputN = inputShape[0];
942  int64_t inputH = inputShape[1];
943  int64_t inputW = inputShape[2];
944  int64_t inputC = inputShape[3];
945  ArrayRef<int64_t> outputShape = outputType.getShape();
946  int64_t outputN = outputShape[0];
947  int64_t outputH = outputShape[1];
948  int64_t outputW = outputShape[2];
949  int64_t outputF = outputShape[3];
950 
951  int64_t m, r;
952  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
953  // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r).
954  bool isSupportedFilter = false;
955  if (filterH == filterW && filterH == r)
956  isSupportedFilter = true;
957  if (filterH == r && filterW == 1)
958  isSupportedFilter = true;
959  if (filterH == 1 && filterW == r)
960  isSupportedFilter = true;
961 
962  if (!isSupportedFilter)
963  return rewriter.notifyMatchFailure(
964  convOp, "only support filter (r x r), (r x 1) or (1 x r)");
965 
966  // All the criterias are satisfied. We can do Winograd Conv2D.
967  Location loc = convOp.getLoc();
968 
969  // For F(m x 1, r x 1), we only need to do left side transform.
970  bool leftTransform = filterH != 1;
971  // For F(1 x m, 1 x r), we only need to do right side transform.
972  bool rightTransform = filterW != 1;
973  int64_t heightM = leftTransform ? m : 1;
974  int64_t widthM = rightTransform ? m : 1;
975  int64_t heightR = leftTransform ? r : 1;
976  int64_t widthR = rightTransform ? r : 1;
977 
978  // --- Create operation for filter transform ---
979  Type filterElementType = filterType.getElementType();
980  int64_t alphaH = heightM + heightR - 1;
981  int64_t alphaW = widthM + widthR - 1;
982  int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
983  int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
984  auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF},
985  filterElementType);
986  Value retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(),
987  filterElementType);
988  auto transformedFilter = linalg::WinogradFilterTransformOp::create(
989  rewriter, loc, retType, filter, retValue, fmr);
990 
991  // --- Create operation for input transform ---
992 
993  // When input size - (r - 1) is not aligned with output tile size, we need to
994  // pad the input data to create the full tiles as tiling.
995  Type inputElementType = inputType.getElementType();
996  int64_t alignedInputH = tileH * heightM + (heightR - 1);
997  int64_t alignedInputW = tileW * widthM + (widthR - 1);
998  if (alignedInputH != inputH || alignedInputW != inputW) {
999  input = padToAlignedTensor(rewriter, loc, input,
1000  {inputN, alignedInputH, alignedInputW, inputC});
1001  }
1002 
1003  retType = RankedTensorType::get(
1004  {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
1005  retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(),
1006  inputElementType);
1007  auto transformedInput = linalg::WinogradInputTransformOp::create(
1008  rewriter, loc, retType, input, retValue, fmr);
1009 
1010  Type outputElementType = outputType.getElementType();
1011  Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
1012  transformedInput, outputElementType);
1013 
1014  // --- Create operation for output transform ---
1015 
1016  // When output size is not aligned with output tile size, we need to pad the
1017  // output buffer to insert the full tiles after tiling.
1018  int64_t alignedOutputH = tileH * heightM;
1019  int64_t alignedOutputW = tileW * widthM;
1020  bool isOutputUnaligned =
1021  ((alignedOutputH != outputH) || (alignedOutputW != outputW));
1022  if (isOutputUnaligned) {
1023  auto alignedOutputType = RankedTensorType::get(
1024  {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
1025  output =
1026  padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
1027  outputType = alignedOutputType;
1028  }
1029 
1030  Value transformedOutput = linalg::WinogradOutputTransformOp::create(
1031  rewriter, loc, outputType, matmulRet, output, fmr);
1032 
1033  // When output size is not aligned with output tile size, extract the
1034  // value from the padded buffer.
1035  if (isOutputUnaligned) {
1036  transformedOutput = extractFromAlignedTensor(
1037  rewriter, loc, transformedOutput,
1038  RankedTensorType::get({outputN, outputH, outputW, outputF},
1039  outputElementType));
1040  }
1041 
1042  rewriter.replaceOp(convOp, transformedOutput);
1043 
1044  return transformedOutput.getDefiningOp();
1045 }
1046 
1047 /// A helper function to decompose linalg.winograd_filter_transform.
1048 FailureOr<Operation *>
1049 decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
1050  linalg::WinogradFilterTransformOp op) {
1051  Location loc = op.getLoc();
1052  Value filter = op.getFilter();
1053  auto filterType = cast<ShapedType>(filter.getType());
1054  auto filterShape = filterType.getShape();
1055  int64_t filterH = filterShape[1];
1056  int64_t filterW = filterShape[2];
1057 
1058  // For F(m x 1, r x 1), we only need to do left side transform.
1059  bool leftTransform = filterH != 1;
1060  // For F(1 x m, 1 x r), we only need to do right side transform.
1061  bool rightTransform = filterW != 1;
1062  Value transformedFilter =
1063  filterTransform(rewriter, loc, filter, op.getOutput(), op.getFmr(),
1064  leftTransform, rightTransform);
1065  if (!transformedFilter)
1066  return failure();
1067 
1068  rewriter.replaceOp(op, transformedFilter);
1069 
1070  return transformedFilter.getDefiningOp();
1071 }
1072 
1073 /// A helper function to decompose linalg.winograd_input_transform.
1074 FailureOr<Operation *>
1075 decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
1076  linalg::WinogradInputTransformOp op) {
1077  Location loc = op.getLoc();
1078  Value output = op.getOutput();
1079  auto outputType = cast<ShapedType>(output.getType());
1080  auto outputShape = outputType.getShape();
1081 
1082  int64_t outputH = outputShape[0];
1083  int64_t outputW = outputShape[1];
1084 
1085  // For F(m x 1, r x 1), we only need to do left side transform.
1086  bool leftTransform = outputH != 1;
1087  // For F(1 x m, 1 x r), we only need to do right side transform.
1088  bool rightTransform = outputW != 1;
1089  Value transformedInput =
1090  inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getFmr(),
1091  leftTransform, rightTransform);
1092  if (!transformedInput)
1093  return failure();
1094 
1095  rewriter.replaceOp(op, transformedInput);
1096 
1097  return transformedInput.getDefiningOp();
1098 }
1099 
1100 /// A helper function to decompose linalg.winograd_output_transform.
1101 FailureOr<Operation *>
1102 decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
1103  linalg::WinogradOutputTransformOp op) {
1104  Location loc = op.getLoc();
1105  Value value = op.getValue();
1106  auto valueType = cast<ShapedType>(value.getType());
1107  auto valueShape = valueType.getShape();
1108  int64_t valueH = valueShape[0];
1109  int64_t valueW = valueShape[1];
1110 
1111  // For F(m x 1, r x 1), we only need to do left side transform.
1112  bool leftTransform = valueH != 1;
1113  // For F(1 x m, 1 x r), we only need to do right side transform.
1114  bool rightTransform = valueW != 1;
1115  Value transformedOutput =
1116  outputTransform(rewriter, loc, value, op.getOutput(), op.getFmr(),
1117  leftTransform, rightTransform);
1118  if (!transformedOutput)
1119  return failure();
1120 
1121  rewriter.replaceOp(op, transformedOutput);
1122 
1123  return transformedOutput.getDefiningOp();
1124 }
1125 
1126 /// A rewrite pattern to decompose linalg.winograd_filter_transform operations.
1127 class DecomposeWinogradFilterTransform final
1128  : public OpRewritePattern<linalg::WinogradFilterTransformOp> {
1129 public:
1131 
1132  LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
1133  PatternRewriter &rewriter) const override {
1134  return decomposeWinogradFilterTransformHelper(rewriter, op);
1135  }
1136 };
1137 
1138 /// A rewrite pattern to decompose linalg.winograd_input_transform operations.
1139 class DecomposeWinogradInputTransform final
1140  : public OpRewritePattern<linalg::WinogradInputTransformOp> {
1141 public:
1143 
1144  LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
1145  PatternRewriter &rewriter) const override {
1146  return decomposeWinogradInputTransformHelper(rewriter, op);
1147  }
1148 };
1149 
1150 /// A rewrite pattern to decompose linalg.winograd_output_transform operations.
1151 class DecomposeWinogradOutputTransform final
1152  : public OpRewritePattern<linalg::WinogradOutputTransformOp> {
1153 public:
1155 
1156  LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
1157  PatternRewriter &rewriter) const override {
1158  return decomposeWinogradOutputTransformHelper(rewriter, op);
1159  }
1160 };
1161 
1162 /// A rewrite pattern for Winograd Conv2D algorithm.
1163 class WinogradConv2DNhwcFhwc final
1164  : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
1165 public:
1167  WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, WinogradConv2DFmr fmr)
1168  : OpRewritePattern(context), fmr(fmr) {}
1169 
1170  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
1171  PatternRewriter &rewriter) const override {
1172  if (failed(winogradConv2DHelper(rewriter, convOp, fmr)))
1173  return failure();
1174 
1175  return success();
1176  }
1177 
1178 private:
1179  WinogradConv2DFmr fmr;
1180 };
1181 
1182 } // end anonymous namespace
1183 
1184 //===----------------------------------------------------------------------===//
1185 FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1186  linalg::Conv2DNhwcFhwcOp op,
1187  linalg::WinogradConv2DFmr fmr) {
1188  return winogradConv2DHelper(rewriter, op, fmr);
1189 }
1190 
1191 FailureOr<Operation *>
1193  linalg::WinogradFilterTransformOp op) {
1194  return decomposeWinogradFilterTransformHelper(rewriter, op);
1195 }
1196 
1197 FailureOr<Operation *>
1199  linalg::WinogradInputTransformOp op) {
1200  return decomposeWinogradInputTransformHelper(rewriter, op);
1201 }
1202 
1203 FailureOr<Operation *>
1205  linalg::WinogradOutputTransformOp op) {
1206  return decomposeWinogradOutputTransformHelper(rewriter, op);
1207 }
1208 
1210  WinogradConv2DFmr fmr) {
1211  MLIRContext *context = patterns.getContext();
1212  // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw
1213  patterns.insert<WinogradConv2DNhwcFhwc>(context, fmr);
1214 }
1215 
1217  MLIRContext *context = patterns.getContext();
1218  patterns
1219  .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
1220  DecomposeWinogradOutputTransform>(context);
1221 }
1222 
1223 } // end namespace linalg
1224 } // end namespace mlir
int64_t scalarFactor
const float * table
int64_t cols
int64_t rows
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, WinogradConv2DFmr fmr)
Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns)
Patterns to decompose Winograd operators.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
static bool hasAllOneValues(DenseIntElementsAttr attr)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
Definition: LinalgOps.cpp:3748
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value padding, bool nofold, ValueRange typeDynDims={})
Create a tensor::PadOp that pads source to the shape of type whose sizes are assumed to be greater th...
Definition: Utils.cpp:243
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
@ Type
An inlay hint that for a type annotation.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:707
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
ValueVector results
Definition: SCF.h:68