MLIR 22.0.0git
ConvertConv2DToImg2Col.cpp
Go to the documentation of this file.
1//===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
18#include "mlir/IR/AffineExpr.h"
19#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/Builders.h"
23#include <cassert>
24#include <utility>
25
26namespace mlir {
27namespace linalg {
29 return llvm::all_of(
30 attr, [](const APInt &element) { return element.getSExtValue() == 1; });
31}
32
33static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
34 if (isa<IntegerType>(x.getType()))
35 return arith::AddIOp::create(builder, loc, x, y);
36 if (isa<ComplexType>(x.getType()))
37 return complex::AddOp::create(builder, loc, x, y);
38 return arith::AddFOp::create(builder, loc, x, y);
39}
40
41static Value createMul(Location loc, Value x, Value y, Type accType,
42 OpBuilder &builder) {
43 // Linalg named ops specify signed extend for named ops.
44 Value xConvert =
45 convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
46 Value yConvert =
47 convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
48 if (isa<ComplexType>(accType))
49 return complex::MulOp::create(builder, loc, xConvert, yConvert);
50 if (isa<IntegerType>(accType))
51 return arith::MulIOp::create(builder, loc, xConvert, yConvert);
52 return arith::MulFOp::create(builder, loc, xConvert, yConvert);
53}
54
55// Generate the affine expression to compute the convolved index
56// for the input as `oIndex * stride + fIndex`,
57// where oIndex: output iterator; fIndex: filter iterator.
59 bool useSymbols = true) {
60 AffineExpr oExpr, fExpr;
61 if (useSymbols)
62 bindSymbols(b.getContext(), oExpr, fExpr);
63 else
64 bindDims(b.getContext(), oExpr, fExpr);
65 return AffineExpr(stride * oExpr + fExpr);
66}
67
68// Stores the affine expressions to map the iteration space of the im2col matrix
69// to the corresponding indices of the output and filter matrices
77
78// Stores the affine expressions to map the iteration space of the im2col matrix
79// to the input matrix indices
86
87/// Construct the affine expressions that map the indices of the im2col matrix
88/// to the corresponding input tensor indices for a 2D convolution with the the
89/// provided strides.
90///
91/// @param exprs Affine expressions for output and filter indices.
92/// @param strides [height, width] stride values for the convolution.
93/// @param rewriter Pattern rewriter.
94/// @return Affine expressions mapping im2col matrix indices to input
95/// offsets.
98 ArrayRef<int64_t> strides, RewriterBase &rewriter) {
99 // maps the iteration space of the im2col matrix to (output_y, filter_y)
100 auto hIndicesMap = AffineMap::inferFromExprList(
101 {ArrayRef{exprs.ohIndex, exprs.fhIndex}}, rewriter.getContext())[0];
102 // maps the iteration space of the im2col matrix to (output_x, filter_x)
103 auto wIndicesMap = AffineMap::inferFromExprList(
104 {ArrayRef{exprs.owIndex, exprs.fwIndex}}, rewriter.getContext())[0];
105 // Compute the input indexing map, to map the indices of the im2col matrix to
106 // the original input offsets. Each element of the im2col matrix corresponds
107 // to a pair of (out_element, filter_element). First, we build the expressions
108 // to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs;
109 // then we compose them with the maps that map the im2col matrix elements to
110 // the (out_element, filter_element) pairs.
111 auto bIndexExpr = rewriter.getAffineDimExpr(0U);
112 auto hIndexExpr = getConvolvedExpr(rewriter, strides[0],
113 /*useSymbols*/ false);
114 hIndexExpr = hIndexExpr.compose(hIndicesMap);
115 auto wIndexExpr = getConvolvedExpr(rewriter, strides[1],
116 /*useSymbols*/ false);
117 wIndexExpr = wIndexExpr.compose(wIndicesMap);
118 auto cIndexExpr = exprs.icIndex;
119 return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
120}
121
122FailureOr<std::pair<Operation *, Operation *>>
123rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
124 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
125 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
126 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
127
128 if (!convOp.hasPureTensorSemantics())
129 return rewriter.notifyMatchFailure(
130 convOp, "expected op to have pure tensor semantics");
131
132 if (!filterType.hasStaticShape())
133 return rewriter.notifyMatchFailure(
134 convOp, "expected a static shape for the filter");
135
136 if (!inputType.hasStaticShape())
137 return rewriter.notifyMatchFailure(convOp,
138 "expected a static shape for the input");
139
140 // TODO: Support dilation.
141 if (!hasAllOneValues(convOp.getDilations()))
142 return rewriter.notifyMatchFailure(convOp,
143 "expected all ones for dilations");
144
145 MLIRContext *context = rewriter.getContext();
146 Value input = convOp.getInputs()[0];
147 Value filter = convOp.getInputs()[1];
148 Value output = convOp.getOutputs()[0];
149
150 ArrayRef<int64_t> filterShape = filterType.getShape();
151 ArrayRef<int64_t> outputShape = outputType.getShape();
152
153 int64_t n = outputShape[0];
154 int64_t oh = outputShape[1];
155 int64_t ow = outputShape[2];
156 int64_t oc = outputShape[3];
157 int64_t fh = filterShape[0];
158 int64_t fw = filterShape[1];
159 int64_t ic = filterShape[2];
160
161 Location loc = convOp.getLoc();
162
163 assert(isa<RankedTensorType>(filterType) &&
164 "expected filter type to be a ranked tensor");
165 auto tensorFilterType = cast<RankedTensorType>(filterType);
166
167 // Reshape output and filter to the LHS and result of a (B)MNK matmul.
168 SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
169 auto reshapedFilterType =
170 RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(),
171 tensorFilterType.getEncoding());
172 Value reshapedFilter = tensor::CollapseShapeOp::create(
173 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
174
175 SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
176 RankedTensorType reshapedOutputType =
177 RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
178 Value reshapedOutput = tensor::CollapseShapeOp::create(
179 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
180
181 SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
182 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
183 inputType.getElementType());
184
185 // Convert the input to a (BMK) column tensor.
186 auto nloops = colTensorShape.size();
187
188 auto parallel = utils::IteratorType::parallel;
189 auto reduction = utils::IteratorType::reduction;
190 SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
191
192 // Given an index of the im2col matrix, retrieve the corresponding indices of
193 // the output and filter matrices
194 auto mIndicesExprs =
195 delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
196 auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
197 ArrayRef<int64_t>{fw * ic, ic, 1});
198 Im2ColToOperandsExprs i2cToOperExprs;
199 i2cToOperExprs.fhIndex = kIndicesExprs[0];
200 i2cToOperExprs.fwIndex = kIndicesExprs[1];
201 i2cToOperExprs.icIndex = kIndicesExprs[2];
202 i2cToOperExprs.ohIndex = mIndicesExprs[0];
203 i2cToOperExprs.owIndex = mIndicesExprs[1];
204
205 // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
207 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
208 rewriter);
209 auto inMap =
211 inExprs.wIndex, inExprs.cIndex}},
212 rewriter.getContext())[0];
213
214 SmallVector<AffineMap> img2colIndexingMaps = {
215 inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
216
217 auto img2ColTensor = linalg::GenericOp::create(
218 rewriter, loc, colTensor.getType(),
219 /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
220 img2colIterators,
221 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
222 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
223 });
224
225 // Because the filter does not share the same batch dimension,
226 // the batch dimension is only used in indexing the input and output. Thus
227 // we cannot use existing linalg named ops like linalg.batch_matmul.
228 // i.e. (B x) M x K * K x N = (B x) M x N
229 AffineExpr bDim, mDim, nDim, kDim;
230 bindDims(context, bDim, mDim, nDim, kDim);
231 auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
232 auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
233 auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
234 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
235 parallel, reduction};
236
237 auto genericOp = linalg::GenericOp::create(
238 rewriter, loc, reshapedOutputType,
239 /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
240 /*outputs=*/ValueRange{reshapedOutput},
241 ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
242 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
243 Value mul =
244 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
245 Value add = createAdd(loc, mul, args[2], nestedBuilder);
246 linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
247 });
248 Value result = genericOp.getResults().front();
249
250 auto reshapedResult = tensor::ExpandShapeOp::create(
251 rewriter, loc, outputType, result, outputReassocIndices);
252
253 rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
254
255 return std::make_pair(img2ColTensor.getOperation(),
256 reshapedResult.getOperation());
257}
258
259FailureOr<std::pair<Operation *, Operation *>>
261 linalg::DepthwiseConv2DNhwcHwcOp convOp) {
262 auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
263 auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
264 auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
265
266 if (!convOp.hasPureTensorSemantics())
267 return rewriter.notifyMatchFailure(
268 convOp, "expected op to have pure tensor semantics");
269
270 if (!filterType.hasStaticShape())
271 return rewriter.notifyMatchFailure(
272 convOp, "expected a static shape for the filter");
273
274 if (!inputType.hasStaticShape())
275 return rewriter.notifyMatchFailure(convOp,
276 "expected a static shape for the input");
277
278 // TODO: Support dilation.
279 if (!hasAllOneValues(convOp.getDilations()))
280 return rewriter.notifyMatchFailure(convOp,
281 "expected all ones for dilations");
282
283 Location loc = convOp.getLoc();
284
285 auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
286 auto operandTensorType = cast<RankedTensorType>(operand.getType());
287 auto nloops = indices.size();
288 ArrayRef<int64_t> inputShape = operandTensorType.getShape();
289
290 SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
291 llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
292 return rewriter.getAffineDimExpr(index);
293 }));
294
295 SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
296 indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
297
298 Value outputTensor = tensor::EmptyOp::create(
299 rewriter, loc, targetShape, operandTensorType.getElementType());
300
301 SmallVector<utils::IteratorType> loopAttributeTypes(
302 nloops, utils::IteratorType::parallel);
303
304 SmallVector<AffineMap> indexingMaps = {
306 AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
308
309 auto transposedOp = linalg::GenericOp::create(
310 rewriter, loc, outputTensor.getType(),
311 /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
312 loopAttributeTypes,
313 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
314 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
315 });
316
317 return transposedOp.getResult(0);
318 };
319
320 Value input = convOp.getInputs()[0];
321 Value filter = convOp.getInputs()[1];
322 Value output = convOp.getOutputs()[0];
323
324 // Transpose input, filter so channels are outermost
325 Value inputT = transposeOperand(input, {0, 3, 1, 2});
326 Value filterT = transposeOperand(filter, {2, 0, 1});
327 ArrayRef<int64_t> filterTShape =
328 cast<RankedTensorType>(filterT.getType()).getShape();
329 ArrayRef<int64_t> outputShape = outputType.getShape();
330
331 int n = outputShape[0];
332 int oh = outputShape[1];
333 int ow = outputShape[2];
334 int c = outputShape[3];
335 int fh = filterTShape[1];
336 int fw = filterTShape[2];
337
338 SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
339 Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
340
341 AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
342 bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
343
344 AffineExpr shSym = rewriter.getAffineConstantExpr(
345 convOp.getStrides().getValues<int64_t>()[0]);
346 AffineExpr swSym = rewriter.getAffineConstantExpr(
347 convOp.getStrides().getValues<int64_t>()[1]);
348
349 SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
350 owDim * swSym + kwDim};
351
352 auto nloops = colTensorShape.size();
353
354 SmallVector<utils::IteratorType> loopAttributeTypes(
355 nloops, utils::IteratorType::parallel);
356
357 SmallVector<AffineMap> indexingMaps = {
358 AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
360
361 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
362 inputType.getElementType());
363
364 auto img2ColTensor = linalg::GenericOp::create(
365 rewriter, loc, colTensor.getType(),
366 /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
367 loopAttributeTypes,
368 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
369 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
370 });
371
372 SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
373 {0, 1}, {2, 3}, {4, 5}};
374 SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
375 SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
376 {2, 3}};
377
378 auto reshapedImg2ColTensorType = RankedTensorType::get(
379 {n * c, oh * ow, fh * fw}, inputType.getElementType());
380 auto reshapedFilterTensorType =
381 RankedTensorType::get({c, fh * fw}, filterType.getElementType());
382 auto reshapedOutputTensorType =
383 RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
384
385 Value reshapedImg2ColTensor = tensor::CollapseShapeOp::create(
386 rewriter, loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
387 img2ColTensorReassocIndices);
388 Value reshapedFilterTensor =
389 tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterTensorType,
390 filterT, filterReassociationIndice);
391 Value reshapedoutputTensor = tensor::CollapseShapeOp::create(
392 rewriter, loc, reshapedOutputTensorType, transposedOutputTensor,
393 outputReassociationIndice);
394
395 auto batchMatVecResult = linalg::BatchMatvecOp::create(
396 rewriter, loc, TypeRange{reshapedoutputTensor.getType()},
397 ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
398 ValueRange{reshapedoutputTensor});
399
400 SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
401 {2, 3}};
402
403 auto batchMatVecResultReshaped = tensor::ExpandShapeOp::create(
404 rewriter, loc, transposedOutputTensor.getType(),
405 batchMatVecResult.getResult(0), batchMatVecReassociationIndice);
406
407 Value transposedResult =
408 transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
409
410 rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
411 return std::make_pair(img2ColTensor.getOperation(),
412 transposedResult.getDefiningOp());
413}
414
415FailureOr<std::pair<Operation *, Operation *>>
416rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
417 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
418 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
419 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
420
421 if (!convOp.hasPureTensorSemantics())
422 return rewriter.notifyMatchFailure(
423 convOp, "expected op to have pure tensor semantics");
424
425 if (!filterType.hasStaticShape())
426 return rewriter.notifyMatchFailure(
427 convOp, "expected a static shape for the filter");
428
429 if (!inputType.hasStaticShape())
430 return rewriter.notifyMatchFailure(convOp,
431 "expected a static shape for the input");
432
433 // TODO: Support dilation.
434 if (!hasAllOneValues(convOp.getDilations()))
435 return rewriter.notifyMatchFailure(convOp,
436 "expected all ones for dilations");
437
438 Value input = convOp.getInputs()[0];
439 Value filter = convOp.getInputs()[1];
440 Value output = convOp.getOutputs()[0];
441
442 auto filterShape = filterType.getShape();
443 auto outputShape = outputType.getShape();
444
445 int64_t n = outputShape[0];
446 int64_t oc = outputShape[1];
447 int64_t oh = outputShape[2];
448 int64_t ow = outputShape[3];
449 int64_t ic = filterShape[1];
450 int64_t fh = filterShape[2];
451 int64_t fw = filterShape[3];
452
453 auto loc = convOp.getLoc();
454 MLIRContext *context = rewriter.getContext();
455
456 assert(isa<RankedTensorType>(filterType) &&
457 "expected filter type to be a ranked tensor");
458 auto tensorFilterType = cast<RankedTensorType>(filterType);
459
460 SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
461 auto reshapedFilterType =
462 RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(),
463 tensorFilterType.getEncoding());
464 Value reshapedFilter = tensor::CollapseShapeOp::create(
465 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
466
467 SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
468 auto reshapedOutputType =
469 RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
470 Value reshapedOutput = tensor::CollapseShapeOp::create(
471 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
472
473 // Convert the input to a (BKN) tensor.
474 SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
475 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
476 inputType.getElementType());
477
478 auto nloops = colTensorShape.size();
479
480 auto parallel = utils::IteratorType::parallel;
481 auto reduction = utils::IteratorType::reduction;
482 SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
483
484 // Recover the original iteration indices from the problem/input sizes:
485 // given an index of the im2col matrix, retrieve the corresponding indices of
486 // the output and filter matrices
487 auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(1U),
488 ArrayRef<int64_t>{fh * fw, fw, 1});
489 auto mIndicesExprs =
490 delinearize(rewriter.getAffineDimExpr(2U), ArrayRef<int64_t>{ow, 1});
491 Im2ColToOperandsExprs i2cToOperExprs;
492 i2cToOperExprs.icIndex = kIndicesExprs[0];
493 i2cToOperExprs.fhIndex = kIndicesExprs[1];
494 i2cToOperExprs.fwIndex = kIndicesExprs[2];
495 i2cToOperExprs.ohIndex = mIndicesExprs[0];
496 i2cToOperExprs.owIndex = mIndicesExprs[1];
498 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
499 rewriter);
500 auto inMap =
502 inExprs.hIndex, inExprs.wIndex}},
503 rewriter.getContext())[0];
504 // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
505 SmallVector<AffineMap> img2colIndexingMaps = {
506 inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
507
508 auto img2ColTensor = linalg::GenericOp::create(
509 rewriter, loc, colTensor.getType(),
510 /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
511 img2colIterators,
512 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
513 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
514 });
515
516 // Because the filter does not share the same batch dimension,
517 // the batch dimension is only used in indexing the input and output. Thus
518 // we cannot use existing linalg named ops like linalg.batch_matmul.
519 // i.e. M x K * (B x) K x N = (B x) M x N
520 AffineExpr bDim, mDim, nDim, kDim;
521 bindDims(context, bDim, mDim, nDim, kDim);
522 auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
523 auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
524 auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
525 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
526 parallel, reduction};
527 auto genericOp = linalg::GenericOp::create(
528 rewriter, loc, reshapedOutputType,
529 /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
530 /*outputs=*/ValueRange{reshapedOutput},
531 ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
532 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
533 Value mul =
534 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
535 Value add = createAdd(loc, mul, args[2], nestedBuilder);
536 linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
537 });
538 Value result = genericOp.getResults().front();
539
540 auto reshapedResult = tensor::ExpandShapeOp::create(
541 rewriter, loc, outputType, result, outputReassocIndices);
542
543 rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
544
545 return std::make_pair(img2ColTensor.getOperation(),
546 reshapedResult.getOperation());
547}
548
549FailureOr<std::pair<Operation *, Operation *>>
550rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
551 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
552 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
553 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
554
555 if (!convOp.hasPureTensorSemantics())
556 return rewriter.notifyMatchFailure(
557 convOp, "expected op to have pure tensor semantics");
558
559 if (!filterType.hasStaticShape())
560 return rewriter.notifyMatchFailure(
561 convOp, "expected a static shape for the filter");
562
563 if (!inputType.hasStaticShape())
564 return rewriter.notifyMatchFailure(convOp,
565 "expected a static shape for the input");
566
567 // TODO: Support dilation.
568 if (!hasAllOneValues(convOp.getDilations()))
569 return rewriter.notifyMatchFailure(convOp,
570 "expected all ones for dilations");
571
572 MLIRContext *context = rewriter.getContext();
573 Value input = convOp.getInputs()[0];
574 Value filter = convOp.getInputs()[1];
575 Value output = convOp.getOutputs()[0];
576
577 ArrayRef<int64_t> filterShape = filterType.getShape();
578 ArrayRef<int64_t> outputShape = outputType.getShape();
579
580 int64_t n = outputShape[0];
581 int64_t oh = outputShape[1];
582 int64_t ow = outputShape[2];
583 int64_t oc = outputShape[3];
584 int64_t fh = filterShape[1];
585 int64_t fw = filterShape[2];
586 int64_t ic = filterShape[3];
587
588 Location loc = convOp.getLoc();
589
590 assert(isa<RankedTensorType>(filterType) &&
591 "expected filter type to be a ranked tensor");
592 auto tensorFilterType = cast<RankedTensorType>(filterType);
593
594 // Reshape output and filter to the LHS and result of a "row-wise" matrix
595 // multiplication.
596 SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
597 auto reshapedFilterType =
598 RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(),
599 tensorFilterType.getEncoding());
600 Value reshapedFilter = tensor::CollapseShapeOp::create(
601 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
602
603 SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
604 RankedTensorType reshapedOutputType =
605 RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
606 Value reshapedOutput = tensor::CollapseShapeOp::create(
607 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
608
609 // Shape of the Toeplitz matrix produced by Im2col.
610 SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
611 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
612 inputType.getElementType());
613
614 // Convert the input to a (BMK) column tensor.
615 auto nloops = colTensorShape.size();
616
617 auto parallel = utils::IteratorType::parallel;
618 auto reduction = utils::IteratorType::reduction;
619 SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
620
621 // Given an index of the im2col matrix, retrieve the corresponding indices of
622 // the output and filter matrices
623 auto mIndicesExprs =
624 delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1});
625 auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U),
626 ArrayRef<int64_t>{fw * ic, ic, 1});
627 Im2ColToOperandsExprs i2cToOperExprs;
628 i2cToOperExprs.fhIndex = kIndicesExprs[0];
629 i2cToOperExprs.fwIndex = kIndicesExprs[1];
630 i2cToOperExprs.icIndex = kIndicesExprs[2];
631 i2cToOperExprs.ohIndex = mIndicesExprs[0];
632 i2cToOperExprs.owIndex = mIndicesExprs[1];
633
634 // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
636 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
637 rewriter);
638 auto inMap =
640 inExprs.wIndex, inExprs.cIndex}},
641 rewriter.getContext())[0];
642 SmallVector<AffineMap> img2colIndexingMaps = {
643 inMap, AffineMap::getMultiDimIdentityMap(nloops, context)};
644
645 auto img2ColTensor = linalg::GenericOp::create(
646 rewriter, loc, colTensor.getType(),
647 /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps,
648 img2colIterators,
649 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
650 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
651 });
652
653 // Because we didn't transpose the filters we don't actually have a batched
654 // matrix multiply. Instead, we have an operation consisting of "row-wise" dot
655 // products.
656 AffineExpr bDim, mDim, nDim, kDim;
657 bindDims(context, bDim, mDim, nDim, kDim);
658 auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
659 auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
660 auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
661 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
662 parallel, reduction};
663
664 auto genericOp = linalg::GenericOp::create(
665 rewriter, loc, reshapedOutputType,
666 /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
667 /*outputs=*/ValueRange{reshapedOutput},
668 ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
669 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
670 Value mul =
671 createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
672 Value add = createAdd(loc, mul, args[2], nestedBuilder);
673 linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
674 });
675 Value result = genericOp.getResults().front();
676
677 auto reshapedResult = tensor::ExpandShapeOp::create(
678 rewriter, loc, outputType, result, outputReassocIndices);
679
680 rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
681
682 return std::make_pair(img2ColTensor.getOperation(),
683 reshapedResult.getOperation());
684}
685
686namespace {
687
688class ConvertConv2DNhwcHwcf final
689 : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
690public:
692
693 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
694 PatternRewriter &rewriter) const override {
695 if (failed(rewriteInIm2Col(rewriter, convOp)))
696 return failure();
697 return success();
698 }
699};
700
701class ConvertDepthwiseConv2DNhwcHwc final
702 : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
703public:
704 using OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>::OpRewritePattern;
705
706 LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
707 PatternRewriter &rewriter) const override {
708 if (failed(rewriteInIm2Col(rewriter, convOp)))
709 return failure();
710 return success();
711 }
712};
713
714class ConvertConv2DNchwFchw final
715 : public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
716public:
718
719 LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
720 PatternRewriter &rewriter) const override {
721 if (failed(rewriteInIm2Col(rewriter, convOp)))
722 return failure();
723 return success();
724 }
725};
726
727class ConvertConv2DNhwcFhwc final
728 : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
729public:
731
732 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
733 PatternRewriter &rewriter) const override {
734 if (failed(rewriteInIm2Col(rewriter, convOp)))
735 return failure();
736 return success();
737 }
738};
739} // end anonymous namespace
740
742 MLIRContext *context = patterns.getContext();
743 patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
744 ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
745}
746} // end namespace linalg
747} // end namespace mlir
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
#define mul(a, b)
#define add(a, b)
Base type for affine expression.
Definition AffineExpr.h:68
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:372
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
MLIRContext * getContext() const
Definition Builders.h:56
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static Im2ColToInputDimsExprs getIm2ColInputExpressions(Im2ColToOperandsExprs exprs, ArrayRef< int64_t > strides, RewriterBase &rewriter)
Construct the affine expressions that map the indices of the im2col matrix to the corresponding input...
static bool hasAllOneValues(DenseIntElementsAttr attr)
static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride, bool useSymbols=true)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:238
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...