29 attr, [](
const APInt &element) {
return element.getSExtValue() == 1; });
33 if (isa<IntegerType>(x.
getType()))
34 return builder.
create<arith::AddIOp>(loc, x, y);
35 if (isa<ComplexType>(x.
getType()))
36 return builder.
create<complex::AddOp>(loc, x, y);
37 return builder.
create<arith::AddFOp>(loc, x, y);
47 if (isa<ComplexType>(accType))
48 return builder.
create<complex::MulOp>(loc, xConvert, yConvert);
49 if (isa<IntegerType>(accType))
50 return builder.
create<arith::MulIOp>(loc, xConvert, yConvert);
51 return builder.
create<arith::MulFOp>(loc, xConvert, yConvert);
57 assert(!factors.empty() &&
"empty factor list");
59 for (int64_t f : factors)
61 FailureOr<SmallVector<Value>> multiIndex =
63 assert(!failed(multiIndex) &&
"Failed to linearize img2col index");
71 Value fIndex, int64_t stride) {
78 FailureOr<std::pair<Operation *, Operation *>>
80 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
81 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
82 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
84 if (!filterType.hasStaticShape())
86 convOp,
"expected a static shape for the filter");
88 if (!inputType.hasStaticShape())
90 "expected a static shape for the input");
95 "expected all ones for dilations");
98 Value input = convOp.getInputs()[0];
99 Value filter = convOp.getInputs()[1];
100 Value output = convOp.getOutputs()[0];
105 int64_t n = outputShape[0];
106 int64_t oh = outputShape[1];
107 int64_t ow = outputShape[2];
108 int64_t oc = outputShape[3];
109 int64_t fh = filterShape[0];
110 int64_t fw = filterShape[1];
111 int64_t ic = filterShape[2];
117 auto reshapedFilterType =
119 Value reshapedFilter = rewriter.
create<tensor::CollapseShapeOp>(
120 loc, reshapedFilterType, filter, filterReassocIndices);
123 RankedTensorType reshapedOutputType =
125 Value reshapedOutput = rewriter.
create<tensor::CollapseShapeOp>(
126 loc, reshapedOutputType, output, outputReassocIndices);
129 Value colTensor = rewriter.
create<tensor::EmptyOp>(
130 loc, colTensorShape, inputType.getElementType());
133 auto nloops = colTensorShape.size();
135 auto parallel = utils::IteratorType::parallel;
136 auto reduction = utils::IteratorType::reduction;
142 auto img2ColTensor = rewriter.
create<linalg::GenericOp>(
148 Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
149 Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
150 Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
155 auto ohIndex = mIndices[0];
156 auto owIndex = mIndices[1];
160 auto fhIndex = kIndices[0];
161 auto fwIndex = kIndices[1];
162 auto icIndex = kIndices[2];
167 convOp.getStrides().getValues<int64_t>()[0]);
170 convOp.getStrides().getValues<int64_t>()[1]);
174 Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
175 loc, input, extractionIndices);
176 nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
184 bindDims(context, bDim, mDim, nDim, kDim);
187 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
189 parallel, reduction};
191 auto genericOp = rewriter.
create<linalg::GenericOp>(
192 loc, reshapedOutputType,
193 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
200 nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
202 Value result = genericOp.getResults().front();
204 auto reshapedResult = rewriter.
create<tensor::ExpandShapeOp>(
205 loc, outputType, result, outputReassocIndices);
209 return std::make_pair(img2ColTensor.getOperation(),
210 reshapedResult.getOperation());
213 FailureOr<std::pair<Operation *, Operation *>>
215 linalg::DepthwiseConv2DNhwcHwcOp convOp) {
216 auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
217 auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
218 auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
220 if (!filterType.hasStaticShape())
222 convOp,
"expected a static shape for the filter");
224 if (!inputType.hasStaticShape())
226 "expected a static shape for the input");
231 "expected all ones for dilations");
236 auto operandTensorType = cast<RankedTensorType>(operand.
getType());
237 auto nloops = indices.size();
241 llvm::map_range(indices, [&](int64_t index) ->
AffineExpr {
246 indices, [&](int64_t index) -> int64_t {
return inputShape[index]; }));
248 Value outputTensor = rewriter.
create<tensor::EmptyOp>(
249 loc, targetShape, operandTensorType.getElementType());
252 nloops, utils::IteratorType::parallel);
259 auto transposedOp = rewriter.
create<linalg::GenericOp>(
261 operand, outputTensor, indexingMaps,
264 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
270 Value input = convOp.getInputs()[0];
271 Value filter = convOp.getInputs()[1];
272 Value output = convOp.getOutputs()[0];
275 Value inputT = transposeOperand(input, {0, 3, 1, 2});
276 Value filterT = transposeOperand(filter, {2, 0, 1});
278 cast<RankedTensorType>(filterT.getType()).getShape();
281 int n = outputShape[0];
282 int oh = outputShape[1];
283 int ow = outputShape[2];
284 int c = outputShape[3];
285 int fh = filterTShape[1];
286 int fw = filterTShape[2];
289 Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
291 AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
295 convOp.getStrides().getValues<int64_t>()[0]);
297 convOp.getStrides().getValues<int64_t>()[1]);
300 owDim * swSym + kwDim};
302 auto nloops = colTensorShape.size();
305 nloops, utils::IteratorType::parallel);
311 Value colTensor = rewriter.
create<tensor::EmptyOp>(
312 loc, colTensorShape, inputType.getElementType());
314 auto img2ColTensor = rewriter.
create<linalg::GenericOp>(
316 inputT, colTensor, indexingMaps,
319 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
323 {0, 1}, {2, 3}, {4, 5}};
329 {n * c, oh * ow, fh * fw}, inputType.getElementType());
330 auto reshapedFilterTensorType =
332 auto reshapedOutputTensorType =
335 Value reshapedImg2ColTensor = rewriter.
create<tensor::CollapseShapeOp>(
336 loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
337 img2ColTensorReassocIndices);
338 Value reshapedFilterTensor = rewriter.
create<tensor::CollapseShapeOp>(
339 loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
340 Value reshapedoutputTensor = rewriter.
create<tensor::CollapseShapeOp>(
341 loc, reshapedOutputTensorType, transposedOutputTensor,
342 outputReassociationIndice);
344 auto batchMatVecResult = rewriter.
create<linalg::BatchMatvecOp>(
345 loc,
TypeRange{reshapedoutputTensor.getType()},
346 ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
352 auto batchMatVecResultReshaped = rewriter.
create<tensor::ExpandShapeOp>(
353 loc, transposedOutputTensor.
getType(), batchMatVecResult.getResult(0),
354 batchMatVecReassociationIndice);
356 Value transposedResult =
357 transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
360 return std::make_pair(img2ColTensor.getOperation(),
361 transposedResult.getDefiningOp());
364 FailureOr<std::pair<Operation *, Operation *>>
366 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
367 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
368 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
370 if (!filterType.hasStaticShape())
372 convOp,
"expected a static shape for the filter");
374 if (!inputType.hasStaticShape())
376 "expected a static shape for the input");
381 "expected all ones for dilations");
383 Value input = convOp.getInputs()[0];
384 Value filter = convOp.getInputs()[1];
385 Value output = convOp.getOutputs()[0];
387 auto filterShape = filterType.getShape();
388 auto outputShape = outputType.getShape();
390 int64_t n = outputShape[0];
391 int64_t oc = outputShape[1];
392 int64_t oh = outputShape[2];
393 int64_t ow = outputShape[3];
394 int64_t ic = filterShape[1];
395 int64_t fh = filterShape[2];
396 int64_t fw = filterShape[3];
398 auto loc = convOp.
getLoc();
402 auto reshapedFilterType =
404 Value reshapedFilter = rewriter.
create<tensor::CollapseShapeOp>(
405 loc, reshapedFilterType, filter, filterReassocIndices);
408 auto reshapedOutputType =
410 Value reshapedOutput = rewriter.
create<tensor::CollapseShapeOp>(
411 loc, reshapedOutputType, output, outputReassocIndices);
415 Value colTensor = rewriter.
create<tensor::EmptyOp>(
416 loc, colTensorShape, inputType.getElementType());
418 auto nloops = colTensorShape.size();
420 auto parallel = utils::IteratorType::parallel;
421 auto reduction = utils::IteratorType::reduction;
427 auto img2ColTensor = rewriter.
create<linalg::GenericOp>(
433 Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
434 Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
435 Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
440 auto icIndex = kIndices[0];
441 auto fhIndex = kIndices[1];
442 auto fwIndex = kIndices[2];
446 auto ohIndex = nIndices[0];
447 auto owIndex = nIndices[1];
452 convOp.getStrides().getValues<int64_t>()[0]);
455 convOp.getStrides().getValues<int64_t>()[1]);
459 Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
460 loc, input, extractionIndices);
461 nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
469 bindDims(context, bDim, mDim, nDim, kDim);
472 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
474 parallel, reduction};
475 auto genericOp = rewriter.
create<linalg::GenericOp>(
476 loc, reshapedOutputType,
477 ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
484 nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
486 Value result = genericOp.getResults().front();
488 auto reshapedResult = rewriter.
create<tensor::ExpandShapeOp>(
489 loc, outputType, result, outputReassocIndices);
493 return std::make_pair(img2ColTensor.getOperation(),
494 reshapedResult.getOperation());
497 FailureOr<std::pair<Operation *, Operation *>>
499 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
500 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
501 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
503 if (!filterType.hasStaticShape())
505 convOp,
"expected a static shape for the filter");
507 if (!inputType.hasStaticShape())
509 "expected a static shape for the input");
514 "expected all ones for dilations");
517 Value input = convOp.getInputs()[0];
518 Value filter = convOp.getInputs()[1];
519 Value output = convOp.getOutputs()[0];
524 int64_t n = outputShape[0];
525 int64_t oh = outputShape[1];
526 int64_t ow = outputShape[2];
527 int64_t oc = outputShape[3];
528 int64_t fh = filterShape[1];
529 int64_t fw = filterShape[2];
530 int64_t ic = filterShape[3];
537 auto reshapedFilterType =
539 Value reshapedFilter = rewriter.
create<tensor::CollapseShapeOp>(
540 loc, reshapedFilterType, filter, filterReassocIndices);
543 RankedTensorType reshapedOutputType =
545 Value reshapedOutput = rewriter.
create<tensor::CollapseShapeOp>(
546 loc, reshapedOutputType, output, outputReassocIndices);
549 Value colTensor = rewriter.
create<tensor::EmptyOp>(
550 loc, colTensorShape, inputType.getElementType());
553 auto nloops = colTensorShape.size();
555 auto parallel = utils::IteratorType::parallel;
556 auto reduction = utils::IteratorType::reduction;
562 auto img2ColTensor = rewriter.
create<linalg::GenericOp>(
568 Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
569 Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
570 Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
575 auto ohIndex = mIndices[0];
576 auto owIndex = mIndices[1];
580 auto fhIndex = kIndices[0];
581 auto fwIndex = kIndices[1];
582 auto icIndex = kIndices[2];
587 convOp.getStrides().getValues<int64_t>()[0]);
590 convOp.getStrides().getValues<int64_t>()[1]);
594 Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
595 loc, input, extractionIndices);
596 nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
603 bindDims(context, bDim, mDim, nDim, kDim);
606 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
608 parallel, reduction};
610 auto genericOp = rewriter.
create<linalg::GenericOp>(
611 loc, reshapedOutputType,
612 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
619 nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
621 Value result = genericOp.getResults().front();
623 auto reshapedResult = rewriter.
create<tensor::ExpandShapeOp>(
624 loc, outputType, result, outputReassocIndices);
628 return std::make_pair(img2ColTensor.getOperation(),
629 reshapedResult.getOperation());
634 class ConvertConv2DNhwcHwcf final
639 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
647 class ConvertDepthwiseConv2DNhwcHwc final
648 :
public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
652 LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
653 PatternRewriter &rewriter)
const override {
660 class ConvertConv2DNchwFchw final
661 :
public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
665 LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
666 PatternRewriter &rewriter)
const override {
673 class ConvertConv2DNhwcFhwc final
674 :
public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
678 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
679 PatternRewriter &rewriter)
const override {
689 patterns.
insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
690 ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
static SmallVector< Value > unrollIndex(OpBuilder &b, Location loc, Value index, ArrayRef< int64_t > factors)
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static bool hasAllOneValues(DenseIntElementsAttr attr)
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, Value fIndex, int64_t stride)
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...