28 attr, [](
const APInt &element) {
return element.getSExtValue() == 1; });
32 if (isa<IntegerType>(x.
getType()))
33 return arith::AddIOp::create(builder, loc, x, y);
34 if (isa<ComplexType>(x.
getType()))
35 return complex::AddOp::create(builder, loc, x, y);
36 return arith::AddFOp::create(builder, loc, x, y);
46 if (isa<ComplexType>(accType))
47 return complex::MulOp::create(builder, loc, xConvert, yConvert);
48 if (isa<IntegerType>(accType))
49 return arith::MulIOp::create(builder, loc, xConvert, yConvert);
50 return arith::MulFOp::create(builder, loc, xConvert, yConvert);
56 assert(!factors.empty() &&
"empty factor list");
58 for (int64_t f : factors)
59 basis.push_back(arith::ConstantOp::create(b, loc, b.
getIndexAttr(f)));
60 FailureOr<SmallVector<Value>> multiIndex =
62 assert(!failed(multiIndex) &&
"Failed to linearize img2col index");
70 Value fIndex, int64_t stride) {
77 FailureOr<std::pair<Operation *, Operation *>>
79 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
80 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
81 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
83 if (!filterType.hasStaticShape())
85 convOp,
"expected a static shape for the filter");
87 if (!inputType.hasStaticShape())
89 "expected a static shape for the input");
94 "expected all ones for dilations");
97 Value input = convOp.getInputs()[0];
98 Value filter = convOp.getInputs()[1];
99 Value output = convOp.getOutputs()[0];
104 int64_t n = outputShape[0];
105 int64_t oh = outputShape[1];
106 int64_t ow = outputShape[2];
107 int64_t oc = outputShape[3];
108 int64_t fh = filterShape[0];
109 int64_t fw = filterShape[1];
110 int64_t ic = filterShape[2];
116 auto reshapedFilterType =
118 Value reshapedFilter = tensor::CollapseShapeOp::create(
119 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
122 RankedTensorType reshapedOutputType =
124 Value reshapedOutput = tensor::CollapseShapeOp::create(
125 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
128 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
129 inputType.getElementType());
132 auto nloops = colTensorShape.size();
134 auto parallel = utils::IteratorType::parallel;
135 auto reduction = utils::IteratorType::reduction;
141 auto img2ColTensor = linalg::GenericOp::create(
142 rewriter, loc, colTensor.
getType(),
147 Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
148 Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
149 Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
152 SmallVector<Value> mIndices = unrollIndex(
153 nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
154 auto ohIndex = mIndices[0];
155 auto owIndex = mIndices[1];
159 auto fhIndex = kIndices[0];
160 auto fwIndex = kIndices[1];
161 auto icIndex = kIndices[2];
166 convOp.getStrides().getValues<int64_t>()[0]);
169 convOp.getStrides().getValues<int64_t>()[1]);
173 Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
175 linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
183 bindDims(context, bDim, mDim, nDim, kDim);
186 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
187 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
188 parallel, reduction};
190 auto genericOp = linalg::GenericOp::create(
191 rewriter, loc, reshapedOutputType,
192 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
193 ValueRange{reshapedOutput},
194 ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
195 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
199 linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
201 Value result = genericOp.getResults().front();
203 auto reshapedResult = tensor::ExpandShapeOp::create(
204 rewriter, loc, outputType, result, outputReassocIndices);
206 rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
208 return std::make_pair(img2ColTensor.getOperation(),
209 reshapedResult.getOperation());
212 FailureOr<std::pair<Operation *, Operation *>>
214 linalg::DepthwiseConv2DNhwcHwcOp convOp) {
215 auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
216 auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
217 auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
219 if (!filterType.hasStaticShape())
221 convOp,
"expected a static shape for the filter");
223 if (!inputType.hasStaticShape())
225 "expected a static shape for the input");
230 "expected all ones for dilations");
235 auto operandTensorType = cast<RankedTensorType>(operand.
getType());
236 auto nloops = indices.size();
240 llvm::map_range(indices, [&](int64_t index) ->
AffineExpr {
245 indices, [&](int64_t index) -> int64_t {
return inputShape[index]; }));
247 Value outputTensor = tensor::EmptyOp::create(
248 rewriter, loc, targetShape, operandTensorType.getElementType());
251 nloops, utils::IteratorType::parallel);
258 auto transposedOp = linalg::GenericOp::create(
259 rewriter, loc, outputTensor.
getType(),
260 operand, outputTensor, indexingMaps,
263 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
266 return transposedOp.getResult(0);
269 Value input = convOp.getInputs()[0];
270 Value filter = convOp.getInputs()[1];
271 Value output = convOp.getOutputs()[0];
274 Value inputT = transposeOperand(input, {0, 3, 1, 2});
275 Value filterT = transposeOperand(filter, {2, 0, 1});
277 cast<RankedTensorType>(filterT.getType()).getShape();
280 int n = outputShape[0];
281 int oh = outputShape[1];
282 int ow = outputShape[2];
283 int c = outputShape[3];
284 int fh = filterTShape[1];
285 int fw = filterTShape[2];
288 Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
290 AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
294 convOp.getStrides().getValues<int64_t>()[0]);
296 convOp.getStrides().getValues<int64_t>()[1]);
299 owDim * swSym + kwDim};
301 auto nloops = colTensorShape.size();
304 nloops, utils::IteratorType::parallel);
310 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
311 inputType.getElementType());
313 auto img2ColTensor = linalg::GenericOp::create(
314 rewriter, loc, colTensor.
getType(),
315 inputT, colTensor, indexingMaps,
318 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
322 {0, 1}, {2, 3}, {4, 5}};
328 {n * c, oh * ow, fh * fw}, inputType.getElementType());
329 auto reshapedFilterTensorType =
331 auto reshapedOutputTensorType =
334 Value reshapedImg2ColTensor = tensor::CollapseShapeOp::create(
335 rewriter, loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
336 img2ColTensorReassocIndices);
337 Value reshapedFilterTensor =
338 tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterTensorType,
339 filterT, filterReassociationIndice);
340 Value reshapedoutputTensor = tensor::CollapseShapeOp::create(
341 rewriter, loc, reshapedOutputTensorType, transposedOutputTensor,
342 outputReassociationIndice);
344 auto batchMatVecResult = linalg::BatchMatvecOp::create(
346 ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
352 auto batchMatVecResultReshaped = tensor::ExpandShapeOp::create(
353 rewriter, loc, transposedOutputTensor.
getType(),
354 batchMatVecResult.getResult(0), batchMatVecReassociationIndice);
356 Value transposedResult =
357 transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
360 return std::make_pair(img2ColTensor.getOperation(),
364 FailureOr<std::pair<Operation *, Operation *>>
366 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
367 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
368 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
370 if (!filterType.hasStaticShape())
372 convOp,
"expected a static shape for the filter");
374 if (!inputType.hasStaticShape())
376 "expected a static shape for the input");
381 "expected all ones for dilations");
383 Value input = convOp.getInputs()[0];
384 Value filter = convOp.getInputs()[1];
385 Value output = convOp.getOutputs()[0];
387 auto filterShape = filterType.getShape();
388 auto outputShape = outputType.getShape();
390 int64_t n = outputShape[0];
391 int64_t oc = outputShape[1];
392 int64_t oh = outputShape[2];
393 int64_t ow = outputShape[3];
394 int64_t ic = filterShape[1];
395 int64_t fh = filterShape[2];
396 int64_t fw = filterShape[3];
398 auto loc = convOp.
getLoc();
402 auto reshapedFilterType =
404 Value reshapedFilter = tensor::CollapseShapeOp::create(
405 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
408 auto reshapedOutputType =
410 Value reshapedOutput = tensor::CollapseShapeOp::create(
411 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
415 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
416 inputType.getElementType());
418 auto nloops = colTensorShape.size();
420 auto parallel = utils::IteratorType::parallel;
421 auto reduction = utils::IteratorType::reduction;
427 auto img2ColTensor = linalg::GenericOp::create(
428 rewriter, loc, colTensor.
getType(),
433 Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
434 Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
435 Value nIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
438 SmallVector<Value> kIndices = unrollIndex(
439 nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
440 auto icIndex = kIndices[0];
441 auto fhIndex = kIndices[1];
442 auto fwIndex = kIndices[2];
446 auto ohIndex = nIndices[0];
447 auto owIndex = nIndices[1];
452 convOp.getStrides().getValues<int64_t>()[0]);
455 convOp.getStrides().getValues<int64_t>()[1]);
459 Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
461 linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
469 bindDims(context, bDim, mDim, nDim, kDim);
472 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
473 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
474 parallel, reduction};
475 auto genericOp = linalg::GenericOp::create(
476 rewriter, loc, reshapedOutputType,
477 ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
478 ValueRange{reshapedOutput},
479 ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
480 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
484 linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
486 Value result = genericOp.getResults().front();
488 auto reshapedResult = tensor::ExpandShapeOp::create(
489 rewriter, loc, outputType, result, outputReassocIndices);
491 rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
493 return std::make_pair(img2ColTensor.getOperation(),
494 reshapedResult.getOperation());
497 FailureOr<std::pair<Operation *, Operation *>>
499 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
500 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
501 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
503 if (!filterType.hasStaticShape())
505 convOp,
"expected a static shape for the filter");
507 if (!inputType.hasStaticShape())
509 "expected a static shape for the input");
514 "expected all ones for dilations");
517 Value input = convOp.getInputs()[0];
518 Value filter = convOp.getInputs()[1];
519 Value output = convOp.getOutputs()[0];
524 int64_t n = outputShape[0];
525 int64_t oh = outputShape[1];
526 int64_t ow = outputShape[2];
527 int64_t oc = outputShape[3];
528 int64_t fh = filterShape[1];
529 int64_t fw = filterShape[2];
530 int64_t ic = filterShape[3];
537 auto reshapedFilterType =
539 Value reshapedFilter = tensor::CollapseShapeOp::create(
540 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
543 RankedTensorType reshapedOutputType =
545 Value reshapedOutput = tensor::CollapseShapeOp::create(
546 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
549 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
550 inputType.getElementType());
553 auto nloops = colTensorShape.size();
555 auto parallel = utils::IteratorType::parallel;
556 auto reduction = utils::IteratorType::reduction;
560 AffineMap::getMultiDimIdentityMap(nloops, context)};
562 auto img2ColTensor = linalg::GenericOp::create(
563 rewriter, loc, colTensor.
getType(),
568 Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0);
569 Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1);
570 Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2);
573 SmallVector<Value> mIndices = unrollIndex(
574 nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
575 auto ohIndex = mIndices[0];
576 auto owIndex = mIndices[1];
580 auto fhIndex = kIndices[0];
581 auto fwIndex = kIndices[1];
582 auto icIndex = kIndices[2];
587 convOp.getStrides().getValues<int64_t>()[0]);
590 convOp.getStrides().getValues<int64_t>()[1]);
594 Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input,
596 linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal);
603 bindDims(context, bDim, mDim, nDim, kDim);
606 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
607 SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
608 parallel, reduction};
610 auto genericOp = linalg::GenericOp::create(
611 rewriter, loc, reshapedOutputType,
612 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
613 ValueRange{reshapedOutput},
614 ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
615 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
619 linalg::YieldOp::create(nestedBuilder, nestedLoc, add);
621 Value result = genericOp.getResults().front();
623 auto reshapedResult = tensor::ExpandShapeOp::create(
624 rewriter, loc, outputType, result, outputReassocIndices);
626 rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
628 return std::make_pair(img2ColTensor.getOperation(),
629 reshapedResult.getOperation());
634 class ConvertConv2DNhwcHwcf final
635 :
public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
637 using OpRewritePattern::OpRewritePattern;
639 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
640 PatternRewriter &rewriter)
const override {
647 class ConvertDepthwiseConv2DNhwcHwc final
648 :
public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
650 using OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>::OpRewritePattern;
652 LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
653 PatternRewriter &rewriter)
const override {
660 class ConvertConv2DNchwFchw final
661 :
public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
663 using OpRewritePattern::OpRewritePattern;
665 LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
666 PatternRewriter &rewriter)
const override {
673 class ConvertConv2DNhwcFhwc final
674 :
public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
676 using OpRewritePattern::OpRewritePattern;
678 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
679 PatternRewriter &rewriter)
const override {
689 patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
690 ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
static Value createMul(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates a MulIOp if isInt is true otherwise create an MulFOp using operands x andy`.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
void bindDims(MLIRContext *ctx)
static SmallVector< Value > unrollIndex(OpBuilder &b, Location loc, Value index, ArrayRef< int64_t > factors)
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static bool hasAllOneValues(DenseIntElementsAttr attr)
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, Value fIndex, int64_t stride)
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...