MLIR  21.0.0git
ConvertConv2DToImg2Col.cpp
Go to the documentation of this file.
1 //===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include <utility>
23 
24 namespace mlir {
25 namespace linalg {
27  return llvm::all_of(
28  attr, [](const APInt &element) { return element.getSExtValue() == 1; });
29 }
30 
31 static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
32  if (isa<IntegerType>(x.getType()))
33  return builder.create<arith::AddIOp>(loc, x, y);
34  if (isa<ComplexType>(x.getType()))
35  return builder.create<complex::AddOp>(loc, x, y);
36  return builder.create<arith::AddFOp>(loc, x, y);
37 }
38 
39 static Value createMul(Location loc, Value x, Value y, Type accType,
40  OpBuilder &builder) {
41  // Linalg named ops specify signed extend for named ops.
42  Value xConvert =
43  convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
44  Value yConvert =
45  convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
46  if (isa<ComplexType>(accType))
47  return builder.create<complex::MulOp>(loc, xConvert, yConvert);
48  if (isa<IntegerType>(accType))
49  return builder.create<arith::MulIOp>(loc, xConvert, yConvert);
50  return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
51 }
52 
53 // Delinearizes the given composite `index` by the basis specified in `factors`.
55  ArrayRef<int64_t> factors) {
56  assert(!factors.empty() && "empty factor list");
57  SmallVector<Value> basis;
58  for (int64_t f : factors)
59  basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f)));
60  FailureOr<SmallVector<Value>> multiIndex =
61  affine::delinearizeIndex(b, loc, index, basis);
62  assert(!failed(multiIndex) && "Failed to linearize img2col index");
63  return *multiIndex;
64 }
65 
66 // Given indices corresponding to iterators in the output (oIndex) and filter
67 // (fIndex) for a convolution, compute the convolved index for the
68 // input as `oIndex * stride + fIndex`.
70  Value fIndex, int64_t stride) {
71  AffineExpr oExpr, fExpr;
72  bindSymbols(b.getContext(), oExpr, fExpr);
73  AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
74  return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
75 }
76 
77 FailureOr<std::pair<Operation *, Operation *>>
78 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
79  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
80  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
81  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
82 
83  if (!filterType.hasStaticShape())
84  return rewriter.notifyMatchFailure(
85  convOp, "expected a static shape for the filter");
86 
87  if (!inputType.hasStaticShape())
88  return rewriter.notifyMatchFailure(convOp,
89  "expected a static shape for the input");
90 
91  // TODO: Support dilation.
92  if (!hasAllOneValues(convOp.getDilations()))
93  return rewriter.notifyMatchFailure(convOp,
94  "expected all ones for dilations");
95 
96  MLIRContext *context = rewriter.getContext();
97  Value input = convOp.getInputs()[0];
98  Value filter = convOp.getInputs()[1];
99  Value output = convOp.getOutputs()[0];
100 
101  ArrayRef<int64_t> filterShape = filterType.getShape();
102  ArrayRef<int64_t> outputShape = outputType.getShape();
103 
104  int64_t n = outputShape[0];
105  int64_t oh = outputShape[1];
106  int64_t ow = outputShape[2];
107  int64_t oc = outputShape[3];
108  int64_t fh = filterShape[0];
109  int64_t fw = filterShape[1];
110  int64_t ic = filterShape[2];
111 
112  Location loc = convOp.getLoc();
113 
114  // Reshape output and filter to the LHS and result of a (B)MNK matmul.
115  SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
116  auto reshapedFilterType =
117  RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
118  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
119  loc, reshapedFilterType, filter, filterReassocIndices);
120 
121  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
122  RankedTensorType reshapedOutputType =
123  RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
124  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
125  loc, reshapedOutputType, output, outputReassocIndices);
126 
127  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
128  Value colTensor = rewriter.create<tensor::EmptyOp>(
129  loc, colTensorShape, inputType.getElementType());
130 
131  // Convert the input to a (BMK) column tensor.
132  auto nloops = colTensorShape.size();
133 
134  auto parallel = utils::IteratorType::parallel;
135  auto reduction = utils::IteratorType::reduction;
136  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
137 
138  SmallVector<AffineMap> img2colIndexingMaps = {
139  AffineMap::getMultiDimIdentityMap(nloops, context)};
140 
141  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
142  loc, colTensor.getType(),
143  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
144  img2colIterators,
145  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
146  // Get the iterators named based on the matmul (batch, m, k).
147  Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
148  Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
149  Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
150 
151  // Recover the original iteration indices from the problem/input sizes.
152  SmallVector<Value> mIndices = unrollIndex(
153  nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
154  auto ohIndex = mIndices[0];
155  auto owIndex = mIndices[1];
156 
157  SmallVector<Value> kIndices = unrollIndex(
158  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
159  auto fhIndex = kIndices[0];
160  auto fwIndex = kIndices[1];
161  auto icIndex = kIndices[2];
162 
163  // Extract the input element corresponding to the expanded indices.
164  Value hIndex =
165  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
166  convOp.getStrides().getValues<int64_t>()[0]);
167  Value wIndex =
168  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
169  convOp.getStrides().getValues<int64_t>()[1]);
170 
171  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
172  SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
173  Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
174  loc, input, extractionIndices);
175  nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
176  });
177 
178  // Because the filter does not share the same batch dimension,
179  // the batch dimension is only used in indexing the input and output. Thus
180  // we cannot use existing linalg named ops like linalg.batch_matmul.
181  // i.e. (B x) M x K * K x N = (B x) M x N
182  AffineExpr bDim, mDim, nDim, kDim;
183  bindDims(context, bDim, mDim, nDim, kDim);
184  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
185  auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
186  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
187  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
188  parallel, reduction};
189 
190  auto genericOp = rewriter.create<linalg::GenericOp>(
191  loc, reshapedOutputType,
192  /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
193  /*outputs=*/ValueRange{reshapedOutput},
194  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
195  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
196  Value mul =
197  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
198  Value add = createAdd(loc, mul, args[2], nestedBuilder);
199  nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
200  });
201  Value result = genericOp.getResults().front();
202 
203  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
204  loc, outputType, result, outputReassocIndices);
205 
206  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
207 
208  return std::make_pair(img2ColTensor.getOperation(),
209  reshapedResult.getOperation());
210 }
211 
212 FailureOr<std::pair<Operation *, Operation *>>
214  linalg::DepthwiseConv2DNhwcHwcOp convOp) {
215  auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
216  auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
217  auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
218 
219  if (!filterType.hasStaticShape())
220  return rewriter.notifyMatchFailure(
221  convOp, "expected a static shape for the filter");
222 
223  if (!inputType.hasStaticShape())
224  return rewriter.notifyMatchFailure(convOp,
225  "expected a static shape for the input");
226 
227  // TODO: Support dilation.
228  if (!hasAllOneValues(convOp.getDilations()))
229  return rewriter.notifyMatchFailure(convOp,
230  "expected all ones for dilations");
231 
232  Location loc = convOp.getLoc();
233 
234  auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
235  auto operandTensorType = cast<RankedTensorType>(operand.getType());
236  auto nloops = indices.size();
237  ArrayRef<int64_t> inputShape = operandTensorType.getShape();
238 
239  SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
240  llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
241  return rewriter.getAffineDimExpr(index);
242  }));
243 
244  SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
245  indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
246 
247  Value outputTensor = rewriter.create<tensor::EmptyOp>(
248  loc, targetShape, operandTensorType.getElementType());
249 
250  SmallVector<utils::IteratorType> loopAttributeTypes(
251  nloops, utils::IteratorType::parallel);
252 
253  SmallVector<AffineMap> indexingMaps = {
255  AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
256  AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
257 
258  auto transposedOp = rewriter.create<linalg::GenericOp>(
259  loc, outputTensor.getType(),
260  /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
261  loopAttributeTypes,
262  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
263  nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
264  });
265 
266  return transposedOp.getResult(0);
267  };
268 
269  Value input = convOp.getInputs()[0];
270  Value filter = convOp.getInputs()[1];
271  Value output = convOp.getOutputs()[0];
272 
273  // Transpose input, filter so channels are outermost
274  Value inputT = transposeOperand(input, {0, 3, 1, 2});
275  Value filterT = transposeOperand(filter, {2, 0, 1});
276  ArrayRef<int64_t> filterTShape =
277  cast<RankedTensorType>(filterT.getType()).getShape();
278  ArrayRef<int64_t> outputShape = outputType.getShape();
279 
280  int n = outputShape[0];
281  int oh = outputShape[1];
282  int ow = outputShape[2];
283  int c = outputShape[3];
284  int fh = filterTShape[1];
285  int fw = filterTShape[2];
286 
287  SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
288  Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
289 
290  AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
291  bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
292 
293  AffineExpr shSym = rewriter.getAffineConstantExpr(
294  convOp.getStrides().getValues<int64_t>()[0]);
295  AffineExpr swSym = rewriter.getAffineConstantExpr(
296  convOp.getStrides().getValues<int64_t>()[1]);
297 
298  SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
299  owDim * swSym + kwDim};
300 
301  auto nloops = colTensorShape.size();
302 
303  SmallVector<utils::IteratorType> loopAttributeTypes(
304  nloops, utils::IteratorType::parallel);
305 
306  SmallVector<AffineMap> indexingMaps = {
307  AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
308  AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
309 
310  Value colTensor = rewriter.create<tensor::EmptyOp>(
311  loc, colTensorShape, inputType.getElementType());
312 
313  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
314  loc, colTensor.getType(),
315  /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
316  loopAttributeTypes,
317  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
318  nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
319  });
320 
321  SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
322  {0, 1}, {2, 3}, {4, 5}};
323  SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
324  SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
325  {2, 3}};
326 
327  auto reshapedImg2ColTensorType = RankedTensorType::get(
328  {n * c, oh * ow, fh * fw}, inputType.getElementType());
329  auto reshapedFilterTensorType =
330  RankedTensorType::get({c, fh * fw}, filterType.getElementType());
331  auto reshapedOutputTensorType =
332  RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
333 
334  Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
335  loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
336  img2ColTensorReassocIndices);
337  Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>(
338  loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
339  Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>(
340  loc, reshapedOutputTensorType, transposedOutputTensor,
341  outputReassociationIndice);
342 
343  auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>(
344  loc, TypeRange{reshapedoutputTensor.getType()},
345  ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
346  ValueRange{reshapedoutputTensor});
347 
348  SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
349  {2, 3}};
350 
351  auto batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
352  loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
353  batchMatVecReassociationIndice);
354 
355  Value transposedResult =
356  transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
357 
358  rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
359  return std::make_pair(img2ColTensor.getOperation(),
360  transposedResult.getDefiningOp());
361 }
362 
363 FailureOr<std::pair<Operation *, Operation *>>
364 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
365  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
366  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
367  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
368 
369  if (!filterType.hasStaticShape())
370  return rewriter.notifyMatchFailure(
371  convOp, "expected a static shape for the filter");
372 
373  if (!inputType.hasStaticShape())
374  return rewriter.notifyMatchFailure(convOp,
375  "expected a static shape for the input");
376 
377  // TODO: Support dilation.
378  if (!hasAllOneValues(convOp.getDilations()))
379  return rewriter.notifyMatchFailure(convOp,
380  "expected all ones for dilations");
381 
382  Value input = convOp.getInputs()[0];
383  Value filter = convOp.getInputs()[1];
384  Value output = convOp.getOutputs()[0];
385 
386  auto filterShape = filterType.getShape();
387  auto outputShape = outputType.getShape();
388 
389  int64_t n = outputShape[0];
390  int64_t oc = outputShape[1];
391  int64_t oh = outputShape[2];
392  int64_t ow = outputShape[3];
393  int64_t ic = filterShape[1];
394  int64_t fh = filterShape[2];
395  int64_t fw = filterShape[3];
396 
397  auto loc = convOp.getLoc();
398  MLIRContext *context = rewriter.getContext();
399 
400  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
401  auto reshapedFilterType =
402  RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
403  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
404  loc, reshapedFilterType, filter, filterReassocIndices);
405 
406  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
407  auto reshapedOutputType =
408  RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
409  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
410  loc, reshapedOutputType, output, outputReassocIndices);
411 
412  // Convert the input to a (BKN) tensor.
413  SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
414  Value colTensor = rewriter.create<tensor::EmptyOp>(
415  loc, colTensorShape, inputType.getElementType());
416 
417  auto nloops = colTensorShape.size();
418 
419  auto parallel = utils::IteratorType::parallel;
420  auto reduction = utils::IteratorType::reduction;
421  SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
422 
423  SmallVector<AffineMap, 4> img2colIndexingMaps = {
424  AffineMap::getMultiDimIdentityMap(nloops, context)};
425 
426  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
427  loc, colTensor.getType(),
428  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
429  img2colIterators,
430  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
431  // Get the iterators named based on the matmul (batch, m, k).
432  Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
433  Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
434  Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
435 
436  // Recover the original iteration indices from the problem/input sizes.
437  SmallVector<Value> kIndices = unrollIndex(
438  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
439  auto icIndex = kIndices[0];
440  auto fhIndex = kIndices[1];
441  auto fwIndex = kIndices[2];
442 
443  SmallVector<Value> nIndices = unrollIndex(
444  nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
445  auto ohIndex = nIndices[0];
446  auto owIndex = nIndices[1];
447 
448  // Extract the input element corresponding to the expanded indices.
449  Value hIndex =
450  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
451  convOp.getStrides().getValues<int64_t>()[0]);
452  Value wIndex =
453  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
454  convOp.getStrides().getValues<int64_t>()[1]);
455 
456  // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
457  SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
458  Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
459  loc, input, extractionIndices);
460  nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
461  });
462 
463  // Because the filter does not share the same batch dimension,
464  // the batch dimension is only used in indexing the input and output. Thus
465  // we cannot use existing linalg named ops like linalg.batch_matmul.
466  // i.e. M x K * (B x) K x N = (B x) M x N
467  AffineExpr bDim, mDim, nDim, kDim;
468  bindDims(context, bDim, mDim, nDim, kDim);
469  auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
470  auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
471  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
472  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
473  parallel, reduction};
474  auto genericOp = rewriter.create<linalg::GenericOp>(
475  loc, reshapedOutputType,
476  /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
477  /*outputs=*/ValueRange{reshapedOutput},
478  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
479  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
480  Value mul =
481  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
482  Value add = createAdd(loc, mul, args[2], nestedBuilder);
483  nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
484  });
485  Value result = genericOp.getResults().front();
486 
487  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
488  loc, outputType, result, outputReassocIndices);
489 
490  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
491 
492  return std::make_pair(img2ColTensor.getOperation(),
493  reshapedResult.getOperation());
494 }
495 
496 FailureOr<std::pair<Operation *, Operation *>>
497 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
498  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
499  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
500  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
501 
502  if (!filterType.hasStaticShape())
503  return rewriter.notifyMatchFailure(
504  convOp, "expected a static shape for the filter");
505 
506  if (!inputType.hasStaticShape())
507  return rewriter.notifyMatchFailure(convOp,
508  "expected a static shape for the input");
509 
510  // TODO: Support dilation.
511  if (!hasAllOneValues(convOp.getDilations()))
512  return rewriter.notifyMatchFailure(convOp,
513  "expected all ones for dilations");
514 
515  MLIRContext *context = rewriter.getContext();
516  Value input = convOp.getInputs()[0];
517  Value filter = convOp.getInputs()[1];
518  Value output = convOp.getOutputs()[0];
519 
520  ArrayRef<int64_t> filterShape = filterType.getShape();
521  ArrayRef<int64_t> outputShape = outputType.getShape();
522 
523  int64_t n = outputShape[0];
524  int64_t oh = outputShape[1];
525  int64_t ow = outputShape[2];
526  int64_t oc = outputShape[3];
527  int64_t fh = filterShape[1];
528  int64_t fw = filterShape[2];
529  int64_t ic = filterShape[3];
530 
531  Location loc = convOp.getLoc();
532 
533  // Reshape output and filter to the LHS and result of a "row-wise" matrix
534  // multiplication.
535  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
536  auto reshapedFilterType =
537  RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
538  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
539  loc, reshapedFilterType, filter, filterReassocIndices);
540 
541  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
542  RankedTensorType reshapedOutputType =
543  RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
544  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
545  loc, reshapedOutputType, output, outputReassocIndices);
546 
547  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
548  Value colTensor = rewriter.create<tensor::EmptyOp>(
549  loc, colTensorShape, inputType.getElementType());
550 
551  // Convert the input to a (BMK) column tensor.
552  auto nloops = colTensorShape.size();
553 
554  auto parallel = utils::IteratorType::parallel;
555  auto reduction = utils::IteratorType::reduction;
556  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
557 
558  SmallVector<AffineMap> img2colIndexingMaps = {
559  AffineMap::getMultiDimIdentityMap(nloops, context)};
560 
561  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
562  loc, colTensor.getType(),
563  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
564  img2colIterators,
565  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
566  // Get the iterators named based on the matmul (batch, m, k).
567  Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
568  Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
569  Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
570 
571  // Recover the original iteration indices from the problem/input sizes.
572  SmallVector<Value> mIndices = unrollIndex(
573  nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
574  auto ohIndex = mIndices[0];
575  auto owIndex = mIndices[1];
576 
577  SmallVector<Value> kIndices = unrollIndex(
578  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
579  auto fhIndex = kIndices[0];
580  auto fwIndex = kIndices[1];
581  auto icIndex = kIndices[2];
582 
583  // Extract the input element corresponding to the expanded indices.
584  Value hIndex =
585  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
586  convOp.getStrides().getValues<int64_t>()[0]);
587  Value wIndex =
588  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
589  convOp.getStrides().getValues<int64_t>()[1]);
590 
591  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
592  SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
593  Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
594  loc, input, extractionIndices);
595  nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
596  });
597 
598  // Because we didn't transpose the filters we don't actually have a batched
599  // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
600  // products.
601  AffineExpr bDim, mDim, nDim, kDim;
602  bindDims(context, bDim, mDim, nDim, kDim);
603  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
604  auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
605  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
606  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
607  parallel, reduction};
608 
609  auto genericOp = rewriter.create<linalg::GenericOp>(
610  loc, reshapedOutputType,
611  /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
612  /*outputs=*/ValueRange{reshapedOutput},
613  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
614  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
615  Value mul =
616  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
617  Value add = createAdd(loc, mul, args[2], nestedBuilder);
618  nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
619  });
620  Value result = genericOp.getResults().front();
621 
622  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
623  loc, outputType, result, outputReassocIndices);
624 
625  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
626 
627  return std::make_pair(img2ColTensor.getOperation(),
628  reshapedResult.getOperation());
629 }
630 
631 namespace {
632 
633 class ConvertConv2DNhwcHwcf final
634  : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
635 public:
637 
638  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
639  PatternRewriter &rewriter) const override {
640  if (failed(rewriteInIm2Col(rewriter, convOp)))
641  return failure();
642  return success();
643  }
644 };
645 
646 class ConvertDepthwiseConv2DNhwcHwc final
647  : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
648 public:
650 
651  LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
652  PatternRewriter &rewriter) const override {
653  if (failed(rewriteInIm2Col(rewriter, convOp)))
654  return failure();
655  return success();
656  }
657 };
658 
659 class ConvertConv2DNchwFchw final
660  : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
661 public:
663 
664  LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
665  PatternRewriter &rewriter) const override {
666  if (failed(rewriteInIm2Col(rewriter, convOp)))
667  return failure();
668  return success();
669  }
670 };
671 
672 class ConvertConv2DNhwcFhwc final
673  : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
674 public:
676 
677  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
678  PatternRewriter &rewriter) const override {
679  if (failed(rewriteInIm2Col(rewriter, convOp)))
680  return failure();
681  return success();
682  }
683 };
684 } // end anonymous namespace
685 
687  MLIRContext *context = patterns.getContext();
688  patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
689  ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
690 }
691 } // end namespace linalg
692 } // end namespace mlir
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:367
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:359
MLIRContext * getContext() const
Definition: Builders.h:55
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1278
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
Definition: Utils.cpp:1966
static SmallVector< Value > unrollIndex(OpBuilder &b, Location loc, Value index, ArrayRef< int64_t > factors)
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static bool hasAllOneValues(DenseIntElementsAttr attr)
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, Value fIndex, int64_t stride)
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319