30 attr, [](
const APInt &element) {
return element.getSExtValue() == 1; });
34 if (isa<IntegerType>(x.
getType()))
35 return arith::AddIOp::create(builder, loc, x, y);
36 if (isa<ComplexType>(x.
getType()))
37 return complex::AddOp::create(builder, loc, x, y);
38 return arith::AddFOp::create(builder, loc, x, y);
48 if (isa<ComplexType>(accType))
49 return complex::MulOp::create(builder, loc, xConvert, yConvert);
50 if (isa<IntegerType>(accType))
51 return arith::MulIOp::create(builder, loc, xConvert, yConvert);
52 return arith::MulFOp::create(builder, loc, xConvert, yConvert);
59 bool useSymbols =
true) {
114 hIndexExpr = hIndexExpr.compose(hIndicesMap);
117 wIndexExpr = wIndexExpr.compose(wIndicesMap);
118 auto cIndexExpr = exprs.
icIndex;
119 return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
122 FailureOr<std::pair<Operation *, Operation *>>
124 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
125 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
126 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
128 if (!convOp.hasPureTensorSemantics())
130 convOp,
"expected op to have pure tensor semantics");
132 if (!filterType.hasStaticShape())
134 convOp,
"expected a static shape for the filter");
136 if (!inputType.hasStaticShape())
138 "expected a static shape for the input");
143 "expected all ones for dilations");
146 Value input = convOp.getInputs()[0];
147 Value filter = convOp.getInputs()[1];
148 Value output = convOp.getOutputs()[0];
153 int64_t n = outputShape[0];
154 int64_t oh = outputShape[1];
155 int64_t ow = outputShape[2];
156 int64_t oc = outputShape[3];
157 int64_t fh = filterShape[0];
158 int64_t fw = filterShape[1];
159 int64_t ic = filterShape[2];
163 assert(isa<RankedTensorType>(filterType) &&
164 "expected filter type to be a ranked tensor");
165 auto tensorFilterType = cast<RankedTensorType>(filterType);
169 auto reshapedFilterType =
171 tensorFilterType.getEncoding());
172 Value reshapedFilter = tensor::CollapseShapeOp::create(
173 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
176 RankedTensorType reshapedOutputType =
178 Value reshapedOutput = tensor::CollapseShapeOp::create(
179 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
182 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
183 inputType.getElementType());
186 auto nloops = colTensorShape.size();
188 auto parallel = utils::IteratorType::parallel;
189 auto reduction = utils::IteratorType::reduction;
199 i2cToOperExprs.
fhIndex = kIndicesExprs[0];
200 i2cToOperExprs.
fwIndex = kIndicesExprs[1];
201 i2cToOperExprs.
icIndex = kIndicesExprs[2];
202 i2cToOperExprs.
ohIndex = mIndicesExprs[0];
203 i2cToOperExprs.
owIndex = mIndicesExprs[1];
207 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
217 auto img2ColTensor = linalg::GenericOp::create(
218 rewriter, loc, colTensor.
getType(),
219 input, colTensor, img2colIndexingMaps,
222 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
230 bindDims(context, bDim, mDim, nDim, kDim);
233 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
235 parallel, reduction};
237 auto genericOp = linalg::GenericOp::create(
238 rewriter, loc, reshapedOutputType,
239 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
246 linalg::YieldOp::create(nestedBuilder, nestedLoc,
add);
248 Value result = genericOp.getResults().front();
250 auto reshapedResult = tensor::ExpandShapeOp::create(
251 rewriter, loc, outputType, result, outputReassocIndices);
255 return std::make_pair(img2ColTensor.getOperation(),
256 reshapedResult.getOperation());
259 FailureOr<std::pair<Operation *, Operation *>>
261 linalg::DepthwiseConv2DNhwcHwcOp convOp) {
262 auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
263 auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
264 auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
266 if (!convOp.hasPureTensorSemantics())
268 convOp,
"expected op to have pure tensor semantics");
270 if (!filterType.hasStaticShape())
272 convOp,
"expected a static shape for the filter");
274 if (!inputType.hasStaticShape())
276 "expected a static shape for the input");
281 "expected all ones for dilations");
286 auto operandTensorType = cast<RankedTensorType>(operand.
getType());
287 auto nloops = indices.size();
291 llvm::map_range(indices, [&](int64_t index) ->
AffineExpr {
296 indices, [&](int64_t index) -> int64_t {
return inputShape[index]; }));
298 Value outputTensor = tensor::EmptyOp::create(
299 rewriter, loc, targetShape, operandTensorType.getElementType());
302 nloops, utils::IteratorType::parallel);
309 auto transposedOp = linalg::GenericOp::create(
310 rewriter, loc, outputTensor.
getType(),
311 operand, outputTensor, indexingMaps,
314 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
317 return transposedOp.getResult(0);
320 Value input = convOp.getInputs()[0];
321 Value filter = convOp.getInputs()[1];
322 Value output = convOp.getOutputs()[0];
325 Value inputT = transposeOperand(input, {0, 3, 1, 2});
326 Value filterT = transposeOperand(filter, {2, 0, 1});
328 cast<RankedTensorType>(filterT.getType()).getShape();
331 int n = outputShape[0];
332 int oh = outputShape[1];
333 int ow = outputShape[2];
334 int c = outputShape[3];
335 int fh = filterTShape[1];
336 int fw = filterTShape[2];
339 Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
341 AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
345 convOp.getStrides().getValues<int64_t>()[0]);
347 convOp.getStrides().getValues<int64_t>()[1]);
350 owDim * swSym + kwDim};
352 auto nloops = colTensorShape.size();
355 nloops, utils::IteratorType::parallel);
361 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
362 inputType.getElementType());
364 auto img2ColTensor = linalg::GenericOp::create(
365 rewriter, loc, colTensor.
getType(),
366 inputT, colTensor, indexingMaps,
369 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
373 {0, 1}, {2, 3}, {4, 5}};
379 {n * c, oh * ow, fh * fw}, inputType.getElementType());
380 auto reshapedFilterTensorType =
382 auto reshapedOutputTensorType =
385 Value reshapedImg2ColTensor = tensor::CollapseShapeOp::create(
386 rewriter, loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
387 img2ColTensorReassocIndices);
388 Value reshapedFilterTensor =
389 tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterTensorType,
390 filterT, filterReassociationIndice);
391 Value reshapedoutputTensor = tensor::CollapseShapeOp::create(
392 rewriter, loc, reshapedOutputTensorType, transposedOutputTensor,
393 outputReassociationIndice);
395 auto batchMatVecResult = linalg::BatchMatvecOp::create(
397 ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
403 auto batchMatVecResultReshaped = tensor::ExpandShapeOp::create(
404 rewriter, loc, transposedOutputTensor.
getType(),
405 batchMatVecResult.getResult(0), batchMatVecReassociationIndice);
407 Value transposedResult =
408 transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
411 return std::make_pair(img2ColTensor.getOperation(),
415 FailureOr<std::pair<Operation *, Operation *>>
417 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
418 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
419 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
421 if (!convOp.hasPureTensorSemantics())
423 convOp,
"expected op to have pure tensor semantics");
425 if (!filterType.hasStaticShape())
427 convOp,
"expected a static shape for the filter");
429 if (!inputType.hasStaticShape())
431 "expected a static shape for the input");
436 "expected all ones for dilations");
438 Value input = convOp.getInputs()[0];
439 Value filter = convOp.getInputs()[1];
440 Value output = convOp.getOutputs()[0];
442 auto filterShape = filterType.getShape();
443 auto outputShape = outputType.getShape();
445 int64_t n = outputShape[0];
446 int64_t oc = outputShape[1];
447 int64_t oh = outputShape[2];
448 int64_t ow = outputShape[3];
449 int64_t ic = filterShape[1];
450 int64_t fh = filterShape[2];
451 int64_t fw = filterShape[3];
453 auto loc = convOp.getLoc();
456 assert(isa<RankedTensorType>(filterType) &&
457 "expected filter type to be a ranked tensor");
458 auto tensorFilterType = cast<RankedTensorType>(filterType);
461 auto reshapedFilterType =
463 tensorFilterType.getEncoding());
464 Value reshapedFilter = tensor::CollapseShapeOp::create(
465 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
468 auto reshapedOutputType =
470 Value reshapedOutput = tensor::CollapseShapeOp::create(
471 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
475 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
476 inputType.getElementType());
478 auto nloops = colTensorShape.size();
480 auto parallel = utils::IteratorType::parallel;
481 auto reduction = utils::IteratorType::reduction;
492 i2cToOperExprs.
icIndex = kIndicesExprs[0];
493 i2cToOperExprs.
fhIndex = kIndicesExprs[1];
494 i2cToOperExprs.
fwIndex = kIndicesExprs[2];
495 i2cToOperExprs.
ohIndex = mIndicesExprs[0];
496 i2cToOperExprs.
owIndex = mIndicesExprs[1];
498 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
508 auto img2ColTensor = linalg::GenericOp::create(
509 rewriter, loc, colTensor.
getType(),
510 input, colTensor, img2colIndexingMaps,
513 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
521 bindDims(context, bDim, mDim, nDim, kDim);
524 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
526 parallel, reduction};
527 auto genericOp = linalg::GenericOp::create(
528 rewriter, loc, reshapedOutputType,
529 ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
536 linalg::YieldOp::create(nestedBuilder, nestedLoc,
add);
538 Value result = genericOp.getResults().front();
540 auto reshapedResult = tensor::ExpandShapeOp::create(
541 rewriter, loc, outputType, result, outputReassocIndices);
545 return std::make_pair(img2ColTensor.getOperation(),
546 reshapedResult.getOperation());
549 FailureOr<std::pair<Operation *, Operation *>>
551 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
552 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
553 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
555 if (!convOp.hasPureTensorSemantics())
557 convOp,
"expected op to have pure tensor semantics");
559 if (!filterType.hasStaticShape())
561 convOp,
"expected a static shape for the filter");
563 if (!inputType.hasStaticShape())
565 "expected a static shape for the input");
570 "expected all ones for dilations");
573 Value input = convOp.getInputs()[0];
574 Value filter = convOp.getInputs()[1];
575 Value output = convOp.getOutputs()[0];
580 int64_t n = outputShape[0];
581 int64_t oh = outputShape[1];
582 int64_t ow = outputShape[2];
583 int64_t oc = outputShape[3];
584 int64_t fh = filterShape[1];
585 int64_t fw = filterShape[2];
586 int64_t ic = filterShape[3];
590 assert(isa<RankedTensorType>(filterType) &&
591 "expected filter type to be a ranked tensor");
592 auto tensorFilterType = cast<RankedTensorType>(filterType);
597 auto reshapedFilterType =
599 tensorFilterType.getEncoding());
600 Value reshapedFilter = tensor::CollapseShapeOp::create(
601 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
604 RankedTensorType reshapedOutputType =
606 Value reshapedOutput = tensor::CollapseShapeOp::create(
607 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
611 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
612 inputType.getElementType());
615 auto nloops = colTensorShape.size();
617 auto parallel = utils::IteratorType::parallel;
618 auto reduction = utils::IteratorType::reduction;
628 i2cToOperExprs.
fhIndex = kIndicesExprs[0];
629 i2cToOperExprs.
fwIndex = kIndicesExprs[1];
630 i2cToOperExprs.
icIndex = kIndicesExprs[2];
631 i2cToOperExprs.
ohIndex = mIndicesExprs[0];
632 i2cToOperExprs.
owIndex = mIndicesExprs[1];
636 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
645 auto img2ColTensor = linalg::GenericOp::create(
646 rewriter, loc, colTensor.
getType(),
647 input, colTensor, img2colIndexingMaps,
650 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
657 bindDims(context, bDim, mDim, nDim, kDim);
660 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
662 parallel, reduction};
664 auto genericOp = linalg::GenericOp::create(
665 rewriter, loc, reshapedOutputType,
666 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
673 linalg::YieldOp::create(nestedBuilder, nestedLoc,
add);
675 Value result = genericOp.getResults().front();
677 auto reshapedResult = tensor::ExpandShapeOp::create(
678 rewriter, loc, outputType, result, outputReassocIndices);
682 return std::make_pair(img2ColTensor.getOperation(),
683 reshapedResult.getOperation());
688 class ConvertConv2DNhwcHwcf final
693 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
701 class ConvertDepthwiseConv2DNhwcHwc final
702 :
public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
706 LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
707 PatternRewriter &rewriter)
const override {
714 class ConvertConv2DNchwFchw final
715 :
public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
719 LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
720 PatternRewriter &rewriter)
const override {
727 class ConvertConv2DNhwcFhwc final
728 :
public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
732 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
733 PatternRewriter &rewriter)
const override {
743 patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
744 ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
Base type for affine expression.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static Im2ColToInputDimsExprs getIm2ColInputExpressions(Im2ColToOperandsExprs exprs, ArrayRef< int64_t > strides, RewriterBase &rewriter)
Construct the affine expressions that map the indices of the im2col matrix to the corresponding input...
static bool hasAllOneValues(DenseIntElementsAttr attr)
static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride, bool useSymbols=true)
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...