30       attr, [](
const APInt &element) { 
return element.getSExtValue() == 1; });
 
   34   if (isa<IntegerType>(x.
getType()))
 
   35     return arith::AddIOp::create(builder, loc, x, y);
 
   36   if (isa<ComplexType>(x.
getType()))
 
   37     return complex::AddOp::create(builder, loc, x, y);
 
   38   return arith::AddFOp::create(builder, loc, x, y);
 
   48   if (isa<ComplexType>(accType))
 
   49     return complex::MulOp::create(builder, loc, xConvert, yConvert);
 
   50   if (isa<IntegerType>(accType))
 
   51     return arith::MulIOp::create(builder, loc, xConvert, yConvert);
 
   52   return arith::MulFOp::create(builder, loc, xConvert, yConvert);
 
   59                                    bool useSymbols = 
true) {
 
  114   hIndexExpr = hIndexExpr.compose(hIndicesMap);
 
  117   wIndexExpr = wIndexExpr.compose(wIndicesMap);
 
  118   auto cIndexExpr = exprs.
icIndex;
 
  119   return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
 
  122 FailureOr<std::pair<Operation *, Operation *>>
 
  124   auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
 
  125   auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
 
  126   auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
 
  128   if (!convOp.hasPureTensorSemantics())
 
  130         convOp, 
"expected op to have pure tensor semantics");
 
  132   if (!filterType.hasStaticShape())
 
  134         convOp, 
"expected a static shape for the filter");
 
  136   if (!inputType.hasStaticShape())
 
  138                                        "expected a static shape for the input");
 
  143                                        "expected all ones for dilations");
 
  146   Value input = convOp.getInputs()[0];
 
  147   Value filter = convOp.getInputs()[1];
 
  148   Value output = convOp.getOutputs()[0];
 
  153   int64_t n = outputShape[0];
 
  154   int64_t oh = outputShape[1];
 
  155   int64_t ow = outputShape[2];
 
  156   int64_t oc = outputShape[3];
 
  157   int64_t fh = filterShape[0];
 
  158   int64_t fw = filterShape[1];
 
  159   int64_t ic = filterShape[2];
 
  163   assert(isa<RankedTensorType>(filterType) &&
 
  164          "expected filter type to be a ranked tensor");
 
  165   auto tensorFilterType = cast<RankedTensorType>(filterType);
 
  169   auto reshapedFilterType =
 
  171                             tensorFilterType.getEncoding());
 
  172   Value reshapedFilter = tensor::CollapseShapeOp::create(
 
  173       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 
  176   RankedTensorType reshapedOutputType =
 
  178   Value reshapedOutput = tensor::CollapseShapeOp::create(
 
  179       rewriter, loc, reshapedOutputType, output, outputReassocIndices);
 
  182   Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
 
  183                                             inputType.getElementType());
 
  186   auto nloops = colTensorShape.size();
 
  188   auto parallel = utils::IteratorType::parallel;
 
  189   auto reduction = utils::IteratorType::reduction;
 
  199   i2cToOperExprs.
fhIndex = kIndicesExprs[0];
 
  200   i2cToOperExprs.
fwIndex = kIndicesExprs[1];
 
  201   i2cToOperExprs.
icIndex = kIndicesExprs[2];
 
  202   i2cToOperExprs.
ohIndex = mIndicesExprs[0];
 
  203   i2cToOperExprs.
