29 attr, [](
const APInt &element) {
return element.getSExtValue() == 1; });
33 if (isa<IntegerType>(x.
getType()))
34 return arith::AddIOp::create(builder, loc, x, y);
35 if (isa<ComplexType>(x.
getType()))
36 return complex::AddOp::create(builder, loc, x, y);
37 return arith::AddFOp::create(builder, loc, x, y);
47 if (isa<ComplexType>(accType))
48 return complex::MulOp::create(builder, loc, xConvert, yConvert);
49 if (isa<IntegerType>(accType))
50 return arith::MulIOp::create(builder, loc, xConvert, yConvert);
51 return arith::MulFOp::create(builder, loc, xConvert, yConvert);
58 bool useSymbols =
true) {
113 hIndexExpr = hIndexExpr.compose(hIndicesMap);
116 wIndexExpr = wIndexExpr.compose(wIndicesMap);
117 auto cIndexExpr = exprs.
icIndex;
118 return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr};
121 FailureOr<std::pair<Operation *, Operation *>>
123 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
124 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
125 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
127 if (!filterType.hasStaticShape())
129 convOp,
"expected a static shape for the filter");
131 if (!inputType.hasStaticShape())
133 "expected a static shape for the input");
138 "expected all ones for dilations");
141 Value input = convOp.getInputs()[0];
142 Value filter = convOp.getInputs()[1];
143 Value output = convOp.getOutputs()[0];
148 int64_t n = outputShape[0];
149 int64_t oh = outputShape[1];
150 int64_t ow = outputShape[2];
151 int64_t oc = outputShape[3];
152 int64_t fh = filterShape[0];
153 int64_t fw = filterShape[1];
154 int64_t ic = filterShape[2];
160 auto reshapedFilterType =
162 Value reshapedFilter = tensor::CollapseShapeOp::create(
163 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
166 RankedTensorType reshapedOutputType =
168 Value reshapedOutput = tensor::CollapseShapeOp::create(
169 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
172 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
173 inputType.getElementType());
176 auto nloops = colTensorShape.size();
178 auto parallel = utils::IteratorType::parallel;
179 auto reduction = utils::IteratorType::reduction;
189 i2cToOperExprs.
fhIndex = kIndicesExprs[0];
190 i2cToOperExprs.
fwIndex = kIndicesExprs[1];
191 i2cToOperExprs.
icIndex = kIndicesExprs[2];
192 i2cToOperExprs.
ohIndex = mIndicesExprs[0];
193 i2cToOperExprs.
owIndex = mIndicesExprs[1];
197 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
207 auto img2ColTensor = linalg::GenericOp::create(
208 rewriter, loc, colTensor.
getType(),
209 input, colTensor, img2colIndexingMaps,
212 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
220 bindDims(context, bDim, mDim, nDim, kDim);
223 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
225 parallel, reduction};
227 auto genericOp = linalg::GenericOp::create(
228 rewriter, loc, reshapedOutputType,
229 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
236 linalg::YieldOp::create(nestedBuilder, nestedLoc,
add);
238 Value result = genericOp.getResults().front();
240 auto reshapedResult = tensor::ExpandShapeOp::create(
241 rewriter, loc, outputType, result, outputReassocIndices);
245 return std::make_pair(img2ColTensor.getOperation(),
246 reshapedResult.getOperation());
249 FailureOr<std::pair<Operation *, Operation *>>
251 linalg::DepthwiseConv2DNhwcHwcOp convOp) {
252 auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
253 auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
254 auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
256 if (!filterType.hasStaticShape())
258 convOp,
"expected a static shape for the filter");
260 if (!inputType.hasStaticShape())
262 "expected a static shape for the input");
267 "expected all ones for dilations");
272 auto operandTensorType = cast<RankedTensorType>(operand.
getType());
273 auto nloops = indices.size();
277 llvm::map_range(indices, [&](int64_t index) ->
AffineExpr {
282 indices, [&](int64_t index) -> int64_t {
return inputShape[index]; }));
284 Value outputTensor = tensor::EmptyOp::create(
285 rewriter, loc, targetShape, operandTensorType.getElementType());
288 nloops, utils::IteratorType::parallel);
295 auto transposedOp = linalg::GenericOp::create(
296 rewriter, loc, outputTensor.
getType(),
297 operand, outputTensor, indexingMaps,
300 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
303 return transposedOp.getResult(0);
306 Value input = convOp.getInputs()[0];
307 Value filter = convOp.getInputs()[1];
308 Value output = convOp.getOutputs()[0];
311 Value inputT = transposeOperand(input, {0, 3, 1, 2});
312 Value filterT = transposeOperand(filter, {2, 0, 1});
314 cast<RankedTensorType>(filterT.getType()).getShape();
317 int n = outputShape[0];
318 int oh = outputShape[1];
319 int ow = outputShape[2];
320 int c = outputShape[3];
321 int fh = filterTShape[1];
322 int fw = filterTShape[2];
325 Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
327 AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
331 convOp.getStrides().getValues<int64_t>()[0]);
333 convOp.getStrides().getValues<int64_t>()[1]);
336 owDim * swSym + kwDim};
338 auto nloops = colTensorShape.size();
341 nloops, utils::IteratorType::parallel);
347 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
348 inputType.getElementType());
350 auto img2ColTensor = linalg::GenericOp::create(
351 rewriter, loc, colTensor.
getType(),
352 inputT, colTensor, indexingMaps,
355 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
359 {0, 1}, {2, 3}, {4, 5}};
365 {n * c, oh * ow, fh * fw}, inputType.getElementType());
366 auto reshapedFilterTensorType =
368 auto reshapedOutputTensorType =
371 Value reshapedImg2ColTensor = tensor::CollapseShapeOp::create(
372 rewriter, loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
373 img2ColTensorReassocIndices);
374 Value reshapedFilterTensor =
375 tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterTensorType,
376 filterT, filterReassociationIndice);
377 Value reshapedoutputTensor = tensor::CollapseShapeOp::create(
378 rewriter, loc, reshapedOutputTensorType, transposedOutputTensor,
379 outputReassociationIndice);
381 auto batchMatVecResult = linalg::BatchMatvecOp::create(
383 ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
389 auto batchMatVecResultReshaped = tensor::ExpandShapeOp::create(
390 rewriter, loc, transposedOutputTensor.
getType(),
391 batchMatVecResult.getResult(0), batchMatVecReassociationIndice);
393 Value transposedResult =
394 transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
397 return std::make_pair(img2ColTensor.getOperation(),
401 FailureOr<std::pair<Operation *, Operation *>>
403 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
404 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
405 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
407 if (!filterType.hasStaticShape())
409 convOp,
"expected a static shape for the filter");
411 if (!inputType.hasStaticShape())
413 "expected a static shape for the input");
418 "expected all ones for dilations");
420 Value input = convOp.getInputs()[0];
421 Value filter = convOp.getInputs()[1];
422 Value output = convOp.getOutputs()[0];
424 auto filterShape = filterType.getShape();
425 auto outputShape = outputType.getShape();
427 int64_t n = outputShape[0];
428 int64_t oc = outputShape[1];
429 int64_t oh = outputShape[2];
430 int64_t ow = outputShape[3];
431 int64_t ic = filterShape[1];
432 int64_t fh = filterShape[2];
433 int64_t fw = filterShape[3];
435 auto loc = convOp.
getLoc();
439 auto reshapedFilterType =
441 Value reshapedFilter = tensor::CollapseShapeOp::create(
442 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
445 auto reshapedOutputType =
447 Value reshapedOutput = tensor::CollapseShapeOp::create(
448 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
452 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
453 inputType.getElementType());
455 auto nloops = colTensorShape.size();
457 auto parallel = utils::IteratorType::parallel;
458 auto reduction = utils::IteratorType::reduction;
469 i2cToOperExprs.
icIndex = kIndicesExprs[0];
470 i2cToOperExprs.
fhIndex = kIndicesExprs[1];
471 i2cToOperExprs.
fwIndex = kIndicesExprs[2];
472 i2cToOperExprs.
ohIndex = mIndicesExprs[0];
473 i2cToOperExprs.
owIndex = mIndicesExprs[1];
475 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
485 auto img2ColTensor = linalg::GenericOp::create(
486 rewriter, loc, colTensor.
getType(),
487 input, colTensor, img2colIndexingMaps,
490 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
498 bindDims(context, bDim, mDim, nDim, kDim);
501 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
503 parallel, reduction};
504 auto genericOp = linalg::GenericOp::create(
505 rewriter, loc, reshapedOutputType,
506 ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
513 linalg::YieldOp::create(nestedBuilder, nestedLoc,
add);
515 Value result = genericOp.getResults().front();
517 auto reshapedResult = tensor::ExpandShapeOp::create(
518 rewriter, loc, outputType, result, outputReassocIndices);
522 return std::make_pair(img2ColTensor.getOperation(),
523 reshapedResult.getOperation());
526 FailureOr<std::pair<Operation *, Operation *>>
528 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
529 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
530 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
532 if (!filterType.hasStaticShape())
534 convOp,
"expected a static shape for the filter");
536 if (!inputType.hasStaticShape())
538 "expected a static shape for the input");
543 "expected all ones for dilations");
546 Value input = convOp.getInputs()[0];
547 Value filter = convOp.getInputs()[1];
548 Value output = convOp.getOutputs()[0];
553 int64_t n = outputShape[0];
554 int64_t oh = outputShape[1];
555 int64_t ow = outputShape[2];
556 int64_t oc = outputShape[3];
557 int64_t fh = filterShape[1];
558 int64_t fw = filterShape[2];
559 int64_t ic = filterShape[3];
566 auto reshapedFilterType =
568 Value reshapedFilter = tensor::CollapseShapeOp::create(
569 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
572 RankedTensorType reshapedOutputType =
574 Value reshapedOutput = tensor::CollapseShapeOp::create(
575 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
579 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
580 inputType.getElementType());
583 auto nloops = colTensorShape.size();
585 auto parallel = utils::IteratorType::parallel;
586 auto reduction = utils::IteratorType::reduction;
596 i2cToOperExprs.
fhIndex = kIndicesExprs[0];
597 i2cToOperExprs.
fwIndex = kIndicesExprs[1];
598 i2cToOperExprs.
icIndex = kIndicesExprs[2];
599 i2cToOperExprs.
ohIndex = mIndicesExprs[0];
600 i2cToOperExprs.
owIndex = mIndicesExprs[1];
604 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()),
613 auto img2ColTensor = linalg::GenericOp::create(
614 rewriter, loc, colTensor.
getType(),
615 input, colTensor, img2colIndexingMaps,
618 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
625 bindDims(context, bDim, mDim, nDim, kDim);
628 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
630 parallel, reduction};
632 auto genericOp = linalg::GenericOp::create(
633 rewriter, loc, reshapedOutputType,
634 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
641 linalg::YieldOp::create(nestedBuilder, nestedLoc,
add);
643 Value result = genericOp.getResults().front();
645 auto reshapedResult = tensor::ExpandShapeOp::create(
646 rewriter, loc, outputType, result, outputReassocIndices);
650 return std::make_pair(img2ColTensor.getOperation(),
651 reshapedResult.getOperation());
656 class ConvertConv2DNhwcHwcf final
661 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
669 class ConvertDepthwiseConv2DNhwcHwc final
670 :
public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
674 LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
675 PatternRewriter &rewriter)
const override {
682 class ConvertConv2DNchwFchw final
683 :
public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
687 LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
688 PatternRewriter &rewriter)
const override {
695 class ConvertConv2DNhwcFhwc final
696 :
public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
700 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
701 PatternRewriter &rewriter)
const override {
711 patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
712 ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
Base type for affine expression.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static Im2ColToInputDimsExprs getIm2ColInputExpressions(Im2ColToOperandsExprs exprs, ArrayRef< int64_t > strides, RewriterBase &rewriter)
Construct the affine expressions that map the indices of the im2col matrix to the corresponding input...
static bool hasAllOneValues(DenseIntElementsAttr attr)
static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride, bool useSymbols=true)
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...