28 attr, [](
const APInt &element) {
return element.getSExtValue() == 1; });
32 if (isa<IntegerType>(x.
getType()))
33 return builder.
create<arith::AddIOp>(loc, x, y);
34 if (isa<ComplexType>(x.
getType()))
35 return builder.
create<complex::AddOp>(loc, x, y);
36 return builder.
create<arith::AddFOp>(loc, x, y);
46 if (isa<ComplexType>(accType))
47 return builder.
create<complex::MulOp>(loc, xConvert, yConvert);
48 if (isa<IntegerType>(accType))
49 return builder.
create<arith::MulIOp>(loc, xConvert, yConvert);
50 return builder.
create<arith::MulFOp>(loc, xConvert, yConvert);
56 assert(!factors.empty() &&
"empty factor list");
58 for (int64_t f : factors)
60 FailureOr<SmallVector<Value>> multiIndex =
62 assert(!failed(multiIndex) &&
"Failed to linearize img2col index");
70 Value fIndex, int64_t stride) {
77 FailureOr<std::pair<Operation *, Operation *>>
79 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
80 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
81 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
83 if (!filterType.hasStaticShape())
85 convOp,
"expected a static shape for the filter");
87 if (!inputType.hasStaticShape())
89 "expected a static shape for the input");
94 "expected all ones for dilations");
97 Value input = convOp.getInputs()[0];
98 Value filter = convOp.getInputs()[1];
99 Value output = convOp.getOutputs()[0];
104 int64_t n = outputShape[0];
105 int64_t oh = outputShape[1];
106 int64_t ow = outputShape[2];
107 int64_t oc = outputShape[3];
108 int64_t fh = filterShape[0];
109 int64_t fw = filterShape[1];
110 int64_t ic = filterShape[2];
116 auto reshapedFilterType =
118 Value reshapedFilter = rewriter.
create<tensor::CollapseShapeOp>(
119 loc, reshapedFilterType, filter, filterReassocIndices);
122 RankedTensorType reshapedOutputType =
124 Value reshapedOutput = rewriter.
create<tensor::CollapseShapeOp>(
125 loc, reshapedOutputType, output, outputReassocIndices);
128 Value colTensor = rewriter.
create<tensor::EmptyOp>(
129 loc, colTensorShape, inputType.getElementType());
132 auto nloops = colTensorShape.size();
134 auto parallel = utils::IteratorType::parallel;
135 auto reduction = utils::IteratorType::reduction;
141 auto img2ColTensor = rewriter.
create<linalg::GenericOp>(
147 Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
148 Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
149 Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
154 auto ohIndex = mIndices[0];
155 auto owIndex = mIndices[1];
159 auto fhIndex = kIndices[0];
160 auto fwIndex = kIndices[1];
161 auto icIndex = kIndices[2];
166 convOp.getStrides().getValues<int64_t>()[0]);
169 convOp.getStrides().getValues<int64_t>()[1]);
173 Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
174 loc, input, extractionIndices);
175 nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
183 bindDims(context, bDim, mDim, nDim, kDim);
186 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
188 parallel, reduction};
190 auto genericOp = rewriter.
create<linalg::GenericOp>(
191 loc, reshapedOutputType,
192 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
199 nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
201 Value result = genericOp.getResults().front();
203 auto reshapedResult = rewriter.
create<tensor::ExpandShapeOp>(
204 loc, outputType, result, outputReassocIndices);
208 return std::make_pair(img2ColTensor.getOperation(),
209 reshapedResult.getOperation());
212 FailureOr<std::pair<Operation *, Operation *>>
214 linalg::DepthwiseConv2DNhwcHwcOp convOp) {
215 auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
216 auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
217 auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
219 if (!filterType.hasStaticShape())
221 convOp,
"expected a static shape for the filter");
223 if (!inputType.hasStaticShape())
225 "expected a static shape for the input");
230 "expected all ones for dilations");
235 auto operandTensorType = cast<RankedTensorType>(operand.
getType());
236 auto nloops = indices.size();
240 llvm::map_range(indices, [&](int64_t index) ->
AffineExpr {
245 indices, [&](int64_t index) -> int64_t {
return inputShape[index]; }));
247 Value outputTensor = rewriter.
create<tensor::EmptyOp>(
248 loc, targetShape, operandTensorType.getElementType());
251 nloops, utils::IteratorType::parallel);
258 auto transposedOp = rewriter.
create<linalg::GenericOp>(
260 operand, outputTensor, indexingMaps,
263 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
269 Value input = convOp.getInputs()[0];
270 Value filter = convOp.getInputs()[1];
271 Value output = convOp.getOutputs()[0];
274 Value inputT = transposeOperand(input, {0, 3, 1, 2});
275 Value filterT = transposeOperand(filter, {2, 0, 1});
277 cast<RankedTensorType>(filterT.getType()).getShape();
280 int n = outputShape[0];
281 int oh = outputShape[1];
282 int ow = outputShape[2];
283 int c = outputShape[3];
284 int fh = filterTShape[1];
285 int fw = filterTShape[2];
288 Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
290 AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
294 convOp.getStrides().getValues<int64_t>()[0]);
296 convOp.getStrides().getValues<int64_t>()[1]);
299 owDim * swSym + kwDim};
301 auto nloops = colTensorShape.size();
304 nloops, utils::IteratorType::parallel);
310 Value colTensor = rewriter.
create<tensor::EmptyOp>(
311 loc, colTensorShape, inputType.getElementType());
313 auto img2ColTensor = rewriter.
create<linalg::GenericOp>(
315 inputT, colTensor, indexingMaps,
318 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
322 {0, 1}, {2, 3}, {4, 5}};
328 {n * c, oh * ow, fh * fw}, inputType.getElementType());
329 auto reshapedFilterTensorType =
331 auto reshapedOutputTensorType =
334 Value reshapedImg2ColTensor = rewriter.
create<tensor::CollapseShapeOp>(
335 loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
336 img2ColTensorReassocIndices);
337 Value reshapedFilterTensor = rewriter.
create<tensor::CollapseShapeOp>(
338 loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
339 Value reshapedoutputTensor = rewriter.
create<tensor::CollapseShapeOp>(
340 loc, reshapedOutputTensorType, transposedOutputTensor,
341 outputReassociationIndice);
343 auto batchMatVecResult = rewriter.
create<linalg::BatchMatvecOp>(
344 loc,
TypeRange{reshapedoutputTensor.getType()},
345 ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
351 auto batchMatVecResultReshaped = rewriter.
create<tensor::ExpandShapeOp>(
352 loc, transposedOutputTensor.
getType(), batchMatVecResult.getResult(0),
353 batchMatVecReassociationIndice);
355 Value transposedResult =
356 transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
359 return std::make_pair(img2ColTensor.getOperation(),
360 transposedResult.getDefiningOp());
363 FailureOr<std::pair<Operation *, Operation *>>
365 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
366 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
367 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
369 if (!filterType.hasStaticShape())
371 convOp,
"expected a static shape for the filter");
373 if (!inputType.hasStaticShape())
375 "expected a static shape for the input");
380 "expected all ones for dilations");
382 Value input = convOp.getInputs()[0];
383 Value filter = convOp.getInputs()[1];
384 Value output = convOp.getOutputs()[0];
386 auto filterShape = filterType.getShape();
387 auto outputShape = outputType.getShape();
389 int64_t n = outputShape[0];
390 int64_t oc = outputShape[1];
391 int64_t oh = outputShape[2];
392 int64_t ow = outputShape[3];
393 int64_t ic = filterShape[1];
394 int64_t fh = filterShape[2];
395 int64_t fw = filterShape[3];
397 auto loc = convOp.
getLoc();
401 auto reshapedFilterType =
403 Value reshapedFilter = rewriter.
create<tensor::CollapseShapeOp>(
404 loc, reshapedFilterType, filter, filterReassocIndices);
407 auto reshapedOutputType =
409 Value reshapedOutput = rewriter.
create<tensor::CollapseShapeOp>(
410 loc, reshapedOutputType, output, outputReassocIndices);
414 Value colTensor = rewriter.
create<tensor::EmptyOp>(
415 loc, colTensorShape, inputType.getElementType());
417 auto nloops = colTensorShape.size();
419 auto parallel = utils::IteratorType::parallel;
420 auto reduction = utils::IteratorType::reduction;
426 auto img2ColTensor = rewriter.
create<linalg::GenericOp>(
432 Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
433 Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
434 Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
439 auto icIndex = kIndices[0];
440 auto fhIndex = kIndices[1];
441 auto fwIndex = kIndices[2];
445 auto ohIndex = nIndices[0];
446 auto owIndex = nIndices[1];
451 convOp.getStrides().getValues<int64_t>()[0]);
454 convOp.getStrides().getValues<int64_t>()[1]);
458 Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
459 loc, input, extractionIndices);
460 nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
468 bindDims(context, bDim, mDim, nDim, kDim);
471 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
473 parallel, reduction};
474 auto genericOp = rewriter.
create<linalg::GenericOp>(
475 loc, reshapedOutputType,
476 ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
483 nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
485 Value result = genericOp.getResults().front();
487 auto reshapedResult = rewriter.
create<tensor::ExpandShapeOp>(
488 loc, outputType, result, outputReassocIndices);
492 return std::make_pair(img2ColTensor.getOperation(),
493 reshapedResult.getOperation());
496 FailureOr<std::pair<Operation *, Operation *>>
498 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
499 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
500 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
502 if (!filterType.hasStaticShape())
504 convOp,
"expected a static shape for the filter");
506 if (!inputType.hasStaticShape())
508 "expected a static shape for the input");
513 "expected all ones for dilations");
516 Value input = convOp.getInputs()[0];
517 Value filter = convOp.getInputs()[1];
518 Value output = convOp.getOutputs()[0];
523 int64_t n = outputShape[0];
524 int64_t oh = outputShape[1];
525 int64_t ow = outputShape[2];
526 int64_t oc = outputShape[3];
527 int64_t fh = filterShape[1];
528 int64_t fw = filterShape[2];
529 int64_t ic = filterShape[3];
536 auto reshapedFilterType =
538 Value reshapedFilter = rewriter.
create<tensor::CollapseShapeOp>(
539 loc, reshapedFilterType, filter, filterReassocIndices);
542 RankedTensorType reshapedOutputType =
544 Value reshapedOutput = rewriter.
create<tensor::CollapseShapeOp>(
545 loc, reshapedOutputType, output, outputReassocIndices);
548 Value colTensor = rewriter.
create<tensor::EmptyOp>(
549 loc, colTensorShape, inputType.getElementType());
552 auto nloops = colTensorShape.size();
554 auto parallel = utils::IteratorType::parallel;
555 auto reduction = utils::IteratorType::reduction;
561 auto img2ColTensor = rewriter.
create<linalg::GenericOp>(
567 Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
568 Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
569 Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
574 auto ohIndex = mIndices[0];
575 auto owIndex = mIndices[1];
579 auto fhIndex = kIndices[0];
580 auto fwIndex = kIndices[1];
581 auto icIndex = kIndices[2];
586 convOp.getStrides().getValues<int64_t>()[0]);
589 convOp.getStrides().getValues<int64_t>()[1]);
593 Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
594 loc, input, extractionIndices);
595 nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
602 bindDims(context, bDim, mDim, nDim, kDim);
605 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
607 parallel, reduction};
609 auto genericOp = rewriter.
create<linalg::GenericOp>(
610 loc, reshapedOutputType,
611 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
618 nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
620 Value result = genericOp.getResults().front();
622 auto reshapedResult = rewriter.
create<tensor::ExpandShapeOp>(
623 loc, outputType, result, outputReassocIndices);
627 return std::make_pair(img2ColTensor.getOperation(),
628 reshapedResult.getOperation());
633 class ConvertConv2DNhwcHwcf final
638 LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
646 class ConvertDepthwiseConv2DNhwcHwc final
647 :
public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
651 LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
652 PatternRewriter &rewriter)
const override {
659 class ConvertConv2DNchwFchw final
660 :
public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
664 LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
665 PatternRewriter &rewriter)
const override {
672 class ConvertConv2DNhwcFhwc final
673 :
public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
677 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
678 PatternRewriter &rewriter)
const override {
688 patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
689 ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
static SmallVector< Value > unrollIndex(OpBuilder &b, Location loc, Value index, ArrayRef< int64_t > factors)
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns)
Populates patterns to transform linalg.conv_2d_xxx operations into linalg.generic (for img2col packin...
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder)
static Value createMul(Location loc, Value x, Value y, Type accType, OpBuilder &builder)
static bool hasAllOneValues(DenseIntElementsAttr attr)
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, Value fIndex, int64_t stride)
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...