owIndex = mIndicesExprs[1];
 
  207       i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
 
  217   auto img2ColTensor = linalg::GenericOp::create(
 
  218       rewriter, loc, colTensor.
getType(),
 
  219       input, colTensor, img2colIndexingMaps,
 
  222         linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
 
  230   bindDims(context, bDim, mDim, nDim, kDim);
 
  233   auto resultMap = 
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
 
  235                                                        parallel, reduction};
 
  237   auto genericOp = linalg::GenericOp::create(
 
  238       rewriter, loc, reshapedOutputType,
 
  239       ValueRange{img2ColTensor.getResult(0), reshapedFilter},
 
  246         linalg::YieldOp::create(nestedBuilder, nestedLoc, 
add);
 
  248   Value result = genericOp.getResults().front();
 
  250   auto reshapedResult = tensor::ExpandShapeOp::create(
 
  251       rewriter, loc, outputType, result, outputReassocIndices);
 
  255   return std::make_pair(img2ColTensor.getOperation(),
 
  256                         reshapedResult.getOperation());
 
  259 FailureOr<std::pair<Operation *, Operation *>>
 
  261                 linalg::DepthwiseConv2DNhwcHwcOp convOp) {
 
  262   auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
 
  263   auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
 
  264   auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
 
  266   if (!convOp.hasPureTensorSemantics())
 
  268         convOp, 
"expected op to have pure tensor semantics");
 
  270   if (!filterType.hasStaticShape())
 
  272         convOp, 
"expected a static shape for the filter");
 
  274   if (!inputType.hasStaticShape())
 
  276                                        "expected a static shape for the input");
 
  281                                        "expected all ones for dilations");
 
  286     auto operandTensorType = cast<RankedTensorType>(operand.
getType());
 
  287     auto nloops = indices.size();
 
  291         llvm::map_range(indices, [&](int64_t index) -> 
AffineExpr {
 
  296         indices, [&](int64_t index) -> int64_t { 
return inputShape[index]; }));
 
  298     Value outputTensor = tensor::EmptyOp::create(
 
  299         rewriter, loc, targetShape, operandTensorType.getElementType());
 
  302         nloops, utils::IteratorType::parallel);
 
  309     auto transposedOp = linalg::GenericOp::create(
 
  310         rewriter, loc, outputTensor.
getType(),
 
  311         operand, outputTensor, indexingMaps,
 
  314           linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
 
  317     return transposedOp.getResult(0);
 
  320   Value input = convOp.getInputs()[0];
 
  321   Value filter = convOp.getInputs()[1];
 
  322   Value output = convOp.getOutputs()[0];
 
  325   Value inputT = transposeOperand(input, {0, 3, 1, 2});
 
  326   Value filterT = transposeOperand(filter, {2, 0, 1});
 
  328       cast<RankedTensorType>(filterT.getType()).getShape();
 
  331   int n = outputShape[0];
 
  332   int oh = outputShape[1];
 
  333   int ow = outputShape[2];
 
  334   int c = outputShape[3];
 
  335   int fh = filterTShape[1];
 
  336   int fw = filterTShape[2];
 
  339   Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
 
  341   AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
 
  345       convOp.getStrides().getValues<int64_t>()[0]);
 
  347       convOp.getStrides().getValues<int64_t>()[1]);
 
  350                                         owDim * swSym + kwDim};
 
  352   auto nloops = colTensorShape.size();
 
  355       nloops, utils::IteratorType::parallel);
 
  361   Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
 
  362                                             inputType.getElementType());
 
  364   auto img2ColTensor = linalg::GenericOp::create(
 
  365       rewriter, loc, colTensor.
getType(),
 
  366       inputT, colTensor, indexingMaps,
 
  369         linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
 
  373       {0, 1}, {2, 3}, {4, 5}};
 
  379       {n * c, oh * ow, fh * fw}, inputType.getElementType());
 
  380   auto reshapedFilterTensorType =
 
  382   auto reshapedOutputTensorType =
 
  385   Value reshapedImg2ColTensor = tensor::CollapseShapeOp::create(
 
  386       rewriter, loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
 
  387       img2ColTensorReassocIndices);
 
  388   Value reshapedFilterTensor =
 
  389       tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterTensorType,
 
  390                                       filterT, filterReassociationIndice);
 
  391   Value reshapedoutputTensor = tensor::CollapseShapeOp::create(
 
  392       rewriter, loc, reshapedOutputTensorType, transposedOutputTensor,
 
  393       outputReassociationIndice);
 
  395   auto batchMatVecResult = linalg::BatchMatvecOp::create(
 
  397       ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
 
  403   auto batchMatVecResultReshaped = tensor::ExpandShapeOp::create(
 
  404       rewriter, loc, transposedOutputTensor.
getType(),
 
  405       batchMatVecResult.getResult(0), batchMatVecReassociationIndice);
 
  407   Value transposedResult =
 
  408       transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
 
  411   return std::make_pair(img2ColTensor.getOperation(),
 
  415 FailureOr<std::pair<Operation *, Operation *>>
 
  417   auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
 
  418   auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
 
  419   auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
 
  421   if (!convOp.hasPureTensorSemantics())
 
  423         convOp, 
"expected op to have pure tensor semantics");
 
  425   if (!filterType.hasStaticShape())
 
  427         convOp, 
"expected a static shape for the filter");
 
  429   if (!inputType.hasStaticShape())
 
  431                                        "expected a static shape for the input");
 
  436                                        "expected all ones for dilations");
 
  438   Value input = convOp.getInputs()[0];
 
  439   Value filter = convOp.getInputs()[1];
 
  440   Value output = convOp.getOutputs()[0];
 
  442   auto filterShape = filterType.getShape();
 
  443   auto outputShape = outputType.getShape();
 
  445   int64_t n = outputShape[0];
 
  446   int64_t oc = outputShape[1];
 
  447   int64_t oh = outputShape[2];
 
  448   int64_t ow = outputShape[3];
 
  449   int64_t ic = filterShape[1];
 
  450   int64_t fh = filterShape[2];
 
  451   int64_t fw = filterShape[3];
 
  453   auto loc = convOp.getLoc();
 
  456   assert(isa<RankedTensorType>(filterType) &&
 
  457          "expected filter type to be a ranked tensor");
 
  458   auto tensorFilterType = cast<RankedTensorType>(filterType);
 
  461   auto reshapedFilterType =
 
  463                             tensorFilterType.getEncoding());
 
  464   Value reshapedFilter = tensor::CollapseShapeOp::create(
 
  465       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 
  468   auto reshapedOutputType =
 
  470   Value reshapedOutput = tensor::CollapseShapeOp::create(
 
  471       rewriter, loc, reshapedOutputType, output, outputReassocIndices);
 
  475   Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
 
  476                                             inputType.getElementType());
 
  478   auto nloops = colTensorShape.size();
 
  480   auto parallel = utils::IteratorType::parallel;
 
  481   auto reduction = utils::IteratorType::reduction;
 
  492   i2cToOperExprs.
