MLIR  20.0.0git
Go to the documentation of this file.
1 //===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
23 #include <utility>
25 namespace mlir {
26 namespace linalg {
28  return llvm::all_of(
29  attr, [](const APInt &element) { return element.getSExtValue() == 1; });
30 }
32 static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
33  if (isa<IntegerType>(x.getType()))
34  return builder.create<arith::AddIOp>(loc, x, y);
35  if (isa<ComplexType>(x.getType()))
36  return builder.create<complex::AddOp>(loc, x, y);
37  return builder.create<arith::AddFOp>(loc, x, y);
38 }
40 static Value createMul(Location loc, Value x, Value y, Type accType,
41  OpBuilder &builder) {
42  // Linalg named ops specify signed extend for named ops.
43  Value xConvert =
44  convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
45  Value yConvert =
46  convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
47  if (isa<ComplexType>(accType))
48  return builder.create<complex::MulOp>(loc, xConvert, yConvert);
49  if (isa<IntegerType>(accType))
50  return builder.create<arith::MulIOp>(loc, xConvert, yConvert);
51  return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
52 }
54 // Delinearizes the given composite `index` by the basis specified in `factors`.
56  ArrayRef<int64_t> factors) {
57  assert(!factors.empty() && "empty factor list");
58  SmallVector<Value> basis;
59  for (int64_t f : factors)
60  basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f)));
61  FailureOr<SmallVector<Value>> multiIndex =
62  affine::delinearizeIndex(b, loc, index, basis);
63  assert(!failed(multiIndex) && "Failed to linearize img2col index");
64  return *multiIndex;
65 }
67 // Given indices corresponding to iterators in the output (oIndex) and filter
68 // (fIndex) for a convolution, compute the convolved index for the
69 // input as `oIndex * stride + fIndex`.
71  Value fIndex, int64_t stride) {
72  AffineExpr oExpr, fExpr;
73  bindSymbols(b.getContext(), oExpr, fExpr);
74  AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
75  return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
76 }
78 FailureOr<std::pair<Operation *, Operation *>>
79 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
80  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
81  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
82  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
84  if (!filterType.hasStaticShape())
85  return rewriter.notifyMatchFailure(
86  convOp, "expected a static shape for the filter");
88  if (!inputType.hasStaticShape())
89  return rewriter.notifyMatchFailure(convOp,
90  "expected a static shape for the input");
92  // TODO: Support dilation.
93  if (!hasAllOneValues(convOp.getDilations()))
94  return rewriter.notifyMatchFailure(convOp,
95  "expected all ones for dilations");
97  MLIRContext *context = rewriter.getContext();
98  Value input = convOp.getInputs()[0];
99  Value filter = convOp.getInputs()[1];
100  Value output = convOp.getOutputs()[0];
102  ArrayRef<int64_t> filterShape = filterType.getShape();
103  ArrayRef<int64_t> outputShape = outputType.getShape();
105  int64_t n = outputShape[0];
106  int64_t oh = outputShape[1];
107  int64_t ow = outputShape[2];
108  int64_t oc = outputShape[3];
109  int64_t fh = filterShape[0];
110  int64_t fw = filterShape[1];
111  int64_t ic = filterShape[2];
113  Location loc = convOp.getLoc();
115  // Reshape output and filter to the LHS and result of a (B)MNK matmul.
116  SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
117  auto reshapedFilterType =
118  RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
119  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
120  loc, reshapedFilterType, filter, filterReassocIndices);
122  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
123  RankedTensorType reshapedOutputType =
124  RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
125  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
126  loc, reshapedOutputType, output, outputReassocIndices);
128  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
129  Value colTensor = rewriter.create<tensor::EmptyOp>(
130  loc, colTensorShape, inputType.getElementType());
132  // Convert the input to a (BMK) column tensor.
133  auto nloops = colTensorShape.size();
135  auto parallel = utils::IteratorType::parallel;
136  auto reduction = utils::IteratorType::reduction;
137  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
139  SmallVector<AffineMap> img2colIndexingMaps = {
140  AffineMap::getMultiDimIdentityMap(nloops, context)};
142  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
143  loc, colTensor.getType(),
144  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
145  img2colIterators,
146  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
147  // Get the iterators named based on the matmul (batch, m, k).
148  Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
149  Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
150  Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
152  // Recover the original iteration indices from the problem/input sizes.
153  SmallVector<Value> mIndices = unrollIndex(
154  nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
155  auto ohIndex = mIndices[0];
156  auto owIndex = mIndices[1];
158  SmallVector<Value> kIndices = unrollIndex(
159  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
160  auto fhIndex = kIndices[0];
161  auto fwIndex = kIndices[1];
162  auto icIndex = kIndices[2];
164  // Extract the input element corresponding to the expanded indices.
165  Value hIndex =
166  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
167  convOp.getStrides().getValues<int64_t>()[0]);
168  Value wIndex =
169  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
170  convOp.getStrides().getValues<int64_t>()[1]);
172  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
173  SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
174  Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
175  loc, input, extractionIndices);
176  nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
177  });
179  // Because the filter does not share the same batch dimension,
180  // the batch dimension is only used in indexing the input and output. Thus
181  // we cannot use existing linalg named ops like linalg.batch_matmul.
182  // i.e. (B x) M x K * K x N = (B x) M x N
183  AffineExpr bDim, mDim, nDim, kDim;
184  bindDims(context, bDim, mDim, nDim, kDim);
185  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
186  auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
187  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
188  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
189  parallel, reduction};
191  auto genericOp = rewriter.create<linalg::GenericOp>(
192  loc, reshapedOutputType,
193  /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
194  /*outputs=*/ValueRange{reshapedOutput},
195  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
196  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
197  Value mul =
198  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
199  Value add = createAdd(loc, mul, args[2], nestedBuilder);
200  nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
201  });
202  Value result = genericOp.getResults().front();
204  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
205  loc, outputType, result, outputReassocIndices);
207  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
209  return std::make_pair(img2ColTensor.getOperation(),
210  reshapedResult.getOperation());
211 }
213 FailureOr<std::pair<Operation *, Operation *>>
215  linalg::DepthwiseConv2DNhwcHwcOp convOp) {
216  auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
217  auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
218  auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
220  if (!filterType.hasStaticShape())
221  return rewriter.notifyMatchFailure(
222  convOp, "expected a static shape for the filter");
224  if (!inputType.hasStaticShape())
225  return rewriter.notifyMatchFailure(convOp,
226  "expected a static shape for the input");
228  // TODO: Support dilation.
229  if (!hasAllOneValues(convOp.getDilations()))
230  return rewriter.notifyMatchFailure(convOp,
231  "expected all ones for dilations");
233  Location loc = convOp.getLoc();
235  auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
236  auto operandTensorType = cast<RankedTensorType>(operand.getType());
237  auto nloops = indices.size();
238  ArrayRef<int64_t> inputShape = operandTensorType.getShape();
240  SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
241  llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
242  return rewriter.getAffineDimExpr(index);
243  }));
245  SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
246  indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
248  Value outputTensor = rewriter.create<tensor::EmptyOp>(
249  loc, targetShape, operandTensorType.getElementType());
251  SmallVector<utils::IteratorType> loopAttributeTypes(
252  nloops, utils::IteratorType::parallel);
254  SmallVector<AffineMap> indexingMaps = {
256  AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
257  AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
259  auto transposedOp = rewriter.create<linalg::GenericOp>(
260  loc, outputTensor.getType(),
261  /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
262  loopAttributeTypes,
263  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
264  nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
265  });
267  return transposedOp.getResult(0);
268  };
270  Value input = convOp.getInputs()[0];
271  Value filter = convOp.getInputs()[1];
272  Value output = convOp.getOutputs()[0];
274  // Transpose input, filter so channels are outermost
275  Value inputT = transposeOperand(input, {0, 3, 1, 2});
276  Value filterT = transposeOperand(filter, {2, 0, 1});
277  ArrayRef<int64_t> filterTShape =
278  cast<RankedTensorType>(filterT.getType()).getShape();
279  ArrayRef<int64_t> outputShape = outputType.getShape();
281  int n = outputShape[0];
282  int oh = outputShape[1];
283  int ow = outputShape[2];
284  int c = outputShape[3];
285  int fh = filterTShape[1];
286  int fw = filterTShape[2];
288  SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
289  Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
291  AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
292  bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
294  AffineExpr shSym = rewriter.getAffineConstantExpr(
295  convOp.getStrides().getValues<int64_t>()[0]);
296  AffineExpr swSym = rewriter.getAffineConstantExpr(
297  convOp.getStrides().getValues<int64_t>()[1]);
299  SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
300  owDim * swSym + kwDim};
302  auto nloops = colTensorShape.size();
304  SmallVector<utils::IteratorType> loopAttributeTypes(
305  nloops, utils::IteratorType::parallel);
307  SmallVector<AffineMap> indexingMaps = {
308  AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
309  AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
311  Value colTensor = rewriter.create<tensor::EmptyOp>(
312  loc, colTensorShape, inputType.getElementType());
314  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
315  loc, colTensor.getType(),
316  /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
317  loopAttributeTypes,
318  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
319  nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
320  });
322  SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
323  {0, 1}, {2, 3}, {4, 5}};
324  SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
325  SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
326  {2, 3}};
328  auto reshapedImg2ColTensorType = RankedTensorType::get(
329  {n * c, oh * ow, fh * fw}, inputType.getElementType());
330  auto reshapedFilterTensorType =
331  RankedTensorType::get({c, fh * fw}, filterType.getElementType());
332  auto reshapedOutputTensorType =
333  RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
335  Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
336  loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
337  img2ColTensorReassocIndices);
338  Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>(
339  loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
340  Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>(
341  loc, reshapedOutputTensorType, transposedOutputTensor,
342  outputReassociationIndice);
344  auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>(
345  loc, TypeRange{reshapedoutputTensor.getType()},
346  ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
347  ValueRange{reshapedoutputTensor});
349  SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
350  {2, 3}};
352  auto batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
353  loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
354  batchMatVecReassociationIndice);
356  Value transposedResult =
357  transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
359  rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
360  return std::make_pair(img2ColTensor.getOperation(),
361  transposedResult.getDefiningOp());
362 }
364 FailureOr<std::pair<Operation *, Operation *>>
365 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
366  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
367  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
368  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
370  if (!filterType.hasStaticShape())
371  return rewriter.notifyMatchFailure(
372  convOp, "expected a static shape for the filter");
374  if (!inputType.hasStaticShape())
375  return rewriter.notifyMatchFailure(convOp,
376  "expected a static shape for the input");
378  // TODO: Support dilation.
379  if (!hasAllOneValues(convOp.getDilations()))
380  return rewriter.notifyMatchFailure(convOp,
381  "expected all ones for dilations");
383  Value input = convOp.getInputs()[0];
384  Value filter = convOp.getInputs()[1];
385  Value output = convOp.getOutputs()[0];
387  auto filterShape = filterType.getShape();
388  auto outputShape = outputType.getShape();
390  int64_t n = outputShape[0];
391  int64_t oc = outputShape[1];
392  int64_t oh = outputShape[2];
393  int64_t ow = outputShape[3];
394  int64_t ic = filterShape[1];
395  int64_t fh = filterShape[2];
396  int64_t fw = filterShape[3];
398  auto loc = convOp.getLoc();
399  MLIRContext *context = rewriter.getContext();
401  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
402  auto reshapedFilterType =
403  RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
404  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
405  loc, reshapedFilterType, filter, filterReassocIndices);
407  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
408  auto reshapedOutputType =
409  RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
410  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
411  loc, reshapedOutputType, output, outputReassocIndices);
413  // Convert the input to a (BKN) tensor.
414  SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
415  Value colTensor = rewriter.create<tensor::EmptyOp>(
416  loc, colTensorShape, inputType.getElementType());
418  auto nloops = colTensorShape.size();
420  auto parallel = utils::IteratorType::parallel;
421  auto reduction = utils::IteratorType::reduction;
422  SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
424  SmallVector<AffineMap, 4> img2colIndexingMaps = {
425  AffineMap::getMultiDimIdentityMap(nloops, context)};
427  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
428  loc, colTensor.getType(),
429  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
430  img2colIterators,
431  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
432  // Get the iterators named based on the matmul (batch, m, k).
433  Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
434  Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
435  Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
437  // Recover the original iteration indices from the problem/input sizes.
438  SmallVector<Value> kIndices = unrollIndex(
439  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
440  auto icIndex = kIndices[0];
441  auto fhIndex = kIndices[1];
442  auto fwIndex = kIndices[2];
444  SmallVector<Value> nIndices = unrollIndex(
445  nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
446  auto ohIndex = nIndices[0];
447  auto owIndex = nIndices[1];
449  // Extract the input element corresponding to the expanded indices.
450  Value hIndex =
451  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
452  convOp.getStrides().getValues<int64_t>()[0]);
453  Value wIndex =
454  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
455  convOp.getStrides().getValues<int64_t>()[1]);
457  // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
458  SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
459  Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
460  loc, input, extractionIndices);
461  nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
462  });
464  // Because the filter does not share the same batch dimension,
465  // the batch dimension is only used in indexing the input and output. Thus
466  // we cannot use existing linalg named ops like linalg.batch_matmul.
467  // i.e. M x K * (B x) K x N = (B x) M x N
468  AffineExpr bDim, mDim, nDim, kDim;
469  bindDims(context, bDim, mDim, nDim, kDim);
470  auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
471  auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
472  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
473  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
474  parallel, reduction};
475  auto genericOp = rewriter.create<linalg::GenericOp>(
476  loc, reshapedOutputType,
477  /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
478  /*outputs=*/ValueRange{reshapedOutput},
479  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
480  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
481  Value mul =
482  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
483  Value add = createAdd(loc, mul, args[2], nestedBuilder);
484  nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
485  });
486  Value result = genericOp.getResults().front();
488  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
489  loc, outputType, result, outputReassocIndices);
491  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
493  return std::make_pair(img2ColTensor.getOperation(),
494  reshapedResult.getOperation());
495 }
497 FailureOr<std::pair<Operation *, Operation *>>
498 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
499  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
500  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
501  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
503  if (!filterType.hasStaticShape())
504  return rewriter.notifyMatchFailure(
505  convOp, "expected a static shape for the filter");
507  if (!inputType.hasStaticShape())
508  return rewriter.notifyMatchFailure(convOp,
509  "expected a static shape for the input");
511  // TODO: Support dilation.
512  if (!hasAllOneValues(convOp.getDilations()))
513  return rewriter.notifyMatchFailure(convOp,
514  "expected all ones for dilations");
516  MLIRContext *context = rewriter.getContext();
517  Value input = convOp.getInputs()[0];
518  Value filter = convOp.getInputs()[1];
519  Value output = convOp.getOutputs()[0];
521  ArrayRef<int64_t> filterShape = filterType.getShape();
522  ArrayRef<int64_t> outputShape = outputType.getShape();
524  int64_t n = outputShape[0];
525  int64_t oh = outputShape[1];
526  int64_t ow = outputShape[2];
527  int64_t oc = outputShape[3];
528  int64_t fh = filterShape[1];
529  int64_t fw = filterShape[2];
530  int64_t ic = filterShape[3];
532  Location loc = convOp.getLoc();
534  // Reshape output and filter to the LHS and result of a "row-wise" matrix
535  // multiplication.
536  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
537  auto reshapedFilterType =
538  RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
539  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
540  loc, reshapedFilterType, filter, filterReassocIndices);
542  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
543  RankedTensorType reshapedOutputType =
544  RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
545  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
546  loc, reshapedOutputType, output, outputReassocIndices);
548  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
549  Value colTensor = rewriter.create<tensor::EmptyOp>(
550  loc, colTensorShape, inputType.getElementType());
552  // Convert the input to a (BMK) column tensor.
553  auto nloops = colTensorShape.size();
555  auto parallel = utils::IteratorType::parallel;
556  auto reduction = utils::IteratorType::reduction;
557  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
559  SmallVector<AffineMap> img2colIndexingMaps = {
560  AffineMap::getMultiDimIdentityMap(nloops, context)};
562  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
563  loc, colTensor.getType(),
564  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
565  img2colIterators,
566  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
567  // Get the iterators named based on the matmul (batch, m, k).
568  Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
569  Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
570  Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
572  // Recover the original iteration indices from the problem/input sizes.
573  SmallVector<Value> mIndices = unrollIndex(
574  nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
575  auto ohIndex = mIndices[0];
576  auto owIndex = mIndices[1];
578  SmallVector<Value> kIndices = unrollIndex(
579  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
580  auto fhIndex = kIndices[0];
581  auto fwIndex = kIndices[1];
582  auto icIndex = kIndices[2];
584  // Extract the input element corresponding to the expanded indices.
585  Value hIndex =
586  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
587  convOp.getStrides().getValues<int64_t>()[0]);
588  Value wIndex =
589  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
590  convOp.getStrides().getValues<int64_t>()[1]);
592  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
593  SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
594  Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
595  loc, input, extractionIndices);
596  nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
597  });
599  // Because we didn't transpose the filters we don't actually have a batched
600  // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
601  // products.
602  AffineExpr bDim, mDim, nDim, kDim;
603  bindDims(context, bDim, mDim, nDim, kDim);
604  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
605  auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
606  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
607  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
608  parallel, reduction};
610  auto genericOp = rewriter.create<linalg::GenericOp>(
611  loc, reshapedOutputType,
612  /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
613  /*outputs=*/ValueRange{reshapedOutput},
614  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
615  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
616  Value mul =
617  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
618  Value add = createAdd(loc, mul, args[2], nestedBuilder);
619  nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
620  });
621  Value result = genericOp.getResults().front();
623  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
624  loc, outputType, result, outputReassocIndices);
626  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
628  return std::make_pair(img2ColTensor.getOperation(),
629  reshapedResult.getOperation());
630 }
632 namespace {
634 class ConvertConv2DNhwcHwcf final
635  : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
636 public:
639  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
640  PatternRewriter &rewriter) const override {
641  if (failed(rewriteInIm2Col(rewriter, convOp)))
642  return failure();
643  return success();
644  }
645 };
647 class ConvertDepthwiseConv2DNhwcHwc final
648  : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
649 public:
652  LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
653  PatternRewriter &rewriter) const override {
654  if (failed(rewriteInIm2Col(rewriter, convOp)))
655  return failure();
656  return success();
657  }
658 };
660 class ConvertConv2DNchwFchw final
661  : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
662 public:
665  LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
666  PatternRewriter &rewriter) const override {
667  if (failed(rewriteInIm2Col(rewriter, convOp)))
668  return failure();
669  return success();
670  }
671 };
673 class ConvertConv2DNhwcFhwc final
674  : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
675 public:
678  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
679  PatternRewriter &rewriter) const override {
680  if (failed(rewriteInIm2Col(rewriter, convOp)))
681  return failure();
682  return success();
683  }
684 };
685 } // end anonymous namespace
688  MLIRContext *context = patterns.getContext();
689  patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
690  ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
691 }
692 } // end namespace linalg
693 } // end namespace mlir
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:142
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:406
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:398
MLIRContext * getContext() const
Definition: Builders.h:56
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:491
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1144
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
Definition: Utils.cpp:1949
static SmallVector< Value > unrollIndex(OpBuilder &b, Location loc, Value index, ArrayRef< int64_t > factors)
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static bool hasAllOneValues(DenseIntElementsAttr attr)
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, Value fIndex, int64_t stride)
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362