MLIR  22.0.0git
ConvertConv2DToImg2Col.cpp
Go to the documentation of this file.
1 //===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include <utility>
23 
24 namespace mlir {
25 namespace linalg {
27  return llvm::all_of(
28  attr, [](const APInt &element) { return element.getSExtValue() == 1; });
29 }
30 
31 static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
32  if (isa<IntegerType>(x.getType()))
33  return arith::AddIOp::create(builder, loc, x, y);
34  if (isa<ComplexType>(x.getType()))
35  return complex::AddOp::create(builder, loc, x, y);
36  return arith::AddFOp::create(builder, loc, x, y);
37 }
38 
39 static Value createMul(Location loc, Value x, Value y, Type accType,
40  OpBuilder &builder) {
41  // Linalg named ops specify signed extend for named ops.
42  Value xConvert =
43  convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
44  Value yConvert =
45  convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
46  if (isa<ComplexType>(accType))
47  return complex::MulOp::create(builder, loc, xConvert, yConvert);
48  if (isa<IntegerType>(accType))
49  return arith::MulIOp::create(builder, loc, xConvert, yConvert);
50  return arith::MulFOp::create(builder, loc, xConvert, yConvert);
51 }
52 
53 // Delinearizes the given composite `index` by the basis specified in `factors`.
55  ArrayRef<int64_t> factors) {
56  assert(!factors.empty() && "empty factor list");
57  SmallVector<Value> basis;
58  for (int64_t f : factors)
59  basis.push_back(arith::ConstantOp::create(b, loc, b.getIndexAttr(f)));
60  FailureOr<SmallVector<Value>> multiIndex =
61  affine::delinearizeIndex(b, loc, index, basis);
62  assert(!failed(multiIndex) && "Failed to linearize img2col index");
63  return *multiIndex;
64 }
65 
66 // Given indices corresponding to iterators in the output (oIndex) and filter
67 // (fIndex) for a convolution, compute the convolved index for the
68 // input as `oIndex * stride + fIndex`.
70  Value fIndex, int64_t stride) {
71  AffineExpr oExpr, fExpr;
72  bindSymbols(b.getContext(), oExpr, fExpr);
73  AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
74  return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
75 }
76 
77 FailureOr<std::pair<Operation *, Operation *>>
78 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
79  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
80  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
81  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
82 
83  if (!filterType.hasStaticShape())
84  return rewriter.notifyMatchFailure(
85  convOp, "expected a static shape for the filter");
86 
87  if (!inputType.hasStaticShape())
88  return rewriter.notifyMatchFailure(convOp,
89  "expected a static shape for the input");
90 
91  // TODO: Support dilation.
92  if (!hasAllOneValues(convOp.getDilations()))
93  return rewriter.notifyMatchFailure(convOp,
94  "expected all ones for dilations");
95 
96  MLIRContext *context = rewriter.getContext();
97  Value input = convOp.getInputs()[0];
98  Value filter = convOp.getInputs()[1];
99  Value output = convOp.getOutputs()[0];
100 
101  ArrayRef<int64_t> filterShape = filterType.getShape();
102  ArrayRef<int64_t> outputShape = outputType.getShape();
103 
104  int64_t n = outputShape[0];
105  int64_t oh = outputShape[1];
106  int64_t ow = outputShape[2];
107  int64_t oc = outputShape[3];
108  int64_t fh = filterShape[0];
109  int64_t fw = filterShape[1];
110  int64_t ic = filterShape[2];
111 
112  Location loc = convOp.getLoc();
113 
114  // Reshape output and filter to the LHS and result of a (B)MNK matmul.
115  SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
116  auto reshapedFilterType =
117  RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
118  Value reshapedFilter = tensor::CollapseShapeOp::create(
119  rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
120 
121  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
122  RankedTensorType reshapedOutputType =
123  RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
124  Value reshapedOutput = tensor::CollapseShapeOp::create(
125  rewriter, loc, reshapedOutputType, output, outputReassocIndices);
126 
127  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
128  Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
129  inputType.getElementType());
130 
131  // Convert the input to a (BMK) column tensor.
132  auto nloops = colTensorShape.size();
133 
134  auto parallel = utils::IteratorType::parallel;
135  auto reduction = utils::IteratorType::reduction;
136  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
137 
138  SmallVector<AffineMap> img2colIndexingMaps = {
139  AffineMap::getMultiDimIdentityMap(nloops, context)};
140 
141  auto img2ColTensor = linalg::GenericOp::create(
142  rewriter, loc, colTensor.getType(),
143  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
144  img2colIterators,
145  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
146  // Get the iterators named based on the matmul (batch, m, k).
147  Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
148  Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
149  Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
150 
151  // Recover the original iteration indices from the problem/input sizes.
152  SmallVector<Value> mIndices = unrollIndex(
153  nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
154  auto ohIndex = mIndices[0];
155  auto owIndex = mIndices[1];
156 
157  SmallVector<Value> kIndices = unrollIndex(
158  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
159  auto fhIndex = kIndices[0];
160  auto fwIndex = kIndices[1];
161  auto icIndex = kIndices[2];
162 
163  // Extract the input element corresponding to the expanded indices.
164  Value hIndex =
165  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
166  convOp.getStrides().getValues<int64_t>()[0]);
167  Value wIndex =
168  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
169  convOp.getStrides().getValues<int64_t>()[1]);
170 
171  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
172  SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
173  Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
174  extractionIndices);
175  linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
176  });
177 
178  // Because the filter does not share the same batch dimension,
179  // the batch dimension is only used in indexing the input and output. Thus
180  // we cannot use existing linalg named ops like linalg.batch_matmul.
181  // i.e. (B x) M x K * K x N = (B x) M x N
182  AffineExpr bDim, mDim, nDim, kDim;
183  bindDims(context, bDim, mDim, nDim, kDim);
184  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
185  auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
186  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
187  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
188  parallel, reduction};
189 
190  auto genericOp = linalg::GenericOp::create(
191  rewriter, loc, reshapedOutputType,
192  /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
193  /*outputs=*/ValueRange{reshapedOutput},
194  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
195  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
196  Value mul =
197  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
198  Value add = createAdd(loc, mul, args[2], nestedBuilder);
199  linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
200  });
201  Value result = genericOp.getResults().front();
202 
203  auto reshapedResult = tensor::ExpandShapeOp::create(
204  rewriter, loc, outputType, result, outputReassocIndices);
205 
206  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
207 
208  return std::make_pair(img2ColTensor.getOperation(),
209  reshapedResult.getOperation());
210 }
211 
212 FailureOr<std::pair<Operation *, Operation *>>
214  linalg::DepthwiseConv2DNhwcHwcOp convOp) {
215  auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
216  auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
217  auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
218 
219  if (!filterType.hasStaticShape())
220  return rewriter.notifyMatchFailure(
221  convOp, "expected a static shape for the filter");
222 
223  if (!inputType.hasStaticShape())
224  return rewriter.notifyMatchFailure(convOp,
225  "expected a static shape for the input");
226 
227  // TODO: Support dilation.
228  if (!hasAllOneValues(convOp.getDilations()))
229  return rewriter.notifyMatchFailure(convOp,
230  "expected all ones for dilations");
231 
232  Location loc = convOp.getLoc();
233 
234  auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
235  auto operandTensorType = cast<RankedTensorType>(operand.getType());
236  auto nloops = indices.size();
237  ArrayRef<int64_t> inputShape = operandTensorType.getShape();
238 
239  SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
240  llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
241  return rewriter.getAffineDimExpr(index);
242  }));
243 
244  SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
245  indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
246 
247  Value outputTensor = tensor::EmptyOp::create(
248  rewriter, loc, targetShape, operandTensorType.getElementType());
249 
250  SmallVector<utils::IteratorType> loopAttributeTypes(
251  nloops, utils::IteratorType::parallel);
252 
253  SmallVector<AffineMap> indexingMaps = {
255  AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
256  AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
257 
258  auto transposedOp = linalg::GenericOp::create(
259  rewriter, loc, outputTensor.getType(),
260  /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
261  loopAttributeTypes,
262  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
263  linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
264  });
265 
266  return transposedOp.getResult(0);
267  };
268 
269  Value input = convOp.getInputs()[0];
270  Value filter = convOp.getInputs()[1];
271  Value output = convOp.getOutputs()[0];
272 
273  // Transpose input, filter so channels are outermost
274  Value inputT = transposeOperand(input, {0, 3, 1, 2});
275  Value filterT = transposeOperand(filter, {2, 0, 1});
276  ArrayRef<int64_t> filterTShape =
277  cast<RankedTensorType>(filterT.getType()).getShape();
278  ArrayRef<int64_t> outputShape = outputType.getShape();
279 
280  int n = outputShape[0];
281  int oh = outputShape[1];
282  int ow = outputShape[2];
283  int c = outputShape[3];
284  int fh = filterTShape[1];
285  int fw = filterTShape[2];
286 
287  SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
288  Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
289 
290  AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
291  bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
292 
293  AffineExpr shSym = rewriter.getAffineConstantExpr(
294  convOp.getStrides().getValues<int64_t>()[0]);
295  AffineExpr swSym = rewriter.getAffineConstantExpr(
296  convOp.getStrides().getValues<int64_t>()[1]);
297 
298  SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
299  owDim * swSym + kwDim};
300 
301  auto nloops = colTensorShape.size();
302 
303  SmallVector<utils::IteratorType> loopAttributeTypes(
304  nloops, utils::IteratorType::parallel);
305 
306  SmallVector<AffineMap> indexingMaps = {
307  AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
308  AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
309 
310  Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
311  inputType.getElementType());
312 
313  auto img2ColTensor = linalg::GenericOp::create(
314  rewriter, loc, colTensor.getType(),
315  /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
316  loopAttributeTypes,
317  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
318  linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
319  });
320 
321  SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
322  {0, 1}, {2, 3}, {4, 5}};
323  SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
324  SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
325  {2, 3}};
326 
327  auto reshapedImg2ColTensorType = RankedTensorType::get(
328  {n * c, oh * ow, fh * fw}, inputType.getElementType());
329  auto reshapedFilterTensorType =
330  RankedTensorType::get({c, fh * fw}, filterType.getElementType());
331  auto reshapedOutputTensorType =
332  RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
333 
334  Value reshapedImg2ColTensor = tensor::CollapseShapeOp::create(
335  rewriter, loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
336  img2ColTensorReassocIndices);
337  Value reshapedFilterTensor =
338  tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterTensorType,
339  filterT, filterReassociationIndice);
340  Value reshapedoutputTensor = tensor::CollapseShapeOp::create(
341  rewriter, loc, reshapedOutputTensorType, transposedOutputTensor,
342  outputReassociationIndice);
343 
344  auto batchMatVecResult = linalg::BatchMatvecOp::create(
345  rewriter, loc, TypeRange{reshapedoutputTensor.getType()},
346  ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
347  ValueRange{reshapedoutputTensor});
348 
349  SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
350  {2, 3}};
351 
352  auto batchMatVecResultReshaped = tensor::ExpandShapeOp::create(
353  rewriter, loc, transposedOutputTensor.getType(),
354  batchMatVecResult.getResult(0), batchMatVecReassociationIndice);
355 
356  Value transposedResult =
357  transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
358 
359  rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
360  return std::make_pair(img2ColTensor.getOperation(),
361  transposedResult.getDefiningOp());
362 }
363 
364 FailureOr<std::pair<Operation *, Operation *>>
365 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
366  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
367  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
368  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
369 
370  if (!filterType.hasStaticShape())
371  return rewriter.notifyMatchFailure(
372  convOp, "expected a static shape for the filter");
373 
374  if (!inputType.hasStaticShape())
375  return rewriter.notifyMatchFailure(convOp,
376  "expected a static shape for the input");
377 
378  // TODO: Support dilation.
379  if (!hasAllOneValues(convOp.getDilations()))
380  return rewriter.notifyMatchFailure(convOp,
381  "expected all ones for dilations");
382 
383  Value input = convOp.getInputs()[0];
384  Value filter = convOp.getInputs()[1];
385  Value output = convOp.getOutputs()[0];
386 
387  auto filterShape = filterType.getShape();
388  auto outputShape = outputType.getShape();
389 
390  int64_t n = outputShape[0];
391  int64_t oc = outputShape[1];
392  int64_t oh = outputShape[2];
393  int64_t ow = outputShape[3];
394  int64_t ic = filterShape[1];
395  int64_t fh = filterShape[2];
396  int64_t fw = filterShape[3];
397 
398  auto loc = convOp.getLoc();
399  MLIRContext *context = rewriter.getContext();
400 
401  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
402  auto reshapedFilterType =
403  RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
404  Value reshapedFilter = tensor::CollapseShapeOp::create(
405  rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
406 
407  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
408  auto reshapedOutputType =
409  RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
410  Value reshapedOutput = tensor::CollapseShapeOp::create(
411  rewriter, loc, reshapedOutputType, output, outputReassocIndices);
412 
413  // Convert the input to a (BKN) tensor.
414  SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
415  Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
416  inputType.getElementType());
417 
418  auto nloops = colTensorShape.size();
419 
420  auto parallel = utils::IteratorType::parallel;
421  auto reduction = utils::IteratorType::reduction;
422  SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
423 
424  SmallVector<AffineMap, 4> img2colIndexingMaps = {
425  AffineMap::getMultiDimIdentityMap(nloops, context)};
426 
427  auto img2ColTensor = linalg::GenericOp::create(
428  rewriter, loc, colTensor.getType(),
429  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
430  img2colIterators,
431  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
432  // Get the iterators named based on the matmul (batch, m, k).
433  Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
434  Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
435  Value nIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
436 
437  // Recover the original iteration indices from the problem/input sizes.
438  SmallVector<Value> kIndices = unrollIndex(
439  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
440  auto icIndex = kIndices[0];
441  auto fhIndex = kIndices[1];
442  auto fwIndex = kIndices[2];
443 
444  SmallVector<Value> nIndices = unrollIndex(
445  nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
446  auto ohIndex = nIndices[0];
447  auto owIndex = nIndices[1];
448 
449  // Extract the input element corresponding to the expanded indices.
450  Value hIndex =
451  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
452  convOp.getStrides().getValues<int64_t>()[0]);
453  Value wIndex =
454  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
455  convOp.getStrides().getValues<int64_t>()[1]);
456 
457  // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
458  SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
459  Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
460  extractionIndices);
461  linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
462  });
463 
464  // Because the filter does not share the same batch dimension,
465  // the batch dimension is only used in indexing the input and output. Thus
466  // we cannot use existing linalg named ops like linalg.batch_matmul.
467  // i.e. M x K * (B x) K x N = (B x) M x N
468  AffineExpr bDim, mDim, nDim, kDim;
469  bindDims(context, bDim, mDim, nDim, kDim);
470  auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
471  auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
472  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
473  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
474  parallel, reduction};
475  auto genericOp = linalg::GenericOp::create(
476  rewriter, loc, reshapedOutputType,
477  /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
478  /*outputs=*/ValueRange{reshapedOutput},
479  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
480  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
481  Value mul =
482  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
483  Value add = createAdd(loc, mul, args[2], nestedBuilder);
484  linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
485  });
486  Value result = genericOp.getResults().front();
487 
488  auto reshapedResult = tensor::ExpandShapeOp::create(
489  rewriter, loc, outputType, result, outputReassocIndices);
490 
491  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
492 
493  return std::make_pair(img2ColTensor.getOperation(),
494  reshapedResult.getOperation());
495 }
496 
497 FailureOr<std::pair<Operation *, Operation *>>
498 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
499  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
500  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
501  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
502 
503  if (!filterType.hasStaticShape())
504  return rewriter.notifyMatchFailure(
505  convOp, "expected a static shape for the filter");
506 
507  if (!inputType.hasStaticShape())
508  return rewriter.notifyMatchFailure(convOp,
509  "expected a static shape for the input");
510 
511  // TODO: Support dilation.
512  if (!hasAllOneValues(convOp.getDilations()))
513  return rewriter.notifyMatchFailure(convOp,
514  "expected all ones for dilations");
515 
516  MLIRContext *context = rewriter.getContext();
517  Value input = convOp.getInputs()[0];
518  Value filter = convOp.getInputs()[1];
519  Value output = convOp.getOutputs()[0];
520 
521  ArrayRef<int64_t> filterShape = filterType.getShape();
522  ArrayRef<int64_t> outputShape = outputType.getShape();
523 
524  int64_t n = outputShape[0];
525  int64_t oh = outputShape[1];
526  int64_t ow = outputShape[2];
527  int64_t oc = outputShape[3];
528  int64_t fh = filterShape[1];
529  int64_t fw = filterShape[2];
530  int64_t ic = filterShape[3];
531 
532  Location loc = convOp.getLoc();
533 
534  // Reshape output and filter to the LHS and result of a "row-wise" matrix
535  // multiplication.
536  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
537  auto reshapedFilterType =
538  RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
539  Value reshapedFilter = tensor::CollapseShapeOp::create(
540  rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
541 
542  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
543  RankedTensorType reshapedOutputType =
544  RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
545  Value reshapedOutput = tensor::CollapseShapeOp::create(
546  rewriter, loc, reshapedOutputType, output, outputReassocIndices);
547 
548  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
549  Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
550  inputType.getElementType());
551 
552  // Convert the input to a (BMK) column tensor.
553  auto nloops = colTensorShape.size();
554 
555  auto parallel = utils::IteratorType::parallel;
556  auto reduction = utils::IteratorType::reduction;
557  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
558 
559  SmallVector<AffineMap> img2colIndexingMaps = {
560  AffineMap::getMultiDimIdentityMap(nloops, context)};
561 
562  auto img2ColTensor = linalg::GenericOp::create(
563  rewriter, loc, colTensor.getType(),
564  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
565  img2colIterators,
566  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
567  // Get the iterators named based on the matmul (batch, m, k).
568  Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
569  Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
570  Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
571 
572  // Recover the original iteration indices from the problem/input sizes.
573  SmallVector<Value> mIndices = unrollIndex(
574  nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
575  auto ohIndex = mIndices[0];
576  auto owIndex = mIndices[1];
577 
578  SmallVector<Value> kIndices = unrollIndex(
579  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
580  auto fhIndex = kIndices[0];
581  auto fwIndex = kIndices[1];
582  auto icIndex = kIndices[2];
583 
584  // Extract the input element corresponding to the expanded indices.
585  Value hIndex =
586  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
587  convOp.getStrides().getValues<int64_t>()[0]);
588  Value wIndex =
589  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
590  convOp.getStrides().getValues<int64_t>()[1]);
591 
592  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
593  SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
594  Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
595  extractionIndices);
596  linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
597  });
598 
599  // Because we didn't transpose the filters we don't actually have a batched
600  // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
601  // products.
602  AffineExpr bDim, mDim, nDim, kDim;
603  bindDims(context, bDim, mDim, nDim, kDim);
604  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
605  auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
606  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
607  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
608  parallel, reduction};
609 
610  auto genericOp = linalg::GenericOp::create(
611  rewriter, loc, reshapedOutputType,
612  /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
613  /*outputs=*/ValueRange{reshapedOutput},
614  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
615  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
616  Value mul =
617  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
618  Value add = createAdd(loc, mul, args[2], nestedBuilder);
619  linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
620  });
621  Value result = genericOp.getResults().front();
622 
623  auto reshapedResult = tensor::ExpandShapeOp::create(
624  rewriter, loc, outputType, result, outputReassocIndices);
625 
626  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
627 
628  return std::make_pair(img2ColTensor.getOperation(),
629  reshapedResult.getOperation());
630 }
631 
632 namespace {
633 
634 class ConvertConv2DNhwcHwcf final
635  : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
636 public:
637  using OpRewritePattern::OpRewritePattern;
638 
639  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
640  PatternRewriter &rewriter) const override {
641  if (failed(rewriteInIm2Col(rewriter, convOp)))
642  return failure();
643  return success();
644  }
645 };
646 
647 class ConvertDepthwiseConv2DNhwcHwc final
648  : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
649 public:
650  using OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>::OpRewritePattern;
651 
652  LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
653  PatternRewriter &rewriter) const override {
654  if (failed(rewriteInIm2Col(rewriter, convOp)))
655  return failure();
656  return success();
657  }
658 };
659 
660 class ConvertConv2DNchwFchw final
661  : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
662 public:
663  using OpRewritePattern::OpRewritePattern;
664 
665  LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
666  PatternRewriter &rewriter) const override {
667  if (failed(rewriteInIm2Col(rewriter, convOp)))
668  return failure();
669  return success();
670  }
671 };
672 
673 class ConvertConv2DNhwcFhwc final
674  : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
675 public:
676  using OpRewritePattern::OpRewritePattern;
677 
678  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
679  PatternRewriter &rewriter) const override {
680  if (failed(rewriteInIm2Col(rewriter, convOp)))
681  return failure();
682  return success();
683  }
684 };
685 } // end anonymous namespace
686 
688  MLIRContext *context = patterns.getContext();
689  patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
690  ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
691 }
692 } // end namespace linalg
693 } // end namespace mlir
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:367
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:359
MLIRContext * getContext() const
Definition: Builders.h:55
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1274
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
Definition: Utils.cpp:1967
void bindDims(MLIRContext *ctx)
Definition: AffineExpr.h:289
static SmallVector< Value > unrollIndex(OpBuilder &b, Location loc, Value index, ArrayRef< int64_t > factors)
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static bool hasAllOneValues(DenseIntElementsAttr attr)
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, Value fIndex, int64_t stride)
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:238
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...