icIndex = kIndicesExprs[0];
 
  493   i2cToOperExprs.
fhIndex = kIndicesExprs[1];
 
  494   i2cToOperExprs.
fwIndex = kIndicesExprs[2];
 
  495   i2cToOperExprs.
ohIndex = mIndicesExprs[0];
 
  496   i2cToOperExprs.
owIndex = mIndicesExprs[1];
 
  498       i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
 
  508   auto img2ColTensor = linalg::GenericOp::create(
 
  509       rewriter, loc, colTensor.
getType(),
 
  510       input, colTensor, img2colIndexingMaps,
 
  513         linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
 
  521   bindDims(context, bDim, mDim, nDim, kDim);
 
  524   auto resultMap = 
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
 
  526                                                        parallel, reduction};
 
  527   auto genericOp = linalg::GenericOp::create(
 
  528       rewriter, loc, reshapedOutputType,
 
  529       ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
 
  536         linalg::YieldOp::create(nestedBuilder, nestedLoc, 
add);
 
  538   Value result = genericOp.getResults().front();
 
  540   auto reshapedResult = tensor::ExpandShapeOp::create(
 
  541       rewriter, loc, outputType, result, outputReassocIndices);
 
  545   return std::make_pair(img2ColTensor.getOperation(),
 
  546                         reshapedResult.getOperation());
 
  549 FailureOr<std::pair<Operation *, Operation *>>
 
  551   auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
 
  552   auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
 
  553   auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
 
  555   if (!convOp.hasPureTensorSemantics())
 
  557         convOp, 
"expected op to have pure tensor semantics");
 
  559   if (!filterType.hasStaticShape())
 
  561         convOp, 
