MLIR  22.0.0git
ConvertConv2DToImg2Col.cpp
Go to the documentation of this file.
1 //===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 #include "mlir/IR/AffineExpr.h"
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include <cassert>
24 #include <utility>
25 
26 namespace mlir {
27 namespace linalg {
29  return llvm::all_of(
30  attr, [](const APInt &element) { return element.getSExtValue() == 1; });
31 }
32 
33 static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
34  if (isa<IntegerType>(x.getType()))
35  return arith::AddIOp::create(builder, loc, x, y);
36  if (isa<ComplexType>(x.getType()))
37  return complex::AddOp::create(builder, loc, x, y);
38  return arith::AddFOp::create(builder, loc, x, y);
39 }
40 
41 static Value createMul(Location loc, Value x, Value y, Type accType,
42  OpBuilder &builder) {
43  // Linalg named ops specify signed extend for named ops.
44  Value xConvert =
45  convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
46  Value yConvert =
47  convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
48  if (isa<ComplexType>(accType))
49  return complex::MulOp::create(builder, loc, xConvert, yConvert);
50  if (isa<IntegerType>(accType))
51  return arith::MulIOp::create(builder, loc, xConvert, yConvert);
52  return arith::MulFOp::create(builder, loc, xConvert, yConvert);
53 }
54 
55 // Generate the affine expression to compute the convolved index
56 // for the input as `oIndex * stride + fIndex`,
57 // where oIndex: output iterator; fIndex: filter iterator.
58 static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride,
59  bool useSymbols = true) {
60  AffineExpr oExpr, fExpr;
61  if (useSymbols)
62  bindSymbols(b.getContext(), oExpr, fExpr);
63  else
64  bindDims(b.getContext(), oExpr, fExpr);
65  return AffineExpr(stride * oExpr + fExpr);
66 }
67 
68 // Stores the affine expressions to map the iteration space of the im2col matrix
69 // to the corresponding indices of the output and filter matrices
76 };
77 
78 // Stores the affine expressions to map the iteration space of the im2col matrix
79 // to the input matrix indices
85 };
86 
87 /// Construct the affine expressions that map the indices of the im2col matrix
88 /// to the corresponding input tensor indices for a 2D convolution with the the
89 /// provided strides.
90 ///
91 /// @param exprs Affine expressions for output and filter indices.
92 /// @param strides [height, width] stride values for the convolution.
93 /// @param rewriter Pattern rewriter.
94 /// @return Affine expressions mapping im2col matrix indices to input
95 /// offsets.
98  ArrayRef<int64_t> strides, RewriterBase &rewriter) {
99  // maps the iteration space of the im2col matrix to (output_y, filter_y)
100  auto hIndicesMap = AffineMap::inferFromExprList(
101  {ArrayRef{exprs.ohIndex, exprs.fhIndex}}, rewriter.getContext())[0];
102  // maps the iteration space of the im2col matrix to (output_x, filter_x)
103  auto wIndicesMap = AffineMap::inferFromExprList(
104  {ArrayRef{exprs.owIndex, exprs.fwIndex}}, rewriter.getContext())[0];
105  // Compute the input indexing map, to map the indices of the im2col matrix to
106  // the original input offsets. Each element of the im2col matrix corresponds
107  // to a pair of (out_element, filter_element). First, we build the expressions
108  // to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs;
109  // then we compose them with the maps that map the im2col matrix elements to
110  // the (out_element, filter_element) pairs.
111  auto bIndexExpr = rewriter.getAffineDimExpr(0U);
112  auto hIndexExpr = getConvolvedExpr(rewriter, strides[0],
113  /*useSymbols*/ false);
114  hIndexExpr = hIndexExpr.compose(hIndicesMap);
115  auto wIndexExpr = getConvolvedExpr(rewriter, strides[1],
116  /*useSymbols*/ false);
117  wIndexExpr = wIndexExpr.compose(wIndicesMap);
118  auto cIndexExpr = exprs.icIndex;
119  return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
120 }
121 
122 FailureOr<std::pair<Operation *, Operation *>>
123 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
124  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
125  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
126  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
127 
128  if (!convOp.hasPureTensorSemantics())
129  return rewriter.notifyMatchFailure(
130  convOp, "expected op to have pure tensor semantics");
131 
132  if (!filterType.hasStaticShape())
133  return rewriter.notifyMatchFailure(
134  convOp, "expected a static shape for the filter");
135 
136  if (!inputType.hasStaticShape())
137  return rewriter.notifyMatchFailure(convOp,
138  "expected a static shape for the input");
139 
140  // TODO: Support dilation.
141  if (!hasAllOneValues(convOp.getDilations()))
142  return rewriter.notifyMatchFailure(convOp,
143  "expected all ones for dilations");
144 
145  MLIRContext *context = rewriter.getContext();
146  Value input = convOp.getInputs()[0];
147  Value filter = convOp.getInputs()[1];
148  Value output = convOp.getOutputs()[0];
149 
150  ArrayRef<int64_t> filterShape = filterType.getShape();
151  ArrayRef<int64_t> outputShape = outputType.getShape();
152 
153  int64_t n = outputShape[0];
154  int64_t oh = outputShape[1];
155  int64_t ow = outputShape[2];
156  int64_t oc = outputShape[3];
157  int64_t fh = filterShape[0];
158  int64_t fw = filterShape[1];
159  int64_t ic = filterShape[2];
160 
161  Location loc = convOp.getLoc();
162 
163  assert(isa<RankedTensorType>(filterType) &&
164  "expected filter type to be a ranked tensor");
165  auto tensorFilterType = cast<RankedTensorType>(filterType);
166 
167  // Reshape output and filter to the LHS and result of a (B)MNK matmul.
168  SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
169  auto reshapedFilterType =
170  RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(),
171  tensorFilterType.getEncoding());
172  Value reshapedFilter = tensor::CollapseShapeOp::create(
173  rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
174 
175  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
176  RankedTensorType reshapedOutputType =
177  RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
178  Value reshapedOutput = tensor::CollapseShapeOp::create(
179  rewriter, loc, reshapedOutputType, output, outputReassocIndices);
180 
181  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
182  Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
183  inputType.getElementType());
184 
185  // Convert the input to a (BMK) column tensor.
186  auto nloops = colTensorShape.size();
187 
188  auto parallel = utils::IteratorType::parallel;
189  auto reduction = utils::IteratorType::reduction;
190  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
191 
192  // Given an index of the im2col matrix, retrieve the corresponding indices of
193  // the output and filter matrices
194  auto mIndicesExprs =
195  delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
196  auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
197  ArrayRef<int64_t>{fw * ic, ic, 1});
198  Im2ColToOperandsExprs i2cToOperExprs;
199  i2cToOperExprs.fhIndex = kIndicesExprs[0];
200  i2cToOperExprs.fwIndex = kIndicesExprs[1];
201  i2cToOperExprs.icIndex = kIndicesExprs[2];
202  i2cToOperExprs.ohIndex = mIndicesExprs[0];
203  i2cToOperExprs.owIndex = mIndicesExprs[1];
204 
205  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
207  i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
208  rewriter);
209  auto inMap =
211  inExprs.wIndex, inExprs.cIndex}},
212  rewriter.getContext())[0];
213 
214  SmallVector<AffineMap> img2colIndexingMaps = {
215  inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
216 
217  auto img2ColTensor = linalg::GenericOp::create(
218  rewriter, loc, colTensor.getType(),
219  /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
220  img2colIterators,
221  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
222  linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
223  });
224 
225  // Because the filter does not share the same batch dimension,
226  // the batch dimension is only used in indexing the input and output. Thus
227  // we cannot use existing linalg named ops like linalg.batch_matmul.
228  // i.e. (B x) M x K * K x N = (B x) M x N
229  AffineExpr bDim, mDim, nDim, kDim;
230  bindDims(context, bDim, mDim, nDim, kDim);
231  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
232  auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
233  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
234  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
235  parallel, reduction};
236 
237  auto genericOp = linalg::GenericOp::create(
238  rewriter, loc, reshapedOutputType,
239  /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
240  /*outputs=*/ValueRange{reshapedOutput},
241  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
242  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
243  Value mul =
244  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
245  Value add = createAdd(loc, mul, args[2], nestedBuilder);
246  linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
247  });
248  Value result = genericOp.getResults().front();
249 
250  auto reshapedResult = tensor::ExpandShapeOp::create(
251  rewriter, loc, outputType, result, outputReassocIndices);
252 
253  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
254 
255  return std::make_pair(img2ColTensor.getOperation(),
256  reshapedResult.getOperation());
257 }
258 
259 FailureOr<std::pair<Operation *, Operation *>>
261  linalg::DepthwiseConv2DNhwcHwcOp convOp) {
262  auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
263  auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
264  auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
265 
266  if (!convOp.hasPureTensorSemantics())
267  return rewriter.notifyMatchFailure(
268  convOp, "expected op to have pure tensor semantics");
269 
270  if (!filterType.hasStaticShape())
271  return rewriter.notifyMatchFailure(
272  convOp, "expected a static shape for the filter");
273 
274  if (!inputType.hasStaticShape())
275  return rewriter.notifyMatchFailure(convOp,
276  "expected a static shape for the input");
277 
278  // TODO: Support dilation.
279  if (!hasAllOneValues(convOp.getDilations()))
280  return rewriter.notifyMatchFailure(convOp,
281  "expected all ones for dilations");
282 
283  Location loc = convOp.getLoc();
284 
285  auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
286  auto operandTensorType = cast<RankedTensorType>(operand.getType());
287  auto nloops = indices.size();
288  ArrayRef<int64_t> inputShape = operandTensorType.getShape();
289 
290  SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
291  llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
292  return rewriter.getAffineDimExpr(index);
293  }));
294 
295  SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
296  indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
297 
298  Value outputTensor = tensor::EmptyOp::create(
299  rewriter, loc, targetShape, operandTensorType.getElementType());
300 
301  SmallVector<utils::IteratorType> loopAttributeTypes(
302  nloops, utils::IteratorType::parallel);
303 
304  SmallVector<AffineMap> indexingMaps = {
306  AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
307  AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
308 
309  auto transposedOp = linalg::GenericOp::create(
310  rewriter, loc, outputTensor.getType(),
311  /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
312  loopAttributeTypes,
313  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
314  linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
315  });
316 
317  return transposedOp.getResult(0);
318  };
319 
320  Value input = convOp.getInputs()[0];
321  Value filter = convOp.getInputs()[1];
322  Value output = convOp.getOutputs()[0];
323 
324  // Transpose input, filter so channels are outermost
325  Value inputT = transposeOperand(input, {0, 3, 1, 2});
326  Value filterT = transposeOperand(filter, {2, 0, 1});
327  ArrayRef<int64_t> filterTShape =
328  cast<RankedTensorType>(filterT.getType()).getShape();
329  ArrayRef<int64_t> outputShape = outputType.getShape();
330 
331  int n = outputShape[0];
332  int oh = outputShape[1];
333  int ow = outputShape[2];
334  int c = outputShape[3];
335  int fh = filterTShape[1];
336  int fw = filterTShape[2];
337 
338  SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
339  Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
340 
341  AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
342  bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
343 
344  AffineExpr shSym = rewriter.getAffineConstantExpr(
345  convOp.getStrides().getValues<int64_t>()[0]);
346  AffineExpr swSym = rewriter.getAffineConstantExpr(
347  convOp.getStrides().getValues<int64_t>()[1]);
348 
349  SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
350  owDim * swSym + kwDim};
351 
352  auto nloops = colTensorShape.size();
353 
354  SmallVector<utils::IteratorType> loopAttributeTypes(
355  nloops, utils::IteratorType::parallel);
356 
357  SmallVector<AffineMap> indexingMaps = {
358  AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
359  AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
360 
361  Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
362  inputType.getElementType());
363 
364  auto img2ColTensor = linalg::GenericOp::create(
365  rewriter, loc, colTensor.getType(),
366  /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
367  loopAttributeTypes,
368  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
369  linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
370  });
371 
372  SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
373  {0, 1}, {2, 3}, {4, 5}};
374  SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
375  SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
376  {2, 3}};
377 
378  auto reshapedImg2ColTensorType = RankedTensorType::get(
379  {n * c, oh * ow, fh * fw}, inputType.getElementType());
380  auto reshapedFilterTensorType =
381  RankedTensorType::get({c, fh * fw}, filterType.getElementType());
382  auto reshapedOutputTensorType =
383  RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
384 
385  Value reshapedImg2ColTensor = tensor::CollapseShapeOp::create(
386  rewriter, loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
387  img2ColTensorReassocIndices);
388  Value reshapedFilterTensor =
389  tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterTensorType,
390  filterT, filterReassociationIndice);
391  Value reshapedoutputTensor = tensor::CollapseShapeOp::create(
392  rewriter, loc, reshapedOutputTensorType, transposedOutputTensor,
393  outputReassociationIndice);
394 
395  auto batchMatVecResult = linalg::BatchMatvecOp::create(
396  rewriter, loc, TypeRange{reshapedoutputTensor.getType()},
397  ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
398  ValueRange{reshapedoutputTensor});
399 
400  SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
401  {2, 3}};
402 
403  auto batchMatVecResultReshaped = tensor::ExpandShapeOp::create(
404  rewriter, loc, transposedOutputTensor.getType(),
405  batchMatVecResult.getResult(0), batchMatVecReassociationIndice);
406 
407  Value transposedResult =
408  transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
409 
410  rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
411  return std::make_pair(img2ColTensor.getOperation(),
412  transposedResult.getDefiningOp());
413 }
414 
415 FailureOr<std::pair<Operation *, Operation *>>
416 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
417  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
418  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
419  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
420 
421  if (!convOp.hasPureTensorSemantics())
422  return rewriter.notifyMatchFailure(
423  convOp, "expected op to have pure tensor semantics");
424 
425  if (!filterType.hasStaticShape())
426  return rewriter.notifyMatchFailure(
427  convOp, "expected a static shape for the filter");
428 
429  if (!inputType.hasStaticShape())
430  return rewriter.notifyMatchFailure(convOp,
431  "expected a static shape for the input");
432 
433  // TODO: Support dilation.
434  if (!hasAllOneValues(convOp.getDilations()))
435  return rewriter.notifyMatchFailure(convOp,
436  "expected all ones for dilations");
437 
438  Value input = convOp.getInputs()[0];
439  Value filter = convOp.getInputs()[1];
440  Value output = convOp.getOutputs()[0];
441 
442  auto filterShape = filterType.getShape();
443  auto outputShape = outputType.getShape();
444 
445  int64_t n = outputShape[0];
446  int64_t oc = outputShape[1];
447  int64_t oh = outputShape[2];
448  int64_t ow = outputShape[3];
449  int64_t ic = filterShape[1];
450  int64_t fh = filterShape[2];
451  int64_t fw = filterShape[3];
452 
453  auto loc = convOp.getLoc();
454  MLIRContext *context = rewriter.getContext();
455 
456  assert(isa<RankedTensorType>(filterType) &&
457  "expected filter type to be a ranked tensor");
458  auto tensorFilterType = cast<RankedTensorType>(filterType);
459 
460  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
461  auto reshapedFilterType =
462  RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(),
463  tensorFilterType.getEncoding());
464  Value reshapedFilter = tensor::CollapseShapeOp::create(
465  rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
466 
467  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
468  auto reshapedOutputType =
469  RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
470  Value reshapedOutput = tensor::CollapseShapeOp::create(
471  rewriter, loc, reshapedOutputType, output, outputReassocIndices);
472 
473  // Convert the input to a (BKN) tensor.
474  SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
475  Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
476  inputType.getElementType());
477 
478  auto nloops = colTensorShape.size();
479 
480  auto parallel = utils::IteratorType::parallel;
481  auto reduction = utils::IteratorType::reduction;
482  SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
483 
484  // Recover the original iteration indices from the problem/input sizes:
485  // given an index of the im2col matrix, retrieve the corresponding indices of
486  // the output and filter matrices
487  auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(1U),
488  ArrayRef<int64_t>{fh * fw, fw, 1});
489  auto mIndicesExprs =
490  delinearize(rewriter.getAffineDimExpr(2U), ArrayRef<int64_t>{ow, 1});
491  Im2ColToOperandsExprs i2cToOperExprs;
492  i2cToOperExprs.icIndex = kIndicesExprs[0];
493  i2cToOperExprs.fhIndex = kIndicesExprs[1];
494  i2cToOperExprs.fwIndex = kIndicesExprs[2];
495  i2cToOperExprs.ohIndex = mIndicesExprs[0];
496  i2cToOperExprs.owIndex = mIndicesExprs[1];
498  i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
499  rewriter);
500  auto inMap =
502  inExprs.hIndex, inExprs.wIndex}},
503  rewriter.getContext())[0];
504  // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
505  SmallVector<AffineMap> img2colIndexingMaps = {
506  inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
507 
508  auto img2ColTensor = linalg::GenericOp::create(
509  rewriter, loc, colTensor.getType(),
510  /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
511  img2colIterators,
512  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
513  linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
514  });
515 
516  // Because the filter does not share the same batch dimension,
517  // the batch dimension is only used in indexing the input and output. Thus
518  // we cannot use existing linalg named ops like linalg.batch_matmul.
519  // i.e. M x K * (B x) K x N = (B x) M x N
520  AffineExpr bDim, mDim, nDim, kDim;
521  bindDims(context, bDim, mDim, nDim, kDim);
522  auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
523  auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
524  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
525  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
526  parallel, reduction};
527  auto genericOp = linalg::GenericOp::create(
528  rewriter, loc, reshapedOutputType,
529  /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
530  /*outputs=*/ValueRange{reshapedOutput},
531  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
532  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
533  Value mul =
534  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
535  Value add = createAdd(loc, mul, args[2], nestedBuilder);
536  linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
537  });
538  Value result = genericOp.getResults().front();
539 
540  auto reshapedResult = tensor::ExpandShapeOp::create(
541  rewriter, loc, outputType, result, outputReassocIndices);
542 
543  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
544 
545  return std::make_pair(img2ColTensor.getOperation(),
546  reshapedResult.getOperation());
547 }
548 
549 FailureOr<std::pair<Operation *, Operation *>>
550 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
551  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
552  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
553  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
554 
555  if (!convOp.hasPureTensorSemantics())
556  return rewriter.notifyMatchFailure(
557  convOp, "expected op to have pure tensor semantics");
558 
559  if (!filterType.hasStaticShape())
560  return rewriter.notifyMatchFailure(
561  convOp, "expected a static shape for the filter");
562 
563  if (!inputType.hasStaticShape())
564  return rewriter.notifyMatchFailure(convOp,
565  "expected a static shape for the input");
566 
567  // TODO: Support dilation.
568  if (!hasAllOneValues(convOp.getDilations()))
569  return rewriter.notifyMatchFailure(convOp,
570  "expected all ones for dilations");
571 
572  MLIRContext *context = rewriter.getContext();
573  Value input = convOp.getInputs()[0];
574  Value filter = convOp.getInputs()[1];
575  Value output = convOp.getOutputs()[0];
576 
577  ArrayRef<int64_t> filterShape = filterType.getShape();
578  ArrayRef<int64_t> outputShape = outputType.getShape();
579 
580  int64_t n = outputShape[0];
581  int64_t oh = outputShape[1];
582  int64_t ow = outputShape[2];
583  int64_t oc = outputShape[3];
584  int64_t fh = filterShape[1];
585  int64_t fw = filterShape[2];
586  int64_t ic = filterShape[3];
587 
588  Location loc = convOp.getLoc();
589 
590  assert(isa<RankedTensorType>(filterType) &&
591  "expected filter type to be a ranked tensor");
592  auto tensorFilterType = cast<RankedTensorType>(filterType);
593 
594  // Reshape output and filter to the LHS and result of a "row-wise" matrix
595  // multiplication.
596  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
597  auto reshapedFilterType =
598  RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(),
599  tensorFilterType.getEncoding());
600  Value reshapedFilter = tensor::CollapseShapeOp::create(
601  rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
602 
603  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
604  RankedTensorType reshapedOutputType =
605  RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
606  Value reshapedOutput = tensor::CollapseShapeOp::create(
607  rewriter, loc, reshapedOutputType, output, outputReassocIndices);
608 
609  // Shape of the Toeplitz matrix produced by Im2col.
610  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
611  Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
612  inputType.getElementType());
613 
614  // Convert the input to a (BMK) column tensor.
615  auto nloops = colTensorShape.size();
616 
617  auto parallel = utils::IteratorType::parallel;
618  auto reduction = utils::IteratorType::reduction;
619  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
620 
621  // Given an index of the im2col matrix, retrieve the corresponding indices of
622  // the output and filter matrices
623  auto mIndicesExprs =
624  delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
625  auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
626  ArrayRef<int64_t>{fw * ic, ic, 1});
627  Im2ColToOperandsExprs i2cToOperExprs;
628  i2cToOperExprs.fhIndex = kIndicesExprs[0];
629  i2cToOperExprs.fwIndex = kIndicesExprs[1];
630  i2cToOperExprs.icIndex = kIndicesExprs[2];
631  i2cToOperExprs.ohIndex = mIndicesExprs[0];
632  i2cToOperExprs.owIndex = mIndicesExprs[1];
633 
634  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
636  i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
637  rewriter);
638  auto inMap =
640  inExprs.wIndex, inExprs.cIndex}},
641  rewriter.getContext())[0];
642  SmallVector<AffineMap> img2colIndexingMaps = {
643  inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
644 
645  auto img2ColTensor = linalg::GenericOp::create(
646  rewriter, loc, colTensor.getType(),
647  /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
648  img2colIterators,
649  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
650  linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
651  });
652 
653  // Because we didn't transpose the filters we don't actually have a batched
654  // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
655  // products.
656  AffineExpr bDim, mDim, nDim, kDim;
657  bindDims(context, bDim, mDim, nDim, kDim);
658  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
659  auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
660  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
661  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
662  parallel, reduction};
663 
664  auto genericOp = linalg::GenericOp::create(
665  rewriter, loc, reshapedOutputType,
666  /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
667  /*outputs=*/ValueRange{reshapedOutput},
668  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
669  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
670  Value mul =
671  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
672  Value add = createAdd(loc, mul, args[2], nestedBuilder);
673  linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
674  });
675  Value result = genericOp.getResults().front();
676 
677  auto reshapedResult = tensor::ExpandShapeOp::create(
678  rewriter, loc, outputType, result, outputReassocIndices);
679 
680  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
681 
682  return std::make_pair(img2ColTensor.getOperation(),
683  reshapedResult.getOperation());
684 }
685 
686 namespace {
687 
688 class ConvertConv2DNhwcHwcf final
689  : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
690 public:
692 
693  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
694  PatternRewriter &rewriter) const override {
695  if (failed(rewriteInIm2Col(rewriter, convOp)))
696  return failure();
697  return success();
698  }
699 };
700 
701 class ConvertDepthwiseConv2DNhwcHwc final
702  : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
703 public:
705 
706  LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
707  PatternRewriter &rewriter) const override {
708  if (failed(rewriteInIm2Col(rewriter, convOp)))
709  return failure();
710  return success();
711  }
712 };
713 
714 class ConvertConv2DNchwFchw final
715  : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
716 public:
718 
719  LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
720  PatternRewriter &rewriter) const override {
721  if (failed(rewriteInIm2Col(rewriter, convOp)))
722  return failure();
723  return success();
724  }
725 };
726 
727 class ConvertConv2DNhwcFhwc final
728  : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
729 public:
731 
732  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
733  PatternRewriter &rewriter) const override {
734  if (failed(rewriteInIm2Col(rewriter, convOp)))
735  return failure();
736  return success();
737  }
738 };
739 } // end anonymous namespace
740 
742  MLIRContext *context = patterns.getContext();
743  patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
744  ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
745 }
746 } // end namespace linalg
747 } // end namespace mlir
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:372
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:56
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static Im2ColToInputDimsExprs getIm2ColInputExpressions(Im2ColToOperandsExprs exprs, ArrayRef< int64_t > strides, RewriterBase &rewriter)
Construct the affine expressions that map the indices of the im2col matrix to the corresponding input...
static bool hasAllOneValues(DenseIntElementsAttr attr)
static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride, bool useSymbols=true)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:238
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322