MLIR 23.0.0git
ConvertConv2DToImg2Col.cpp
Go to the documentation of this file.
1//===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
18#include "mlir/IR/AffineExpr.h"
19#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/Builders.h"
23#include "llvm/ADT/SmallVectorExtras.h"
24#include <cassert>
25#include <utility>
26
27namespace mlir {
28namespace linalg {
30 return llvm::all_of(
31 attr, [](const APInt &element) { return element.getSExtValue() == 1; });
32}
33
34static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
35 if (isa<IntegerType>(x.getType()))
36 return arith::AddIOp::create(builder, loc, x, y);
37 if (isa<ComplexType>(x.getType()))
38 return complex::AddOp::create(builder, loc, x, y);
39 return arith::AddFOp::create(builder, loc, x, y);
40}
41
42static Value createMul(Location loc, Value x, Value y, Type accType,
43 OpBuilder &builder) {
44 // Linalg named ops specify signed extend for named ops.
45 Value xConvert =
46 convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
47 Value yConvert =
48 convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
49 if (isa<ComplexType>(accType))
50 return complex::MulOp::create(builder, loc, xConvert, yConvert);
51 if (isa<IntegerType>(accType))
52 return arith::MulIOp::create(builder, loc, xConvert, yConvert);
53 return arith::MulFOp::create(builder, loc, xConvert, yConvert);
54}
55
56// Generate the affine expression to compute the convolved index
57// for the input as `oIndex * stride + fIndex`,
58// where oIndex: output iterator; fIndex: filter iterator.
60 bool useSymbols = true) {
61 AffineExpr oExpr, fExpr;
62 if (useSymbols)
63 bindSymbols(b.getContext(), oExpr, fExpr);
64 else
65 bindDims(b.getContext(), oExpr, fExpr);
66 return AffineExpr(stride * oExpr + fExpr);
67}
68
69// Stores the affine expressions to map the iteration space of the im2col matrix
70// to the corresponding indices of the output and filter matrices
78
79// Stores the affine expressions to map the iteration space of the im2col matrix
80// to the input matrix indices
87
88/// Construct the affine expressions that map the indices of the im2col matrix
89/// to the corresponding input tensor indices for a 2D convolution with the the
90/// provided strides.
91///
92/// @param exprs Affine expressions for output and filter indices.
93/// @param strides [height, width] stride values for the convolution.
94/// @param rewriter Pattern rewriter.
95/// @return Affine expressions mapping im2col matrix indices to input
96/// offsets.
99 ArrayRef<int64_t> strides, RewriterBase &rewriter) {
100 // maps the iteration space of the im2col matrix to (output_y, filter_y)
101 auto hIndicesMap = AffineMap::inferFromExprList(
102 {ArrayRef{exprs.ohIndex, exprs.fhIndex}}, rewriter.getContext())[0];
103 // maps the iteration space of the im2col matrix to (output_x, filter_x)
104 auto wIndicesMap = AffineMap::inferFromExprList(
105 {ArrayRef{exprs.owIndex, exprs.fwIndex}}, rewriter.getContext())[0];
106 // Compute the input indexing map, to map the indices of the im2col matrix to
107 // the original input offsets. Each element of the im2col matrix corresponds
108 // to a pair of (out_element, filter_element). First, we build the expressions
109 // to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs;
110 // then we compose them with the maps that map the im2col matrix elements to
111 // the (out_element, filter_element) pairs.
112 auto bIndexExpr = rewriter.getAffineDimExpr(0U);
113 auto hIndexExpr = getConvolvedExpr(rewriter, strides[0],
114 /*useSymbols*/ false);
115 hIndexExpr = hIndexExpr.compose(hIndicesMap);
116 auto wIndexExpr = getConvolvedExpr(rewriter, strides[1],
117 /*useSymbols*/ false);
118 wIndexExpr = wIndexExpr.compose(wIndicesMap);
119 auto cIndexExpr = exprs.icIndex;
120 return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
121}
122
123FailureOr<std::pair<Operation *, Operation *>>
124rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
125 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
126 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
127 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
128
129 if (!convOp.hasPureTensorSemantics())
130 return rewriter.notifyMatchFailure(
131 convOp, "expected op to have pure tensor semantics");
132
133 if (!filterType.hasStaticShape())
134 return rewriter.notifyMatchFailure(
135 convOp, "expected a static shape for the filter");
136
137 if (!inputType.hasStaticShape())
138 return rewriter.notifyMatchFailure(convOp,
139 "expected a static shape for the input");
140
141 // TODO: Support dilation.
142 if (!hasAllOneValues(convOp.getDilations()))
143 return rewriter.notifyMatchFailure(convOp,
144 "expected all ones for dilations");
145
146 MLIRContext *context = rewriter.getContext();
147 Value input = convOp.getInputs()[0];
148 Value filter = convOp.getInputs()[1];
149 Value output = convOp.getOutputs()[0];
150
151 ArrayRef<int64_t> filterShape = filterType.getShape();
152 ArrayRef<int64_t> outputShape = outputType.getShape();
153
154 int64_t n = outputShape[0];
155 int64_t oh = outputShape[1];
156 int64_t ow = outputShape[2];
157 int64_t oc = outputShape[3];
158 int64_t fh = filterShape[0];
159 int64_t fw = filterShape[1];
160 int64_t ic = filterShape[2];
161
162 Location loc = convOp.getLoc();
163
164 assert(isa<RankedTensorType>(filterType) &&
165 "expected filter type to be a ranked tensor");
166 auto tensorFilterType = cast<RankedTensorType>(filterType);
167
168 // Reshape output and filter to the LHS and result of a (B)MNK matmul.
169 SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
170 auto reshapedFilterType =
171 RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(),
172 tensorFilterType.getEncoding());
173 Value reshapedFilter = tensor::CollapseShapeOp::create(
174 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
175
176 SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
177 RankedTensorType reshapedOutputType =
178 RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
179 Value reshapedOutput = tensor::CollapseShapeOp::create(
180 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
181
182 SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
183 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
184 inputType.getElementType());
185
186 // Convert the input to a (BMK) column tensor.
187 auto nloops = colTensorShape.size();
188
189 auto parallel = utils::IteratorType::parallel;
190 auto reduction = utils::IteratorType::reduction;
191 SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
192
193 // Given an index of the im2col matrix, retrieve the corresponding indices of
194 // the output and filter matrices
195 auto mIndicesExprs =
196 delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
197 auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
198 ArrayRef<int64_t>{fw * ic, ic, 1});
199 Im2ColToOperandsExprs i2cToOperExprs;
200 i2cToOperExprs.fhIndex = kIndicesExprs[0];
201 i2cToOperExprs.fwIndex = kIndicesExprs[1];
202 i2cToOperExprs.icIndex = kIndicesExprs[2];
203 i2cToOperExprs.ohIndex = mIndicesExprs[0];
204 i2cToOperExprs.owIndex = mIndicesExprs[1];
205
206 // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
208 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
209 rewriter);
210 auto inMap =
212 inExprs.wIndex, inExprs.cIndex}},
213 rewriter.getContext())[0];
214
215 SmallVector<AffineMap> img2colIndexingMaps = {
216 inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
217
218 auto img2ColTensor = linalg::GenericOp::create(
219 rewriter, loc, colTensor.getType(),
220 /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
221 img2colIterators,
222 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
223 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
224 });
225
226 // Because the filter does not share the same batch dimension,
227 // the batch dimension is only used in indexing the input and output. Thus
228 // we cannot use existing linalg named ops like linalg.batch_matmul.
229 // i.e. (B x) M x K * K x N = (B x) M x N
230 AffineExpr bDim, mDim, nDim, kDim;
231 bindDims(context, bDim, mDim, nDim, kDim);
232 auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
233 auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
234 auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
235 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
236 parallel, reduction};
237
238 auto genericOp = linalg::GenericOp::create(
239 rewriter, loc, reshapedOutputType,
240 /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
241 /*outputs=*/ValueRange{reshapedOutput},
242 ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
243 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
244 Value mul =
245 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
246 Value add = createAdd(loc, mul, args[2], nestedBuilder);
247 linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
248 });
249 Value result = genericOp.getResults().front();
250
251 auto reshapedResult = tensor::ExpandShapeOp::create(
252 rewriter, loc, outputType, result, outputReassocIndices);
253
254 rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
255
256 return std::make_pair(img2ColTensor.getOperation(),
257 reshapedResult.getOperation());
258}
259
260FailureOr<std::pair<Operation *, Operation *>>
262 linalg::DepthwiseConv2DNhwcHwcOp convOp) {
263 auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
264 auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
265 auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
266
267 if (!convOp.hasPureTensorSemantics())
268 return rewriter.notifyMatchFailure(
269 convOp, "expected op to have pure tensor semantics");
270
271 if (!filterType.hasStaticShape())
272 return rewriter.notifyMatchFailure(
273 convOp, "expected a static shape for the filter");
274
275 if (!inputType.hasStaticShape())
276 return rewriter.notifyMatchFailure(convOp,
277 "expected a static shape for the input");
278
279 // TODO: Support dilation.
280 if (!hasAllOneValues(convOp.getDilations()))
281 return rewriter.notifyMatchFailure(convOp,
282 "expected all ones for dilations");
283
284 Location loc = convOp.getLoc();
285
286 auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
287 auto operandTensorType = cast<RankedTensorType>(operand.getType());
288 auto nloops = indices.size();
289 ArrayRef<int64_t> inputShape = operandTensorType.getShape();
290
292 llvm::map_to_vector<4>(indices, [&](int64_t index) -> AffineExpr {
293 return rewriter.getAffineDimExpr(index);
294 });
295
296 SmallVector<int64_t> targetShape = llvm::map_to_vector<4>(
297 indices, [&](int64_t index) -> int64_t { return inputShape[index]; });
298
299 Value outputTensor = tensor::EmptyOp::create(
300 rewriter, loc, targetShape, operandTensorType.getElementType());
301
302 SmallVector<utils::IteratorType> loopAttributeTypes(
303 nloops, utils::IteratorType::parallel);
304
305 SmallVector<AffineMap> indexingMaps = {
307 AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
309
310 auto transposedOp = linalg::GenericOp::create(
311 rewriter, loc, outputTensor.getType(),
312 /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
313 loopAttributeTypes,
314 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
315 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
316 });
317
318 return transposedOp.getResult(0);
319 };
320
321 Value input = convOp.getInputs()[0];
322 Value filter = convOp.getInputs()[1];
323 Value output = convOp.getOutputs()[0];
324
325 // Transpose input, filter so channels are outermost
326 Value inputT = transposeOperand(input, {0, 3, 1, 2});
327 Value filterT = transposeOperand(filter, {2, 0, 1});
328 ArrayRef<int64_t> filterTShape =
329 cast<RankedTensorType>(filterT.getType()).getShape();
330 ArrayRef<int64_t> outputShape = outputType.getShape();
331
332 int n = outputShape[0];
333 int oh = outputShape[1];
334 int ow = outputShape[2];
335 int c = outputShape[3];
336 int fh = filterTShape[1];
337 int fw = filterTShape[2];
338
339 SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
340 Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
341
342 AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
343 bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
344
345 AffineExpr shSym = rewriter.getAffineConstantExpr(
346 convOp.getStrides().getValues<int64_t>()[0]);
347 AffineExpr swSym = rewriter.getAffineConstantExpr(
348 convOp.getStrides().getValues<int64_t>()[1]);
349
350 SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
351 owDim * swSym + kwDim};
352
353 auto nloops = colTensorShape.size();
354
355 SmallVector<utils::IteratorType> loopAttributeTypes(
356 nloops, utils::IteratorType::parallel);
357
358 SmallVector<AffineMap> indexingMaps = {
359 AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
361
362 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
363 inputType.getElementType());
364
365 auto img2ColTensor = linalg::GenericOp::create(
366 rewriter, loc, colTensor.getType(),
367 /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
368 loopAttributeTypes,
369 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
370 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
371 });
372
373 SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
374 {0, 1}, {2, 3}, {4, 5}};
375 SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
376 SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
377 {2, 3}};
378
379 auto reshapedImg2ColTensorType = RankedTensorType::get(
380 {n * c, oh * ow, fh * fw}, inputType.getElementType());
381 auto reshapedFilterTensorType =
382 RankedTensorType::get({c, fh * fw}, filterType.getElementType());
383 auto reshapedOutputTensorType =
384 RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
385
386 Value reshapedImg2ColTensor = tensor::CollapseShapeOp::create(
387 rewriter, loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
388 img2ColTensorReassocIndices);
389 Value reshapedFilterTensor =
390 tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterTensorType,
391 filterT, filterReassociationIndice);
392 Value reshapedoutputTensor = tensor::CollapseShapeOp::create(
393 rewriter, loc, reshapedOutputTensorType, transposedOutputTensor,
394 outputReassociationIndice);
395
396 auto batchMatVecResult = linalg::BatchMatvecOp::create(
397 rewriter, loc, TypeRange{reshapedoutputTensor.getType()},
398 ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
399 ValueRange{reshapedoutputTensor});
400
401 SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
402 {2, 3}};
403
404 auto batchMatVecResultReshaped = tensor::ExpandShapeOp::create(
405 rewriter, loc, transposedOutputTensor.getType(),
406 batchMatVecResult.getResult(0), batchMatVecReassociationIndice);
407
408 Value transposedResult =
409 transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
410
411 rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
412 return std::make_pair(img2ColTensor.getOperation(),
413 transposedResult.getDefiningOp());
414}
415
416FailureOr<std::pair<Operation *, Operation *>>
417rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
418 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
419 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
420 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
421
422 if (!convOp.hasPureTensorSemantics())
423 return rewriter.notifyMatchFailure(
424 convOp, "expected op to have pure tensor semantics");
425
426 if (!filterType.hasStaticShape())
427 return rewriter.notifyMatchFailure(
428 convOp, "expected a static shape for the filter");
429
430 if (!inputType.hasStaticShape())
431 return rewriter.notifyMatchFailure(convOp,
432 "expected a static shape for the input");
433
434 // TODO: Support dilation.
435 if (!hasAllOneValues(convOp.getDilations()))
436 return rewriter.notifyMatchFailure(convOp,
437 "expected all ones for dilations");
438
439 Value input = convOp.getInputs()[0];
440 Value filter = convOp.getInputs()[1];
441 Value output = convOp.getOutputs()[0];
442
443 auto filterShape = filterType.getShape();
444 auto outputShape = outputType.getShape();
445
446 int64_t n = outputShape[0];
447 int64_t oc = outputShape[1];
448 int64_t oh = outputShape[2];
449 int64_t ow = outputShape[3];
450 int64_t ic = filterShape[1];
451 int64_t fh = filterShape[2];
452 int64_t fw = filterShape[3];
453
454 auto loc = convOp.getLoc();
455 MLIRContext *context = rewriter.getContext();
456
457 assert(isa<RankedTensorType>(filterType) &&
458 "expected filter type to be a ranked tensor");
459 auto tensorFilterType = cast<RankedTensorType>(filterType);
460
461 SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
462 auto reshapedFilterType =
463 RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(),
464 tensorFilterType.getEncoding());
465 Value reshapedFilter = tensor::CollapseShapeOp::create(
466 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
467
468 SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
469 auto reshapedOutputType =
470 RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
471 Value reshapedOutput = tensor::CollapseShapeOp::create(
472 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
473
474 // Convert the input to a (BKN) tensor.
475 SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
476 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
477 inputType.getElementType());
478
479 auto nloops = colTensorShape.size();
480
481 auto parallel = utils::IteratorType::parallel;
482 auto reduction = utils::IteratorType::reduction;
483 SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
484
485 // Recover the original iteration indices from the problem/input sizes:
486 // given an index of the im2col matrix, retrieve the corresponding indices of
487 // the output and filter matrices
488 auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(1U),
489 ArrayRef<int64_t>{fh * fw, fw, 1});
490 auto mIndicesExprs =
491 delinearize(rewriter.getAffineDimExpr(2U), ArrayRef<int64_t>{ow, 1});
492 Im2ColToOperandsExprs i2cToOperExprs;
493 i2cToOperExprs.icIndex = kIndicesExprs[0];
494 i2cToOperExprs.fhIndex = kIndicesExprs[1];
495 i2cToOperExprs.fwIndex = kIndicesExprs[2];
496 i2cToOperExprs.ohIndex = mIndicesExprs[0];
497 i2cToOperExprs.owIndex = mIndicesExprs[1];
499 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
500 rewriter);
501 auto inMap =
503 inExprs.hIndex, inExprs.wIndex}},
504 rewriter.getContext())[0];
505 // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
506 SmallVector<AffineMap> img2colIndexingMaps = {
507 inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
508
509 auto img2ColTensor = linalg::GenericOp::create(
510 rewriter, loc, colTensor.getType(),
511 /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
512 img2colIterators,
513 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
514 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
515 });
516
517 // Because the filter does not share the same batch dimension,
518 // the batch dimension is only used in indexing the input and output. Thus
519 // we cannot use existing linalg named ops like linalg.batch_matmul.
520 // i.e. M x K * (B x) K x N = (B x) M x N
521 AffineExpr bDim, mDim, nDim, kDim;
522 bindDims(context, bDim, mDim, nDim, kDim);
523 auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
524 auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
525 auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
526 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
527 parallel, reduction};
528 auto genericOp = linalg::GenericOp::create(
529 rewriter, loc, reshapedOutputType,
530 /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
531 /*outputs=*/ValueRange{reshapedOutput},
532 ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
533 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
534 Value mul =
535 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
536 Value add = createAdd(loc, mul, args[2], nestedBuilder);
537 linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
538 });
539 Value result = genericOp.getResults().front();
540
541 auto reshapedResult = tensor::ExpandShapeOp::create(
542 rewriter, loc, outputType, result, outputReassocIndices);
543
544 rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
545
546 return std::make_pair(img2ColTensor.getOperation(),
547 reshapedResult.getOperation());
548}
549
550FailureOr<std::pair<Operation *, Operation *>>
551rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
552 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
553 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
554 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
555
556 if (!convOp.hasPureTensorSemantics())
557 return rewriter.notifyMatchFailure(
558 convOp, "expected op to have pure tensor semantics");
559
560 if (!filterType.hasStaticShape())
561 return rewriter.notifyMatchFailure(
562 convOp, "expected a static shape for the filter");
563
564 if (!inputType.hasStaticShape())
565 return rewriter.notifyMatchFailure(convOp,
566 "expected a static shape for the input");
567
568 // TODO: Support dilation.
569 if (!hasAllOneValues(convOp.getDilations()))
570 return rewriter.notifyMatchFailure(convOp,
571 "expected all ones for dilations");
572
573 MLIRContext *context = rewriter.getContext();
574 Value input = convOp.getInputs()[0];
575 Value filter = convOp.getInputs()[1];
576 Value output = convOp.getOutputs()[0];
577
578 ArrayRef<int64_t> filterShape = filterType.getShape();
579 ArrayRef<int64_t> outputShape = outputType.getShape();
580
581 int64_t n = outputShape[0];
582 int64_t oh = outputShape[1];
583 int64_t ow = outputShape[2];
584 int64_t oc = outputShape[3];
585 int64_t fh = filterShape[1];
586 int64_t fw = filterShape[2];
587 int64_t ic = filterShape[3];
588
589 Location loc = convOp.getLoc();
590
591 assert(isa<RankedTensorType>(filterType) &&
592 "expected filter type to be a ranked tensor");
593 auto tensorFilterType = cast<RankedTensorType>(filterType);
594
595 // Reshape output and filter to the LHS and result of a "row-wise" matrix
596 // multiplication.
597 SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
598 auto reshapedFilterType =
599 RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(),
600 tensorFilterType.getEncoding());
601 Value reshapedFilter = tensor::CollapseShapeOp::create(
602 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
603
604 SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
605 RankedTensorType reshapedOutputType =
606 RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
607 Value reshapedOutput = tensor::CollapseShapeOp::create(
608 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
609
610 // Shape of the Toeplitz matrix produced by Im2col.
611 SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
612 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
613 inputType.getElementType());
614
615 // Convert the input to a (BMK) column tensor.
616 auto nloops = colTensorShape.size();
617
618 auto parallel = utils::IteratorType::parallel;
619 auto reduction = utils::IteratorType::reduction;
620 SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
621
622 // Given an index of the im2col matrix, retrieve the corresponding indices of
623 // the output and filter matrices
624 auto mIndicesExprs =
625 delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
626 auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
627 ArrayRef<int64_t>{fw * ic, ic, 1});
628 Im2ColToOperandsExprs i2cToOperExprs;
629 i2cToOperExprs.fhIndex = kIndicesExprs[0];
630 i2cToOperExprs.fwIndex = kIndicesExprs[1];
631 i2cToOperExprs.icIndex = kIndicesExprs[2];
632 i2cToOperExprs.ohIndex = mIndicesExprs[0];
633 i2cToOperExprs.owIndex = mIndicesExprs[1];
634
635 // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
637 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
638 rewriter);
639 auto inMap =
641 inExprs.wIndex, inExprs.cIndex}},
642 rewriter.getContext())[0];
643 SmallVector<AffineMap> img2colIndexingMaps = {
644 inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
645
646 auto img2ColTensor = linalg::GenericOp::create(
647 rewriter, loc, colTensor.getType(),
648 /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
649 img2colIterators,
650 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
651 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
652 });
653
654 // Because we didn't transpose the filters we don't actually have a batched
655 // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
656 // products.
657 AffineExpr bDim, mDim, nDim, kDim;
658 bindDims(context, bDim, mDim, nDim, kDim);
659 auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
660 auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
661 auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
662 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
663 parallel, reduction};
664
665 auto genericOp = linalg::GenericOp::create(
666 rewriter, loc, reshapedOutputType,
667 /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
668 /*outputs=*/ValueRange{reshapedOutput},
669 ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
670 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
671 Value mul =
672 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
673 Value add = createAdd(loc, mul, args[2], nestedBuilder);
674 linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
675 });
676 Value result = genericOp.getResults().front();
677
678 auto reshapedResult = tensor::ExpandShapeOp::create(
679 rewriter, loc, outputType, result, outputReassocIndices);
680
681 rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
682
683 return std::make_pair(img2ColTensor.getOperation(),
684 reshapedResult.getOperation());
685}
686
687namespace {
688
689class ConvertConv2DNhwcHwcf final
690 : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
691public:
693
694 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
695 PatternRewriter &rewriter) const override {
696 if (failed(rewriteInIm2Col(rewriter, convOp)))
697 return failure();
698 return success();
699 }
700};
701
702class ConvertDepthwiseConv2DNhwcHwc final
703 : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
704public:
705 using OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>::OpRewritePattern;
706
707 LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
708 PatternRewriter &rewriter) const override {
709 if (failed(rewriteInIm2Col(rewriter, convOp)))
710 return failure();
711 return success();
712 }
713};
714
715class ConvertConv2DNchwFchw final
716 : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
717public:
719
720 LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
721 PatternRewriter &rewriter) const override {
722 if (failed(rewriteInIm2Col(rewriter, convOp)))
723 return failure();
724 return success();
725 }
726};
727
728class ConvertConv2DNhwcFhwc final
729 : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
730public:
732
733 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
734 PatternRewriter &rewriter) const override {
735 if (failed(rewriteInIm2Col(rewriter, convOp)))
736 return failure();
737 return success();
738 }
739};
740} // end anonymous namespace
741
743 MLIRContext *context = patterns.getContext();
744 patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
745 ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
746}
747} // end namespace linalg
748} // end namespace mlir
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
#define mul(a, b)
#define add(a, b)
Base type for affine expression.
Definition AffineExpr.h:68
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:376
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
MLIRContext * getContext() const
Definition Builders.h:56
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static Im2ColToInputDimsExprs getIm2ColInputExpressions(Im2ColToOperandsExprs exprs, ArrayRef< int64_t > strides, RewriterBase &rewriter)
Construct the affine expressions that map the indices of the im2col matrix to the corresponding input...
static bool hasAllOneValues(DenseIntElementsAttr attr)
static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride, bool useSymbols=true)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:239
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...