"expected a static shape for the filter");
 
  563   if (!inputType.hasStaticShape())
 
  565                                        "expected a static shape for the input");
 
  570                                        "expected all ones for dilations");
 
  573   Value input = convOp.getInputs()[0];
 
  574   Value filter = convOp.getInputs()[1];
 
  575   Value output = convOp.getOutputs()[0];
 
  580   int64_t n = outputShape[0];
 
  581   int64_t oh = outputShape[1];
 
  582   int64_t ow = outputShape[2];
 
  583   int64_t oc = outputShape[3];
 
  584   int64_t fh = filterShape[1];
 
  585   int64_t fw = filterShape[2];
 
  586   int64_t ic = filterShape[3];
 
  590   assert(isa<RankedTensorType>(filterType) &&
 
  591          "expected filter type to be a ranked tensor");
 
  592   auto tensorFilterType = cast<RankedTensorType>(filterType);
 
  597   auto reshapedFilterType =
 
  599                             tensorFilterType.getEncoding());
 
  600   Value reshapedFilter = tensor::CollapseShapeOp::create(
 
  601       rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
 
  604   RankedTensorType reshapedOutputType =
 
  606   Value reshapedOutput = tensor::CollapseShapeOp::create(
 
  607       rewriter, loc, reshapedOutputType, output, outputReassocIndices);
 
  611   Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
 
  612                                             inputType.getElementType());
 
  615   auto nloops = colTensorShape.size();
 
  617   auto parallel = utils::IteratorType::parallel;
 
  618   auto reduction = utils::IteratorType::reduction;
 
  628   i2cToOperExprs.
fhIndex = kIndicesExprs[0];
 
  629   i2cToOperExprs.
fwIndex = kIndicesExprs[1];
 
  630   i2cToOperExprs.
icIndex = kIndicesExprs[2];
 
  631   i2cToOperExprs.
ohIndex = mIndicesExprs[0];
 
  632   i2cToOperExprs.
owIndex = mIndicesExprs[1];
 
  636       i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
 
  645   auto img2ColTensor = linalg::GenericOp::create(
 
  646       rewriter, loc, colTensor.
getType(),
 
  647       input, colTensor, img2colIndexingMaps,
 
  650         linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
 
  657   bindDims(context, bDim, mDim, nDim, kDim);
 
  660   auto resultMap = 
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
 
  662                                                        parallel, reduction};
 
  664   auto genericOp = linalg::GenericOp::create(
 
  665       rewriter, loc, reshapedOutputType,
 
  666       ValueRange{img2ColTensor.getResult(0), reshapedFilter},
 
  673         linalg::YieldOp::create(nestedBuilder, nestedLoc, 
add);
 
  675   Value result = genericOp.getResults().front();
 
  677   auto reshapedResult = tensor::ExpandShapeOp::create(
 
  678       rewriter, loc, outputType, result, outputReassocIndices);
 
  682   return std::make_pair(img2ColTensor.getOperation(),
 
  683                         reshapedResult.getOperation());
 
  688 class ConvertConv2DNhwcHwcf final
 
  693   LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
 
  701 class ConvertDepthwiseConv2DNhwcHwc final
 
  702     : 
public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
 
  706   LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
 
  707                                 PatternRewriter &rewriter)
 const override {
 
  714 class ConvertConv2DNchwFchw final
 
  715     : 
public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
 
  719   LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
 
  720                                 PatternRewriter &rewriter)
 const override {
 
  727 class ConvertConv2DNhwcFhwc final
 
  728     : 
public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
 
  732   LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
 
  733                                 PatternRewriter &rewriter)
 const override {
 
  743   patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
 
  744                   ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
 
Base type for affine expression.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static Im2ColToInputDimsExprs getIm2ColInputExpressions(Im2ColToOperandsExprs exprs, ArrayRef< int64_t > strides, RewriterBase &rewriter)
Construct the affine expressions that map the indices of the im2col matrix to the corresponding input...
static bool hasAllOneValues(DenseIntElementsAttr attr)
static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride, bool useSymbols=true)
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...