MLIR  20.0.0git
ConvertConv2DToImg2Col.cpp
Go to the documentation of this file.
1 //===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
23 #include <utility>
24 
25 namespace mlir {
26 namespace linalg {
28  return llvm::all_of(
29  attr, [](const APInt &element) { return element.getSExtValue() == 1; });
30 }
31 
32 static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
33  if (isa<IntegerType>(x.getType()))
34  return builder.create<arith::AddIOp>(loc, x, y);
35  if (isa<ComplexType>(x.getType()))
36  return builder.create<complex::AddOp>(loc, x, y);
37  return builder.create<arith::AddFOp>(loc, x, y);
38 }
39 
40 static Value createMul(Location loc, Value x, Value y, Type accType,
41  OpBuilder &builder) {
42  // Linalg named ops specify signed extend for named ops.
43  Value xConvert =
44  convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
45  Value yConvert =
46  convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
47  if (isa<ComplexType>(accType))
48  return builder.create<complex::MulOp>(loc, xConvert, yConvert);
49  if (isa<IntegerType>(accType))
50  return builder.create<arith::MulIOp>(loc, xConvert, yConvert);
51  return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
52 }
53 
54 // Delinearizes the given composite `index` by the basis specified in `factors`.
56  ArrayRef<int64_t> factors) {
57  assert(!factors.empty() && "empty factor list");
58  SmallVector<Value> basis;
59  for (int64_t f : factors)
60  basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f)));
61  FailureOr<SmallVector<Value>> multiIndex =
62  affine::delinearizeIndex(b, loc, index, basis);
63  assert(!failed(multiIndex) && "Failed to linearize img2col index");
64  return *multiIndex;
65 }
66 
67 // Given indices corresponding to iterators in the output (oIndex) and filter
68 // (fIndex) for a convolution, compute the convolved index for the
69 // input as `oIndex * stride + fIndex`.
71  Value fIndex, int64_t stride) {
72  AffineExpr oExpr, fExpr;
73  bindSymbols(b.getContext(), oExpr, fExpr);
74  AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
75  return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
76 }
77 
78 FailureOr<std::pair<Operation *, Operation *>>
79 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
80  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
81  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
82  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
83 
84  if (!filterType.hasStaticShape())
85  return rewriter.notifyMatchFailure(
86  convOp, "expected a static shape for the filter");
87 
88  if (!inputType.hasStaticShape())
89  return rewriter.notifyMatchFailure(convOp,
90  "expected a static shape for the input");
91 
92  // TODO: Support dilation.
93  if (!hasAllOneValues(convOp.getDilations()))
94  return rewriter.notifyMatchFailure(convOp,
95  "expected all ones for dilations");
96 
97  MLIRContext *context = rewriter.getContext();
98  Value input = convOp.getInputs()[0];
99  Value filter = convOp.getInputs()[1];
100  Value output = convOp.getOutputs()[0];
101 
102  ArrayRef<int64_t> filterShape = filterType.getShape();
103  ArrayRef<int64_t> outputShape = outputType.getShape();
104 
105  int64_t n = outputShape[0];
106  int64_t oh = outputShape[1];
107  int64_t ow = outputShape[2];
108  int64_t oc = outputShape[3];
109  int64_t fh = filterShape[0];
110  int64_t fw = filterShape[1];
111  int64_t ic = filterShape[2];
112 
113  Location loc = convOp.getLoc();
114 
115  // Reshape output and filter to the LHS and result of a (B)MNK matmul.
116  SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
117  auto reshapedFilterType =
118  RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
119  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
120  loc, reshapedFilterType, filter, filterReassocIndices);
121 
122  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
123  RankedTensorType reshapedOutputType =
124  RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
125  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
126  loc, reshapedOutputType, output, outputReassocIndices);
127 
128  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
129  Value colTensor = rewriter.create<tensor::EmptyOp>(
130  loc, colTensorShape, inputType.getElementType());
131 
132  // Convert the input to a (BMK) column tensor.
133  auto nloops = colTensorShape.size();
134 
135  auto parallel = utils::IteratorType::parallel;
136  auto reduction = utils::IteratorType::reduction;
137  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
138 
139  SmallVector<AffineMap> img2colIndexingMaps = {
140  AffineMap::getMultiDimIdentityMap(nloops, context)};
141 
142  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
143  loc, colTensor.getType(),
144  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
145  img2colIterators,
146  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
147  // Get the iterators named based on the matmul (batch, m, k).
148  Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
149  Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
150  Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
151 
152  // Recover the original iteration indices from the problem/input sizes.
153  SmallVector<Value> mIndices = unrollIndex(
154  nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
155  auto ohIndex = mIndices[0];
156  auto owIndex = mIndices[1];
157 
158  SmallVector<Value> kIndices = unrollIndex(
159  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
160  auto fhIndex = kIndices[0];
161  auto fwIndex = kIndices[1];
162  auto icIndex = kIndices[2];
163 
164  // Extract the input element corresponding to the expanded indices.
165  Value hIndex =
166  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
167  convOp.getStrides().getValues<int64_t>()[0]);
168  Value wIndex =
169  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
170  convOp.getStrides().getValues<int64_t>()[1]);
171 
172  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
173  SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
174  Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
175  loc, input, extractionIndices);
176  nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
177  });
178 
179  // Because the filter does not share the same batch dimension,
180  // the batch dimension is only used in indexing the input and output. Thus
181  // we cannot use existing linalg named ops like linalg.batch_matmul.
182  // i.e. (B x) M x K * K x N = (B x) M x N
183  AffineExpr bDim, mDim, nDim, kDim;
184  bindDims(context, bDim, mDim, nDim, kDim);
185  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
186  auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
187  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
188  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
189  parallel, reduction};
190 
191  auto genericOp = rewriter.create<linalg::GenericOp>(
192  loc, reshapedOutputType,
193  /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
194  /*outputs=*/ValueRange{reshapedOutput},
195  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
196  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
197  Value mul =
198  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
199  Value add = createAdd(loc, mul, args[2], nestedBuilder);
200  nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
201  });
202  Value result = genericOp.getResults().front();
203 
204  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
205  loc, outputType, result, outputReassocIndices);
206 
207  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
208 
209  return std::make_pair(img2ColTensor.getOperation(),
210  reshapedResult.getOperation());
211 }
212 
213 FailureOr<std::pair<Operation *, Operation *>>
215  linalg::DepthwiseConv2DNhwcHwcOp convOp) {
216  auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
217  auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
218  auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
219 
220  if (!filterType.hasStaticShape())
221  return rewriter.notifyMatchFailure(
222  convOp, "expected a static shape for the filter");
223 
224  if (!inputType.hasStaticShape())
225  return rewriter.notifyMatchFailure(convOp,
226  "expected a static shape for the input");
227 
228  // TODO: Support dilation.
229  if (!hasAllOneValues(convOp.getDilations()))
230  return rewriter.notifyMatchFailure(convOp,
231  "expected all ones for dilations");
232 
233  Location loc = convOp.getLoc();
234 
235  auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
236  auto operandTensorType = cast<RankedTensorType>(operand.getType());
237  auto nloops = indices.size();
238  ArrayRef<int64_t> inputShape = operandTensorType.getShape();
239 
240  SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
241  llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
242  return rewriter.getAffineDimExpr(index);
243  }));
244 
245  SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
246  indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
247 
248  Value outputTensor = rewriter.create<tensor::EmptyOp>(
249  loc, targetShape, operandTensorType.getElementType());
250 
251  SmallVector<utils::IteratorType> loopAttributeTypes(
252  nloops, utils::IteratorType::parallel);
253 
254  SmallVector<AffineMap> indexingMaps = {
256  AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
257  AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
258 
259  auto transposedOp = rewriter.create<linalg::GenericOp>(
260  loc, outputTensor.getType(),
261  /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
262  loopAttributeTypes,
263  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
264  nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
265  });
266 
267  return transposedOp.getResult(0);
268  };
269 
270  Value input = convOp.getInputs()[0];
271  Value filter = convOp.getInputs()[1];
272  Value output = convOp.getOutputs()[0];
273 
274  // Transpose input, filter so channels are outermost
275  Value inputT = transposeOperand(input, {0, 3, 1, 2});
276  Value filterT = transposeOperand(filter, {2, 0, 1});
277  ArrayRef<int64_t> filterTShape =
278  cast<RankedTensorType>(filterT.getType()).getShape();
279  ArrayRef<int64_t> outputShape = outputType.getShape();
280 
281  int n = outputShape[0];
282  int oh = outputShape[1];
283  int ow = outputShape[2];
284  int c = outputShape[3];
285  int fh = filterTShape[1];
286  int fw = filterTShape[2];
287 
288  SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
289  Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
290 
291  AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
292  bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
293 
294  AffineExpr shSym = rewriter.getAffineConstantExpr(
295  convOp.getStrides().getValues<int64_t>()[0]);
296  AffineExpr swSym = rewriter.getAffineConstantExpr(
297  convOp.getStrides().getValues<int64_t>()[1]);
298 
299  SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
300  owDim * swSym + kwDim};
301 
302  auto nloops = colTensorShape.size();
303 
304  SmallVector<utils::IteratorType> loopAttributeTypes(
305  nloops, utils::IteratorType::parallel);
306 
307  SmallVector<AffineMap> indexingMaps = {
308  AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
309  AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
310 
311  Value colTensor = rewriter.create<tensor::EmptyOp>(
312  loc, colTensorShape, inputType.getElementType());
313 
314  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
315  loc, colTensor.getType(),
316  /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
317  loopAttributeTypes,
318  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
319  nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
320  });
321 
322  SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
323  {0, 1}, {2, 3}, {4, 5}};
324  SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
325  SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
326  {2, 3}};
327 
328  auto reshapedImg2ColTensorType = RankedTensorType::get(
329  {n * c, oh * ow, fh * fw}, inputType.getElementType());
330  auto reshapedFilterTensorType =
331  RankedTensorType::get({c, fh * fw}, filterType.getElementType());
332  auto reshapedOutputTensorType =
333  RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
334 
335  Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
336  loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
337  img2ColTensorReassocIndices);
338  Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>(
339  loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
340  Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>(
341  loc, reshapedOutputTensorType, transposedOutputTensor,
342  outputReassociationIndice);
343 
344  auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>(
345  loc, TypeRange{reshapedoutputTensor.getType()},
346  ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
347  ValueRange{reshapedoutputTensor});
348 
349  SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
350  {2, 3}};
351 
352  auto batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
353  loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
354  batchMatVecReassociationIndice);
355 
356  Value transposedResult =
357  transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
358 
359  rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
360  return std::make_pair(img2ColTensor.getOperation(),
361  transposedResult.getDefiningOp());
362 }
363 
364 FailureOr<std::pair<Operation *, Operation *>>
365 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
366  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
367  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
368  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
369 
370  if (!filterType.hasStaticShape())
371  return rewriter.notifyMatchFailure(
372  convOp, "expected a static shape for the filter");
373 
374  if (!inputType.hasStaticShape())
375  return rewriter.notifyMatchFailure(convOp,
376  "expected a static shape for the input");
377 
378  // TODO: Support dilation.
379  if (!hasAllOneValues(convOp.getDilations()))
380  return rewriter.notifyMatchFailure(convOp,
381  "expected all ones for dilations");
382 
383  Value input = convOp.getInputs()[0];
384  Value filter = convOp.getInputs()[1];
385  Value output = convOp.getOutputs()[0];
386 
387  auto filterShape = filterType.getShape();
388  auto outputShape = outputType.getShape();
389 
390  int64_t n = outputShape[0];
391  int64_t oc = outputShape[1];
392  int64_t oh = outputShape[2];
393  int64_t ow = outputShape[3];
394  int64_t ic = filterShape[1];
395  int64_t fh = filterShape[2];
396  int64_t fw = filterShape[3];
397 
398  auto loc = convOp.getLoc();
399  MLIRContext *context = rewriter.getContext();
400 
401  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
402  auto reshapedFilterType =
403  RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
404  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
405  loc, reshapedFilterType, filter, filterReassocIndices);
406 
407  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
408  auto reshapedOutputType =
409  RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
410  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
411  loc, reshapedOutputType, output, outputReassocIndices);
412 
413  // Convert the input to a (BKN) tensor.
414  SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
415  Value colTensor = rewriter.create<tensor::EmptyOp>(
416  loc, colTensorShape, inputType.getElementType());
417 
418  auto nloops = colTensorShape.size();
419 
420  auto parallel = utils::IteratorType::parallel;
421  auto reduction = utils::IteratorType::reduction;
422  SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
423 
424  SmallVector<AffineMap, 4> img2colIndexingMaps = {
425  AffineMap::getMultiDimIdentityMap(nloops, context)};
426 
427  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
428  loc, colTensor.getType(),
429  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
430  img2colIterators,
431  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
432  // Get the iterators named based on the matmul (batch, m, k).
433  Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
434  Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
435  Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
436 
437  // Recover the original iteration indices from the problem/input sizes.
438  SmallVector<Value> kIndices = unrollIndex(
439  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
440  auto icIndex = kIndices[0];
441  auto fhIndex = kIndices[1];
442  auto fwIndex = kIndices[2];
443 
444  SmallVector<Value> nIndices = unrollIndex(
445  nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
446  auto ohIndex = nIndices[0];
447  auto owIndex = nIndices[1];
448 
449  // Extract the input element corresponding to the expanded indices.
450  Value hIndex =
451  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
452  convOp.getStrides().getValues<int64_t>()[0]);
453  Value wIndex =
454  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
455  convOp.getStrides().getValues<int64_t>()[1]);
456 
457  // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
458  SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
459  Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
460  loc, input, extractionIndices);
461  nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
462  });
463 
464  // Because the filter does not share the same batch dimension,
465  // the batch dimension is only used in indexing the input and output. Thus
466  // we cannot use existing linalg named ops like linalg.batch_matmul.
467  // i.e. M x K * (B x) K x N = (B x) M x N
468  AffineExpr bDim, mDim, nDim, kDim;
469  bindDims(context, bDim, mDim, nDim, kDim);
470  auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
471  auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
472  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
473  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
474  parallel, reduction};
475  auto genericOp = rewriter.create<linalg::GenericOp>(
476  loc, reshapedOutputType,
477  /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
478  /*outputs=*/ValueRange{reshapedOutput},
479  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
480  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
481  Value mul =
482  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
483  Value add = createAdd(loc, mul, args[2], nestedBuilder);
484  nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
485  });
486  Value result = genericOp.getResults().front();
487 
488  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
489  loc, outputType, result, outputReassocIndices);
490 
491  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
492 
493  return std::make_pair(img2ColTensor.getOperation(),
494  reshapedResult.getOperation());
495 }
496 
497 FailureOr<std::pair<Operation *, Operation *>>
498 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
499  auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
500  auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
501  auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
502 
503  if (!filterType.hasStaticShape())
504  return rewriter.notifyMatchFailure(
505  convOp, "expected a static shape for the filter");
506 
507  if (!inputType.hasStaticShape())
508  return rewriter.notifyMatchFailure(convOp,
509  "expected a static shape for the input");
510 
511  // TODO: Support dilation.
512  if (!hasAllOneValues(convOp.getDilations()))
513  return rewriter.notifyMatchFailure(convOp,
514  "expected all ones for dilations");
515 
516  MLIRContext *context = rewriter.getContext();
517  Value input = convOp.getInputs()[0];
518  Value filter = convOp.getInputs()[1];
519  Value output = convOp.getOutputs()[0];
520 
521  ArrayRef<int64_t> filterShape = filterType.getShape();
522  ArrayRef<int64_t> outputShape = outputType.getShape();
523 
524  int64_t n = outputShape[0];
525  int64_t oh = outputShape[1];
526  int64_t ow = outputShape[2];
527  int64_t oc = outputShape[3];
528  int64_t fh = filterShape[1];
529  int64_t fw = filterShape[2];
530  int64_t ic = filterShape[3];
531 
532  Location loc = convOp.getLoc();
533 
534  // Reshape output and filter to the LHS and result of a "row-wise" matrix
535  // multiplication.
536  SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
537  auto reshapedFilterType =
538  RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
539  Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
540  loc, reshapedFilterType, filter, filterReassocIndices);
541 
542  SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
543  RankedTensorType reshapedOutputType =
544  RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
545  Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
546  loc, reshapedOutputType, output, outputReassocIndices);
547 
548  SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
549  Value colTensor = rewriter.create<tensor::EmptyOp>(
550  loc, colTensorShape, inputType.getElementType());
551 
552  // Convert the input to a (BMK) column tensor.
553  auto nloops = colTensorShape.size();
554 
555  auto parallel = utils::IteratorType::parallel;
556  auto reduction = utils::IteratorType::reduction;
557  SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
558 
559  SmallVector<AffineMap> img2colIndexingMaps = {
560  AffineMap::getMultiDimIdentityMap(nloops, context)};
561 
562  auto img2ColTensor = rewriter.create<linalg::GenericOp>(
563  loc, colTensor.getType(),
564  /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
565  img2colIterators,
566  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
567  // Get the iterators named based on the matmul (batch, m, k).
568  Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
569  Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
570  Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
571 
572  // Recover the original iteration indices from the problem/input sizes.
573  SmallVector<Value> mIndices = unrollIndex(
574  nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
575  auto ohIndex = mIndices[0];
576  auto owIndex = mIndices[1];
577 
578  SmallVector<Value> kIndices = unrollIndex(
579  nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
580  auto fhIndex = kIndices[0];
581  auto fwIndex = kIndices[1];
582  auto icIndex = kIndices[2];
583 
584  // Extract the input element corresponding to the expanded indices.
585  Value hIndex =
586  getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
587  convOp.getStrides().getValues<int64_t>()[0]);
588  Value wIndex =
589  getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
590  convOp.getStrides().getValues<int64_t>()[1]);
591 
592  // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
593  SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
594  Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
595  loc, input, extractionIndices);
596  nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
597  });
598 
599  // Because we didn't transpose the filters we don't actually have a batched
600  // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
601  // products.
602  AffineExpr bDim, mDim, nDim, kDim;
603  bindDims(context, bDim, mDim, nDim, kDim);
604  auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
605  auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
606  auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
607  SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
608  parallel, reduction};
609 
610  auto genericOp = rewriter.create<linalg::GenericOp>(
611  loc, reshapedOutputType,
612  /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
613  /*outputs=*/ValueRange{reshapedOutput},
614  ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
615  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
616  Value mul =
617  createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
618  Value add = createAdd(loc, mul, args[2], nestedBuilder);
619  nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
620  });
621  Value result = genericOp.getResults().front();
622 
623  auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
624  loc, outputType, result, outputReassocIndices);
625 
626  rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
627 
628  return std::make_pair(img2ColTensor.getOperation(),
629  reshapedResult.getOperation());
630 }
631 
632 namespace {
633 
634 class ConvertConv2DNhwcHwcf final
635  : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
636 public:
638 
639  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
640  PatternRewriter &rewriter) const override {
641  if (failed(rewriteInIm2Col(rewriter, convOp)))
642  return failure();
643  return success();
644  }
645 };
646 
647 class ConvertDepthwiseConv2DNhwcHwc final
648  : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
649 public:
651 
652  LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
653  PatternRewriter &rewriter) const override {
654  if (failed(rewriteInIm2Col(rewriter, convOp)))
655  return failure();
656  return success();
657  }
658 };
659 
660 class ConvertConv2DNchwFchw final
661  : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
662 public:
664 
665  LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
666  PatternRewriter &rewriter) const override {
667  if (failed(rewriteInIm2Col(rewriter, convOp)))
668  return failure();
669  return success();
670  }
671 };
672 
673 class ConvertConv2DNhwcFhwc final
674  : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
675 public:
677 
678  LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
679  PatternRewriter &rewriter) const override {
680  if (failed(rewriteInIm2Col(rewriter, convOp)))
681  return failure();
682  return success();
683  }
684 };
685 } // end anonymous namespace
686 
688  MLIRContext *context = patterns.getContext();
689  patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
690  ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
691 }
692 } // end namespace linalg
693 } // end namespace mlir
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:412
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:55
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:937
MLIRContext * getContext() const
Definition: PatternMatch.h:829
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1144
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
Definition: Utils.cpp:1946
static SmallVector< Value > unrollIndex(OpBuilder &b, Location loc, Value index, ArrayRef< int64_t > factors)
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static bool hasAllOneValues(DenseIntElementsAttr attr)
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, Value fIndex, int64_t stride)
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362