MLIR  22.0.0git
ConvertConv2DToImg2Col.cpp
Go to the documentation of this file.
1 //===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 #include "mlir/IR/AffineExpr.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include <utility>
24 
25 namespace mlir {
26 namespace linalg {
28  return llvm::all_of(
29  attr, [](const APInt &element) { return element.getSExtValue() == 1; });
30 }
31 
32 static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
33  if (isa<IntegerType>(x.getType()))
34  return arith::AddIOp::create(builder, loc, x, y);
35  if (isa<ComplexType>(x.getType()))
36  return complex::AddOp::create(builder, loc, x, y);
37  return arith::AddFOp::create(builder, loc, x, y);
38 }
39 
40 static Value createMul(Location loc, Value x, Value y, Type accType,
41  OpBuilder &builder) {
42  // Linalg named ops specify signed extend for named ops.
43  Value xConvert =
44  convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
45  Value yConvert =
46  convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
47  if (isa<ComplexType>(accType))
48  return complex::MulOp::create(builder, loc, xConvert, yConvert);
49  if (isa<IntegerType>(accType))
50  return arith::MulIOp::create(builder, loc, xConvert, yConvert);
51  return arith::MulFOp::create(builder, loc, xConvert, yConvert);
52 }
53 
54 // Generate the affine expression to compute the convolved index
55 // for the input as `oIndex * stride + fIndex`,
56 // where oIndex: output iterator; fIndex: filter iterator.
57 static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride,
58  bool useSymbols = true) {
59  AffineExpr oExpr, fExpr;
60  if (useSymbols)
61  bindSymbols(b.getContext(), oExpr, fExpr);
62  else
63  bindDims(b.getContext(), oExpr, fExpr);
64  return AffineExpr(stride * oExpr + fExpr);
65 }
66 
67 // Stores the affine expressions to map the iteration space of the im2col matrix
68 // to the corresponding indices of the output and filter matrices
75 };
76 
77 // Stores the affine expressions to map the iteration space of the im2col matrix
78 // to the input matrix indices
84 };
85 
86 /// Construct the affine expressions that map the indices of the im2col matrix
87 /// to the corresponding input tensor indices for a 2D convolution with the the
88 /// provided strides.
89 ///
90 /// @param exprs Affine expressions for output and filter indices.
91 /// @param strides [height, width] stride values for the convolution.
92 /// @param rewriter Pattern rewriter.
93 /// @return Affine expressions mapping im2col matrix indices to input
94 /// offsets.
97  ArrayRef<int64_t> strides, RewriterBase &rewriter) {
98  // maps the iteration space of the im2col matrix to (output_y, filter_y)
99  auto hIndicesMap = AffineMap::inferFromExprList(
100  {ArrayRef{exprs.ohIndex, exprs.fhIndex}}, rewriter.getContext())[0];
101  // maps the iteration space of the im2col matrix to (output_x, filter_x)
102  auto wIndicesMap = AffineMap::inferFromExprList(
103  {ArrayRef{exprs.owIndex, exprs.fwIndex}}, rewriter.getContext())[0];
104  // Compute the input indexing map, to map the indices of the im2col matrix to
105  // the original input offsets. Each element of the im2col matrix corresponds
106  // to a pair of (out_element, filter_element). First, we build the expressions
107  // to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs;
108  // then we compose them with the maps that map the im2col matrix elements to
109  // the (out_element, filter_element) pairs.
110  auto bIndexExpr = rewriter.getAffineDimExpr(0U);
111  auto hIndexExpr = getConvolvedExpr(rewriter, strides[0],
112  /*useSymbols*/ false);
113  hIndexExpr = hIndexExpr.compose(hIndicesMap);
114  auto wIndexExpr = getConvolvedExpr(rewriter, strides[1],
115  /*useSymbols*/ false);
116  wIndexExpr = wIndexExpr.compose(wIndicesMap);
117  auto cIndexExpr = exprs.icIndex;
118  return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
119 }
120 
121 FailureOr<std::pair<Operation *, Operation *>>
122 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
123  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
124  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
125  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
126 
127  if (!filterType.hasStaticShape())
128  return rewriter.notifyMatchFailure(
129  convOp, "expected a static shape for the filter");
130 
131  if (!inputType.hasStaticShape())
132  return rewriter.notifyMatchFailure(convOp,
133  "expected a static shape for the input");
134 
135  // TODO: Support dilation.
136  if (!hasAllOneValues(convOp.getDilations()))
137  return rewriter.notifyMatchFailure(convOp,
138  "expected all ones for dilations");
139 
140  MLIRContext *context = rewriter.getContext();
141  Value input = convOp.getInputs()[0];
142  Value filter = convOp.getInputs()[1];
143  Value output = convOp.getOutputs()[0];
144 
145  ArrayRef<int64_t> filterShape = filterType.getShape();
146  ArrayRef<int64_t> outputShape = outputType.getShape();
147 
148  int64_t n = outputShape[0];
149  int64_t oh = outputShape[1];
150  int64_t ow = outputShape[2];
151  int64_t oc = outputShape[3];
152  int64_t fh = filterShape[0];
153  int64_t fw = filterShape[1];
154  int64_t ic = filterShape[2];
155 
156  Location loc = convOp.getLoc();
157 
158  // Reshape output and filter to the LHS and result of a (B)MNK matmul.
159  SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
160  auto reshapedFilterType =
161  RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
162  Value reshapedFilter = tensor::CollapseShapeOp::create(
163  rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
164 
165  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
166  RankedTensorType reshapedOutputType =
167  RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
168  Value reshapedOutput = tensor::CollapseShapeOp::create(
169  rewriter, loc, reshapedOutputType, output, outputReassocIndices);
170 
171  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
172  Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
173  inputType.getElementType());
174 
175  // Convert the input to a (BMK) column tensor.
176  auto nloops = colTensorShape.size();
177 
178  auto parallel = utils::IteratorType::parallel;
179  auto reduction = utils::IteratorType::reduction;
180  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
181 
182  // Given an index of the im2col matrix, retrieve the corresponding indices of
183  // the output and filter matrices
184  auto mIndicesExprs =
185  delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
186  auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
187  ArrayRef<int64_t>{fw * ic, ic, 1});
188  Im2ColToOperandsExprs i2cToOperExprs;
189  i2cToOperExprs.fhIndex = kIndicesExprs[0];
190  i2cToOperExprs.fwIndex = kIndicesExprs[1];
191  i2cToOperExprs.icIndex = kIndicesExprs[2];
192  i2cToOperExprs.ohIndex = mIndicesExprs[0];
193  i2cToOperExprs.owIndex = mIndicesExprs[1];
194 
195  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
197  i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
198  rewriter);
199  auto inMap =
201  inExprs.wIndex, inExprs.cIndex}},
202  rewriter.getContext())[0];
203 
204  SmallVector<AffineMap> img2colIndexingMaps = {
205  inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
206 
207  auto img2ColTensor = linalg::GenericOp::create(
208  rewriter, loc, colTensor.getType(),
209  /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
210  img2colIterators,
211  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
212  linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
213  });
214 
215  // Because the filter does not share the same batch dimension,
216  // the batch dimension is only used in indexing the input and output. Thus
217  // we cannot use existing linalg named ops like linalg.batch_matmul.
218  // i.e. (B x) M x K * K x N = (B x) M x N
219  AffineExpr bDim, mDim, nDim, kDim;
220  bindDims(context, bDim, mDim, nDim, kDim);
221  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
222  auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
223  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
224  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
225  parallel, reduction};
226 
227  auto genericOp = linalg::GenericOp::create(
228  rewriter, loc, reshapedOutputType,
229  /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
230  /*outputs=*/ValueRange{reshapedOutput},
231  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
232  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
233  Value mul =
234  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
235  Value add = createAdd(loc, mul, args[2], nestedBuilder);
236  linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
237  });
238  Value result = genericOp.getResults().front();
239 
240  auto reshapedResult = tensor::ExpandShapeOp::create(
241  rewriter, loc, outputType, result, outputReassocIndices);
242 
243  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
244 
245  return std::make_pair(img2ColTensor.getOperation(),
246  reshapedResult.getOperation());
247 }
248 
249 FailureOr<std::pair<Operation *, Operation *>>
251  linalg::DepthwiseConv2DNhwcHwcOp convOp) {
252  auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
253  auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
254  auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
255 
256  if (!filterType.hasStaticShape())
257  return rewriter.notifyMatchFailure(
258  convOp, "expected a static shape for the filter");
259 
260  if (!inputType.hasStaticShape())
261  return rewriter.notifyMatchFailure(convOp,
262  "expected a static shape for the input");
263 
264  // TODO: Support dilation.
265  if (!hasAllOneValues(convOp.getDilations()))
266  return rewriter.notifyMatchFailure(convOp,
267  "expected all ones for dilations");
268 
269  Location loc = convOp.getLoc();
270 
271  auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
272  auto operandTensorType = cast<RankedTensorType>(operand.getType());
273  auto nloops = indices.size();
274  ArrayRef<int64_t> inputShape = operandTensorType.getShape();
275 
276  SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
277  llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
278  return rewriter.getAffineDimExpr(index);
279  }));
280 
281  SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
282  indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
283 
284  Value outputTensor = tensor::EmptyOp::create(
285  rewriter, loc, targetShape, operandTensorType.getElementType());
286 
287  SmallVector<utils::IteratorType> loopAttributeTypes(
288  nloops, utils::IteratorType::parallel);
289 
290  SmallVector<AffineMap> indexingMaps = {
292  AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
293  AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
294 
295  auto transposedOp = linalg::GenericOp::create(
296  rewriter, loc, outputTensor.getType(),
297  /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
298  loopAttributeTypes,
299  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
300  linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
301  });
302 
303  return transposedOp.getResult(0);
304  };
305 
306  Value input = convOp.getInputs()[0];
307  Value filter = convOp.getInputs()[1];
308  Value output = convOp.getOutputs()[0];
309 
310  // Transpose input, filter so channels are outermost
311  Value inputT = transposeOperand(input, {0, 3, 1, 2});
312  Value filterT = transposeOperand(filter, {2, 0, 1});
313  ArrayRef<int64_t> filterTShape =
314  cast<RankedTensorType>(filterT.getType()).getShape();
315  ArrayRef<int64_t> outputShape = outputType.getShape();
316 
317  int n = outputShape[0];
318  int oh = outputShape[1];
319  int ow = outputShape[2];
320  int c = outputShape[3];
321  int fh = filterTShape[1];
322  int fw = filterTShape[2];
323 
324  SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
325  Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
326 
327  AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
328  bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
329 
330  AffineExpr shSym = rewriter.getAffineConstantExpr(
331  convOp.getStrides().getValues<int64_t>()[0]);
332  AffineExpr swSym = rewriter.getAffineConstantExpr(
333  convOp.getStrides().getValues<int64_t>()[1]);
334 
335  SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
336  owDim * swSym + kwDim};
337 
338  auto nloops = colTensorShape.size();
339 
340  SmallVector<utils::IteratorType> loopAttributeTypes(
341  nloops, utils::IteratorType::parallel);
342 
343  SmallVector<AffineMap> indexingMaps = {
344  AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
345  AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
346 
347  Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
348  inputType.getElementType());
349 
350  auto img2ColTensor = linalg::GenericOp::create(
351  rewriter, loc, colTensor.getType(),
352  /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
353  loopAttributeTypes,
354  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
355  linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
356  });
357 
358  SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
359  {0, 1}, {2, 3}, {4, 5}};
360  SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
361  SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
362  {2, 3}};
363 
364  auto reshapedImg2ColTensorType = RankedTensorType::get(
365  {n * c, oh * ow, fh * fw}, inputType.getElementType());
366  auto reshapedFilterTensorType =
367  RankedTensorType::get({c, fh * fw}, filterType.getElementType());
368  auto reshapedOutputTensorType =
369  RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
370 
371  Value reshapedImg2ColTensor = tensor::CollapseShapeOp::create(
372  rewriter, loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
373  img2ColTensorReassocIndices);
374  Value reshapedFilterTensor =
375  tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterTensorType,
376  filterT, filterReassociationIndice);
377  Value reshapedoutputTensor = tensor::CollapseShapeOp::create(
378  rewriter, loc, reshapedOutputTensorType, transposedOutputTensor,
379  outputReassociationIndice);
380 
381  auto batchMatVecResult = linalg::BatchMatvecOp::create(
382  rewriter, loc, TypeRange{reshapedoutputTensor.getType()},
383  ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
384  ValueRange{reshapedoutputTensor});
385 
386  SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
387  {2, 3}};
388 
389  auto batchMatVecResultReshaped = tensor::ExpandShapeOp::create(
390  rewriter, loc, transposedOutputTensor.getType(),
391  batchMatVecResult.getResult(0), batchMatVecReassociationIndice);
392 
393  Value transposedResult =
394  transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
395 
396  rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
397  return std::make_pair(img2ColTensor.getOperation(),
398  transposedResult.getDefiningOp());
399 }
400 
401 FailureOr<std::pair<Operation *, Operation *>>
402 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
403  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
404  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
405  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
406 
407  if (!filterType.hasStaticShape())
408  return rewriter.notifyMatchFailure(
409  convOp, "expected a static shape for the filter");
410 
411  if (!inputType.hasStaticShape())
412  return rewriter.notifyMatchFailure(convOp,
413  "expected a static shape for the input");
414 
415  // TODO: Support dilation.
416  if (!hasAllOneValues(convOp.getDilations()))
417  return rewriter.notifyMatchFailure(convOp,
418  "expected all ones for dilations");
419 
420  Value input = convOp.getInputs()[0];
421  Value filter = convOp.getInputs()[1];
422  Value output = convOp.getOutputs()[0];
423 
424  auto filterShape = filterType.getShape();
425  auto outputShape = outputType.getShape();
426 
427  int64_t n = outputShape[0];
428  int64_t oc = outputShape[1];
429  int64_t oh = outputShape[2];
430  int64_t ow = outputShape[3];
431  int64_t ic = filterShape[1];
432  int64_t fh = filterShape[2];
433  int64_t fw = filterShape[3];
434 
435  auto loc = convOp.getLoc();
436  MLIRContext *context = rewriter.getContext();
437 
438  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
439  auto reshapedFilterType =
440  RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
441  Value reshapedFilter = tensor::CollapseShapeOp::create(
442  rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
443 
444  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
445  auto reshapedOutputType =
446  RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
447  Value reshapedOutput = tensor::CollapseShapeOp::create(
448  rewriter, loc, reshapedOutputType, output, outputReassocIndices);
449 
450  // Convert the input to a (BKN) tensor.
451  SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
452  Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
453  inputType.getElementType());
454 
455  auto nloops = colTensorShape.size();
456 
457  auto parallel = utils::IteratorType::parallel;
458  auto reduction = utils::IteratorType::reduction;
459  SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
460 
461  // Recover the original iteration indices from the problem/input sizes:
462  // given an index of the im2col matrix, retrieve the corresponding indices of
463  // the output and filter matrices
464  auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(1U),
465  ArrayRef<int64_t>{fh * fw, fw, 1});
466  auto mIndicesExprs =
467  delinearize(rewriter.getAffineDimExpr(2U), ArrayRef<int64_t>{ow, 1});
468  Im2ColToOperandsExprs i2cToOperExprs;
469  i2cToOperExprs.icIndex = kIndicesExprs[0];
470  i2cToOperExprs.fhIndex = kIndicesExprs[1];
471  i2cToOperExprs.fwIndex = kIndicesExprs[2];
472  i2cToOperExprs.ohIndex = mIndicesExprs[0];
473  i2cToOperExprs.owIndex = mIndicesExprs[1];
475  i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
476  rewriter);
477  auto inMap =
479  inExprs.hIndex, inExprs.wIndex}},
480  rewriter.getContext())[0];
481  // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
482  SmallVector<AffineMap> img2colIndexingMaps = {
483  inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
484 
485  auto img2ColTensor = linalg::GenericOp::create(
486  rewriter, loc, colTensor.getType(),
487  /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
488  img2colIterators,
489  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
490  linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
491  });
492 
493  // Because the filter does not share the same batch dimension,
494  // the batch dimension is only used in indexing the input and output. Thus
495  // we cannot use existing linalg named ops like linalg.batch_matmul.
496  // i.e. M x K * (B x) K x N = (B x) M x N
497  AffineExpr bDim, mDim, nDim, kDim;
498  bindDims(context, bDim, mDim, nDim, kDim);
499  auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
500  auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
501  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
502  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
503  parallel, reduction};
504  auto genericOp = linalg::GenericOp::create(
505  rewriter, loc, reshapedOutputType,
506  /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
507  /*outputs=*/ValueRange{reshapedOutput},
508  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
509  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
510  Value mul =
511  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
512  Value add = createAdd(loc, mul, args[2], nestedBuilder);
513  linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
514  });
515  Value result = genericOp.getResults().front();
516 
517  auto reshapedResult = tensor::ExpandShapeOp::create(
518  rewriter, loc, outputType, result, outputReassocIndices);
519 
520  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
521 
522  return std::make_pair(img2ColTensor.getOperation(),
523  reshapedResult.getOperation());
524 }
525 
526 FailureOr<std::pair<Operation *, Operation *>>
527 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
528  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
529  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
530  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
531 
532  if (!filterType.hasStaticShape())
533  return rewriter.notifyMatchFailure(
534  convOp, "expected a static shape for the filter");
535 
536  if (!inputType.hasStaticShape())
537  return rewriter.notifyMatchFailure(convOp,
538  "expected a static shape for the input");
539 
540  // TODO: Support dilation.
541  if (!hasAllOneValues(convOp.getDilations()))
542  return rewriter.notifyMatchFailure(convOp,
543  "expected all ones for dilations");
544 
545  MLIRContext *context = rewriter.getContext();
546  Value input = convOp.getInputs()[0];
547  Value filter = convOp.getInputs()[1];
548  Value output = convOp.getOutputs()[0];
549 
550  ArrayRef<int64_t> filterShape = filterType.getShape();
551  ArrayRef<int64_t> outputShape = outputType.getShape();
552 
553  int64_t n = outputShape[0];
554  int64_t oh = outputShape[1];
555  int64_t ow = outputShape[2];
556  int64_t oc = outputShape[3];
557  int64_t fh = filterShape[1];
558  int64_t fw = filterShape[2];
559  int64_t ic = filterShape[3];
560 
561  Location loc = convOp.getLoc();
562 
563  // Reshape output and filter to the LHS and result of a "row-wise" matrix
564  // multiplication.
565  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
566  auto reshapedFilterType =
567  RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
568  Value reshapedFilter = tensor::CollapseShapeOp::create(
569  rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
570 
571  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
572  RankedTensorType reshapedOutputType =
573  RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
574  Value reshapedOutput = tensor::CollapseShapeOp::create(
575  rewriter, loc, reshapedOutputType, output, outputReassocIndices);
576 
577  // Shape of the Toeplitz matrix produced by Im2col.
578  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
579  Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
580  inputType.getElementType());
581 
582  // Convert the input to a (BMK) column tensor.
583  auto nloops = colTensorShape.size();
584 
585  auto parallel = utils::IteratorType::parallel;
586  auto reduction = utils::IteratorType::reduction;
587  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
588 
589  // Given an index of the im2col matrix, retrieve the corresponding indices of
590  // the output and filter matrices
591  auto mIndicesExprs =
592  delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
593  auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
594  ArrayRef<int64_t>{fw * ic, ic, 1});
595  Im2ColToOperandsExprs i2cToOperExprs;
596  i2cToOperExprs.fhIndex = kIndicesExprs[0];
597  i2cToOperExprs.fwIndex = kIndicesExprs[1];
598  i2cToOperExprs.icIndex = kIndicesExprs[2];
599  i2cToOperExprs.ohIndex = mIndicesExprs[0];
600  i2cToOperExprs.owIndex = mIndicesExprs[1];
601 
602  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
604  i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
605  rewriter);
606  auto inMap =
608  inExprs.wIndex, inExprs.cIndex}},
609  rewriter.getContext())[0];
610  SmallVector<AffineMap> img2colIndexingMaps = {
611  inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
612 
613  auto img2ColTensor = linalg::GenericOp::create(
614  rewriter, loc, colTensor.getType(),
615  /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
616  img2colIterators,
617  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
618  linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
619  });
620 
621  // Because we didn't transpose the filters we don't actually have a batched
622  // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
623  // products.
624  AffineExpr bDim, mDim, nDim, kDim;
625  bindDims(context, bDim, mDim, nDim, kDim);
626  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
627  auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
628  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
629  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
630  parallel, reduction};
631 
632  auto genericOp = linalg::GenericOp::create(
633  rewriter, loc, reshapedOutputType,
634  /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
635  /*outputs=*/ValueRange{reshapedOutput},
636  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
637  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
638  Value mul =
639  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
640  Value add = createAdd(loc, mul, args[2], nestedBuilder);
641  linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
642  });
643  Value result = genericOp.getResults().front();
644 
645  auto reshapedResult = tensor::ExpandShapeOp::create(
646  rewriter, loc, outputType, result, outputReassocIndices);
647 
648  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
649 
650  return std::make_pair(img2ColTensor.getOperation(),
651  reshapedResult.getOperation());
652 }
653 
654 namespace {
655 
656 class ConvertConv2DNhwcHwcf final
657  : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
658 public:
660 
661  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
662  PatternRewriter &rewriter) const override {
663  if (failed(rewriteInIm2Col(rewriter, convOp)))
664  return failure();
665  return success();
666  }
667 };
668 
669 class ConvertDepthwiseConv2DNhwcHwc final
670  : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
671 public:
673 
674  LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
675  PatternRewriter &rewriter) const override {
676  if (failed(rewriteInIm2Col(rewriter, convOp)))
677  return failure();
678  return success();
679  }
680 };
681 
682 class ConvertConv2DNchwFchw final
683  : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
684 public:
686 
687  LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
688  PatternRewriter &rewriter) const override {
689  if (failed(rewriteInIm2Col(rewriter, convOp)))
690  return failure();
691  return success();
692  }
693 };
694 
695 class ConvertConv2DNhwcFhwc final
696  : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
697 public:
699 
700  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
701  PatternRewriter &rewriter) const override {
702  if (failed(rewriteInIm2Col(rewriter, convOp)))
703  return failure();
704  return success();
705  }
706 };
707 } // end anonymous namespace
708 
710  MLIRContext *context = patterns.getContext();
711  patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
712  ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
713 }
714 } // end namespace linalg
715 } // end namespace mlir
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:371
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:363
MLIRContext * getContext() const
Definition: Builders.h:56
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static Im2ColToInputDimsExprs getIm2ColInputExpressions(Im2ColToOperandsExprs exprs, ArrayRef< int64_t > strides, RewriterBase &rewriter)
Construct the affine expressions that map the indices of the im2col matrix to the corresponding input...
static bool hasAllOneValues(DenseIntElementsAttr attr)
static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride, bool useSymbols=true)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:238
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322