MLIR  18.0.0git
ConvertConv2DToImg2Col.cpp
Go to the documentation of this file.
1 //===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
23 #include <utility>
24 
25 namespace mlir {
26 namespace linalg {
28  return llvm::all_of(
29  attr, [](APInt element) { return element.getSExtValue() == 1; });
30 }
31 
32 static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
33  if (isa<IntegerType>(x.getType()))
34  return builder.create<arith::AddIOp>(loc, x, y);
35  if (isa<ComplexType>(x.getType()))
36  return builder.create<complex::AddOp>(loc, x, y);
37  return builder.create<arith::AddFOp>(loc, x, y);
38 }
39 
40 static Value createMul(Location loc, Value x, Value y, Type accType,
41  OpBuilder &builder) {
42  // Linalg named ops specify signed extend for named ops.
43  Value xConvert =
44  convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
45  Value yConvert =
46  convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
47  if (isa<ComplexType>(accType))
48  return builder.create<complex::MulOp>(loc, xConvert, yConvert);
49  if (isa<IntegerType>(accType))
50  return builder.create<arith::MulIOp>(loc, xConvert, yConvert);
51  return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
52 }
53 
54 // Delinearizes the given composite `index` by the basis specified in `factors`.
56  ArrayRef<int64_t> factors) {
57  assert(!factors.empty() && "empty factor list");
58  SmallVector<Value> basis;
59  for (int64_t f : factors)
60  basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f)));
61  FailureOr<SmallVector<Value>> multiIndex =
62  affine::delinearizeIndex(b, loc, index, basis);
63  assert(!failed(multiIndex) && "Failed to linearize img2col index");
64  return *multiIndex;
65 }
66 
67 // Given indices corresponding to iterators in the output (oIndex) and filter
68 // (fIndex) for a convolution, compute the convolved index for the
69 // input as `oIndex * stride + fIndex`.
71  Value fIndex, int64_t stride) {
72  AffineExpr oExpr, fExpr;
73  bindSymbols(b.getContext(), oExpr, fExpr);
74  AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
75  return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
76 }
77 
79 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
80  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
81  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
82  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
83 
84  if (!filterType.hasStaticShape())
85  return rewriter.notifyMatchFailure(
86  convOp, "expected a static shape for the filter");
87 
88  if (!inputType.hasStaticShape())
89  return rewriter.notifyMatchFailure(convOp,
90  "expected a static shape for the input");
91 
92  // TODO: Support dilation.
93  if (!hasAllOneValues(convOp.getDilations()))
94  return rewriter.notifyMatchFailure(convOp,
95  "expected all ones for dilations");
96 
97  MLIRContext *context = rewriter.getContext();
98  Value input = convOp.getInputs()[0];
99  Value filter = convOp.getInputs()[1];
100  Value output = convOp.getOutputs()[0];
101 
102  ArrayRef<int64_t> filterShape = filterType.getShape();
103  ArrayRef<int64_t> outputShape = outputType.getShape();
104 
105  int64_t n = outputShape[0];
106  int64_t oh = outputShape[1];
107  int64_t ow = outputShape[2];
108  int64_t oc = outputShape[3];
109  int64_t fh = filterShape[0];
110  int64_t fw = filterShape[1];
111  int64_t ic = filterShape[2];
112 
113  Location loc = convOp.getLoc();
114 
115  // Reshape output and filter to the LHS and result of a (B)MNK matmul.
116  SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
117  auto reshapedFilterType =
118  RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
119  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
120  loc, reshapedFilterType, filter, filterReassocIndices);
121 
122  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
123  RankedTensorType reshapedOutputType =
124  RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
125  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
126  loc, reshapedOutputType, output, outputReassocIndices);
127 
128  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
129  Value colTensor = rewriter.create<tensor::EmptyOp>(
130  loc, colTensorShape, inputType.getElementType());
131 
132  // Convert the input to a (BMK) column tensor.
133  auto nloops = colTensorShape.size();
134 
135  auto parallel = utils::IteratorType::parallel;
136  auto reduction = utils::IteratorType::reduction;
137  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
138 
139  SmallVector<AffineMap> img2colIndexingMaps = {
140  AffineMap::getMultiDimIdentityMap(nloops, context)};
141 
142  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
143  loc, colTensor.getType(),
144  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
145  img2colIterators,
146  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
147  // Get the iterators named based on the matmul (batch, m, k).
148  Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
149  Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
150  Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
151 
152  // Recover the original iteration indices from the problem/input sizes.
153  SmallVector<Value> mIndices = unrollIndex(
154  nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
155  auto ohIndex = mIndices[0];
156  auto owIndex = mIndices[1];
157 
158  SmallVector<Value> kIndices = unrollIndex(
159  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
160  auto fhIndex = kIndices[0];
161  auto fwIndex = kIndices[1];
162  auto icIndex = kIndices[2];
163 
164  // Extract the input element corresponding to the expanded indices.
165  Value hIndex =
166  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
167  convOp.getStrides().getValues<int64_t>()[0]);
168  Value wIndex =
169  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
170  convOp.getStrides().getValues<int64_t>()[1]);
171 
172  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
173  SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
174  Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
175  loc, input, extractionIndices);
176  nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
177  });
178 
179  // Because the filter does not share the same batch dimension,
180  // the batch dimension is only used in indexing the input and output. Thus
181  // we cannot use existing linalg named ops like linalg.batch_matmul.
182  // i.e. (B x) M x K * K x N = (B x) M x N
183  AffineExpr bDim, mDim, nDim, kDim;
184  bindDims(context, bDim, mDim, nDim, kDim);
185  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
186  auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
187  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
188  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
189  parallel, reduction};
190 
191  auto genericOp = rewriter.create<linalg::GenericOp>(
192  loc, reshapedOutputType,
193  /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
194  /*outputs=*/ValueRange{reshapedOutput},
195  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
196  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
197  Value mul =
198  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
199  Value add = createAdd(loc, mul, args[2], nestedBuilder);
200  nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
201  });
202  Value result = genericOp.getResults().front();
203 
204  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
205  loc, outputType, result, outputReassocIndices);
206 
207  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
208 
209  return std::make_pair(img2ColTensor.getOperation(),
210  reshapedResult.getOperation());
211 }
212 
215  linalg::DepthwiseConv2DNhwcHwcOp convOp) {
216  auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
217  auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
218  auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
219 
220  if (!filterType.hasStaticShape())
221  return rewriter.notifyMatchFailure(
222  convOp, "expected a static shape for the filter");
223 
224  if (!inputType.hasStaticShape())
225  return rewriter.notifyMatchFailure(convOp,
226  "expected a static shape for the input");
227 
228  // TODO: Support dilation.
229  if (!hasAllOneValues(convOp.getDilations()))
230  return rewriter.notifyMatchFailure(convOp,
231  "expected all ones for dilations");
232 
233  Location loc = convOp.getLoc();
234 
235  auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
236  auto operandTensorType = cast<RankedTensorType>(operand.getType());
237  auto nloops = indices.size();
238  ArrayRef<int64_t> inputShape = operandTensorType.getShape();
239 
240  SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
241  llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
242  return rewriter.getAffineDimExpr(index);
243  }));
244 
245  SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
246  indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
247 
248  Value outputTensor = rewriter.create<tensor::EmptyOp>(
249  loc, targetShape, operandTensorType.getElementType());
250 
251  SmallVector<utils::IteratorType> loopAttributeTypes(
252  nloops, utils::IteratorType::parallel);
253 
254  SmallVector<AffineMap> indexingMaps = {
256  AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
257  AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
258 
259  auto transposedOp = rewriter.create<linalg::GenericOp>(
260  loc, outputTensor.getType(),
261  /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
262  loopAttributeTypes,
263  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
264  nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
265  });
266 
267  return transposedOp.getResult(0);
268  };
269 
270  Value input = convOp.getInputs()[0];
271  Value filter = convOp.getInputs()[1];
272  Value output = convOp.getOutputs()[0];
273 
274  // Transpose input, filter so channels are outermost
275  Value inputT = transposeOperand(input, {0, 3, 1, 2});
276  Value filterT = transposeOperand(filter, {2, 0, 1});
277  ArrayRef<int64_t> filterTShape =
278  cast<RankedTensorType>(filterT.getType()).getShape();
279  ArrayRef<int64_t> outputShape = outputType.getShape();
280 
281  int n = outputShape[0];
282  int oh = outputShape[1];
283  int ow = outputShape[2];
284  int c = outputShape[3];
285  int fh = filterTShape[1];
286  int fw = filterTShape[2];
287 
288  SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
289  Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
290 
291  AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
292  bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
293 
294  AffineExpr shSym = rewriter.getAffineConstantExpr(
295  convOp.getStrides().getValues<int64_t>()[0]);
296  AffineExpr swSym = rewriter.getAffineConstantExpr(
297  convOp.getStrides().getValues<int64_t>()[1]);
298 
299  SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
300  owDim * swSym + kwDim};
301 
302  auto nloops = colTensorShape.size();
303 
304  SmallVector<utils::IteratorType> loopAttributeTypes(
305  nloops, utils::IteratorType::parallel);
306 
307  SmallVector<AffineMap> indexingMaps = {
308  AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
309  AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
310 
311  Value colTensor = rewriter.create<tensor::EmptyOp>(
312  loc, colTensorShape, inputType.getElementType());
313 
314  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
315  loc, colTensor.getType(),
316  /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
317  loopAttributeTypes,
318  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
319  nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
320  });
321 
322  SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
323  {0, 1}, {2, 3}, {4, 5}};
324  SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
325  SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
326  {2, 3}};
327 
328  auto reshapedImg2ColTensorType = RankedTensorType::get(
329  {n * c, oh * ow, fh * fw}, inputType.getElementType());
330  auto reshapedFilterTensorType =
331  RankedTensorType::get({c, fh * fw}, filterType.getElementType());
332  auto reshapedOutputTensorType =
333  RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
334 
335  Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
336  loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
337  img2ColTensorReassocIndices);
338  Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>(
339  loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
340  Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>(
341  loc, reshapedOutputTensorType, transposedOutputTensor,
342  outputReassociationIndice);
343 
344  auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>(
345  loc, TypeRange{reshapedoutputTensor.getType()},
346  ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
347  ValueRange{reshapedoutputTensor});
348 
349  SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
350  {2, 3}};
351 
352  Value batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
353  loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
354  batchMatVecReassociationIndice);
355 
356  Value transposedResult =
357  transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
358 
359  rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
360  return std::make_pair(img2ColTensor.getOperation(),
361  transposedResult.getDefiningOp());
362 }
363 
365 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
366  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
367  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
368  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
369 
370  if (!filterType.hasStaticShape())
371  return rewriter.notifyMatchFailure(
372  convOp, "expected a static shape for the filter");
373 
374  if (!inputType.hasStaticShape())
375  return rewriter.notifyMatchFailure(convOp,
376  "expected a static shape for the input");
377 
378  // TODO: Support dilation.
379  if (!hasAllOneValues(convOp.getDilations()))
380  return rewriter.notifyMatchFailure(convOp,
381  "expected all ones for dilations");
382 
383  Value input = convOp.getInputs()[0];
384  Value filter = convOp.getInputs()[1];
385  Value output = convOp.getOutputs()[0];
386 
387  auto filterShape = filterType.getShape();
388  auto outputShape = outputType.getShape();
389 
390  int64_t n = outputShape[0];
391  int64_t oc = outputShape[1];
392  int64_t oh = outputShape[2];
393  int64_t ow = outputShape[3];
394  int64_t ic = filterShape[1];
395  int64_t fh = filterShape[2];
396  int64_t fw = filterShape[3];
397 
398  auto loc = convOp.getLoc();
399  MLIRContext *context = rewriter.getContext();
400 
401  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
402  auto reshapedFilterType =
403  RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
404  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
405  loc, reshapedFilterType, filter, filterReassocIndices);
406 
407  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
408  auto reshapedOutputType =
409  RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
410  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
411  loc, reshapedOutputType, output, outputReassocIndices);
412 
413  // Convert the input to a (BKN) tensor.
414  SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
415  Value colTensor = rewriter.create<tensor::EmptyOp>(
416  loc, colTensorShape, inputType.getElementType());
417 
418  auto nloops = colTensorShape.size();
419 
420  auto parallel = utils::IteratorType::parallel;
421  auto reduction = utils::IteratorType::reduction;
422  SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
423 
424  SmallVector<AffineMap, 4> img2colIndexingMaps = {
425  AffineMap::getMultiDimIdentityMap(nloops, context)};
426 
427  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
428  loc, colTensor.getType(),
429  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
430  img2colIterators,
431  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
432  // Get the iterators named based on the matmul (batch, m, k).
433  Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
434  Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
435  Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
436 
437  // Recover the original iteration indices from the problem/input sizes.
438  SmallVector<Value> kIndices = unrollIndex(
439  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
440  auto icIndex = kIndices[0];
441  auto fhIndex = kIndices[1];
442  auto fwIndex = kIndices[2];
443 
444  SmallVector<Value> nIndices = unrollIndex(
445  nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
446  auto ohIndex = nIndices[0];
447  auto owIndex = nIndices[1];
448 
449  // Extract the input element corresponding to the expanded indices.
450  Value hIndex =
451  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
452  convOp.getStrides().getValues<int64_t>()[0]);
453  Value wIndex =
454  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
455  convOp.getStrides().getValues<int64_t>()[1]);
456 
457  // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
458  SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
459  Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
460  loc, input, extractionIndices);
461  nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
462  });
463 
464  // Because the filter does not share the same batch dimension,
465  // the batch dimension is only used in indexing the input and output. Thus
466  // we cannot use existing linalg named ops like linalg.batch_matmul.
467  // i.e. M x K * (B x) K x N = (B x) M x N
468  AffineExpr bDim, mDim, nDim, kDim;
469  bindDims(context, bDim, mDim, nDim, kDim);
470  auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
471  auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
472  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
473  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
474  parallel, reduction};
475  auto genericOp = rewriter.create<linalg::GenericOp>(
476  loc, reshapedOutputType,
477  /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
478  /*outputs=*/ValueRange{reshapedOutput},
479  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
480  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
481  Value mul =
482  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
483  Value add = createAdd(loc, mul, args[2], nestedBuilder);
484  nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
485  });
486  Value result = genericOp.getResults().front();
487 
488  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
489  loc, outputType, result, outputReassocIndices);
490 
491  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
492 
493  return std::make_pair(img2ColTensor.getOperation(),
494  reshapedResult.getOperation());
495 }
496 
498 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
499  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
500  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
501  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
502 
503  if (!filterType.hasStaticShape())
504  return rewriter.notifyMatchFailure(
505  convOp, "expected a static shape for the filter");
506 
507  if (!inputType.hasStaticShape())
508  return rewriter.notifyMatchFailure(convOp,
509  "expected a static shape for the input");
510 
511  // TODO: Support dilation.
512  if (!hasAllOneValues(convOp.getDilations()))
513  return rewriter.notifyMatchFailure(convOp,
514  "expected all ones for dilations");
515 
516  MLIRContext *context = rewriter.getContext();
517  Value input = convOp.getInputs()[0];
518  Value filter = convOp.getInputs()[1];
519  Value output = convOp.getOutputs()[0];
520 
521  ArrayRef<int64_t> filterShape = filterType.getShape();
522  ArrayRef<int64_t> outputShape = outputType.getShape();
523 
524  int64_t n = outputShape[0];
525  int64_t oh = outputShape[1];
526  int64_t ow = outputShape[2];
527  int64_t oc = outputShape[3];
528  int64_t fh = filterShape[1];
529  int64_t fw = filterShape[2];
530  int64_t ic = filterShape[3];
531 
532  Location loc = convOp.getLoc();
533 
534  // Reshape output and filter to the LHS and result of a "row-wise" matrix
535  // multiplication.
536  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
537  auto reshapedFilterType =
538  RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
539  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
540  loc, reshapedFilterType, filter, filterReassocIndices);
541 
542  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
543  RankedTensorType reshapedOutputType =
544  RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
545  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
546  loc, reshapedOutputType, output, outputReassocIndices);
547 
548  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
549  Value colTensor = rewriter.create<tensor::EmptyOp>(
550  loc, colTensorShape, inputType.getElementType());
551 
552  // Convert the input to a (BMK) column tensor.
553  auto nloops = colTensorShape.size();
554 
555  auto parallel = utils::IteratorType::parallel;
556  auto reduction = utils::IteratorType::reduction;
557  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
558 
559  SmallVector<AffineMap> img2colIndexingMaps = {
560  AffineMap::getMultiDimIdentityMap(nloops, context)};
561 
562  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
563  loc, colTensor.getType(),
564  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
565  img2colIterators,
566  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
567  // Get the iterators named based on the matmul (batch, m, k).
568  Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
569  Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
570  Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
571 
572  // Recover the original iteration indices from the problem/input sizes.
573  SmallVector<Value> mIndices = unrollIndex(
574  nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
575  auto ohIndex = mIndices[0];
576  auto owIndex = mIndices[1];
577 
578  SmallVector<Value> kIndices = unrollIndex(
579  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
580  auto fhIndex = kIndices[0];
581  auto fwIndex = kIndices[1];
582  auto icIndex = kIndices[2];
583 
584  // Extract the input element corresponding to the expanded indices.
585  Value hIndex =
586  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
587  convOp.getStrides().getValues<int64_t>()[0]);
588  Value wIndex =
589  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
590  convOp.getStrides().getValues<int64_t>()[1]);
591 
592  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
593  SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
594  Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
595  loc, input, extractionIndices);
596  nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
597  });
598 
599  // Because we didn't transpose the filters we don't actually have a batched
600  // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
601  // products.
602  AffineExpr bDim, mDim, nDim, kDim;
603  bindDims(context, bDim, mDim, nDim, kDim);
604  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
605  auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
606  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
607  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
608  parallel, reduction};
609 
610  auto genericOp = rewriter.create<linalg::GenericOp>(
611  loc, reshapedOutputType,
612  /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
613  /*outputs=*/ValueRange{reshapedOutput},
614  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
615  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
616  Value mul =
617  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
618  Value add = createAdd(loc, mul, args[2], nestedBuilder);
619  nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
620  });
621  Value result = genericOp.getResults().front();
622 
623  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
624  loc, outputType, result, outputReassocIndices);
625 
626  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
627 
628  return std::make_pair(img2ColTensor.getOperation(),
629  reshapedResult.getOperation());
630 }
631 
632 namespace {
633 
634 class ConvertConv2DNhwcHwcf final
635  : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
636 public:
638 
639  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
640  PatternRewriter &rewriter) const override {
641  if (failed(rewriteInIm2Col(rewriter, convOp)))
642  return failure();
643  return success();
644  }
645 };
646 
647 class ConvertDepthwiseConv2DNhwcHwc final
648  : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
649 public:
651 
652  LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
653  PatternRewriter &rewriter) const override {
654  if (failed(rewriteInIm2Col(rewriter, convOp)))
655  return failure();
656  return success();
657  }
658 };
659 
660 class ConvertConv2DNchwFchw final
661  : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
662 public:
664 
665  LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
666  PatternRewriter &rewriter) const override {
667  if (failed(rewriteInIm2Col(rewriter, convOp)))
668  return failure();
669  return success();
670  }
671 };
672 
673 class ConvertConv2DNhwcFhwc final
674  : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
675 public:
677 
678  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
679  PatternRewriter &rewriter) const override {
680  if (failed(rewriteInIm2Col(rewriter, convOp)))
681  return failure();
682  return success();
683  }
684 };
685 } // end anonymous namespace
686 
688  MLIRContext *context = patterns.getContext();
689  patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
690  ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
691 }
692 } // end namespace linalg
693 } // end namespace mlir
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:312
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:361
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:353
MLIRContext * getContext() const
Definition: Builders.h:55
An attribute that represents a reference to a dense integer vector or tensor object.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
Definition: Utils.cpp:1851
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1124
static SmallVector< Value > unrollIndex(OpBuilder &b, Location loc, Value index, ArrayRef< int64_t > factors)
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static bool hasAllOneValues(DenseIntElementsAttr attr)
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, Value fIndex, int64_t stride)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:168
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:334
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:749
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:348
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361