21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 #include "llvm/ADT/SmallVector.h"
30 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
39 for (
auto &opOperand : linalgOp->getOpOperands()) {
40 if (llvm::is_contained(droppedOperands, &opOperand))
42 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
44 if (indexingMaps.empty()) {
47 return linalgOp.getNumLoops() == 0;
58 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
62 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
64 auto mapRange = linalgOp.getIndexingMapsArray();
65 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
66 !mapRange.back().isIdentity()) {
70 return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
84 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
85 if (!iface || !iface.hasNoEffect())
95 llvm::raw_ostream &errs) {
97 errs <<
"no terminator in the block";
102 errs <<
"expected block with 3 arguments";
108 errs <<
"expected terminator with 1 operand";
115 errs <<
"expected reduction op to be binary";
124 errs <<
"expected reduction to take block argument #2 as one of the "
125 "operands (modulo unary casts)";
130 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
134 errs <<
"expected elementwise op to be binary";
138 if (!isaPair(elementwiseOp, reductionOp)) {
139 errs <<
"expected reduction/elementwise op kind not satisfied";
152 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
159 template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
161 static_assert(
sizeof...(Args) % 2 == 0,
162 "expected an even number of template arguments");
163 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
166 if constexpr (
sizeof...(Args) > 0)
174 template <
typename... Args>
186 static llvm::SmallDenseSet<int64_t>
188 utils::IteratorType iter) {
189 llvm::SmallDenseSet<int64_t> res;
190 assert(linalgOp == opOperand->
getOwner() &&
"expected linalgOp owner");
191 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
193 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
194 if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
196 return e.isFunctionOfDim(d.getPosition());
198 res.insert(d.getPosition());
205 auto par = utils::IteratorType::parallel;
206 auto red = utils::IteratorType::reduction;
222 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
226 linalgOp, linalgOp.getDpsInputOperand(0), par);
228 linalgOp, linalgOp.getDpsInputOperand(1), par);
230 linalgOp, linalgOp.getDpsInitOperand(0), par);
233 llvm::SmallDenseSet<int64_t> ac = a;
234 llvm::set_intersect(ac, c);
235 llvm::set_subtract(ac, b);
237 llvm::SmallDenseSet<int64_t> bc = b;
238 llvm::set_intersect(bc, c);
239 llvm::set_subtract(bc, a);
241 llvm::SmallDenseSet<int64_t> batches = a;
242 llvm::set_intersect(batches, b);
243 llvm::set_intersect(batches, c);
247 linalgOp, linalgOp.getDpsInputOperand(0), red);
249 linalgOp, linalgOp.getDpsInputOperand(1), red);
250 llvm::set_intersect(ra, rb);
258 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
259 llvm::sort(dimensions.m.begin(), dimensions.m.end());
260 llvm::sort(dimensions.n.begin(), dimensions.n.end());
261 llvm::sort(dimensions.k.begin(), dimensions.k.end());
279 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
281 return MatchContractionResult::NotLinalgOp;
282 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
283 return MatchContractionResult::WrongNumOperands;
284 auto mapRange = linalgOp.getIndexingMapsArray();
285 if (linalgOp.getNumReductionLoops() == 0)
286 return MatchContractionResult::NoReduction;
287 if (llvm::any_of(mapRange,
289 return MatchContractionResult::NotProjectedPermutations;
293 arith::MulFOp, arith::AddFOp,
294 arith::MulIOp, arith::AddIOp,
295 complex::MulOp, complex::AddOp,
296 arith::AndIOp, arith::OrIOp>(
297 *linalgOp.getBlock())) {
298 return MatchContractionResult::NotAddMul;
304 assert(
succeeded(res) &&
"unexpected failure to infer contraction dims");
307 return MatchContractionResult::Success;
313 case MatchContractionResult::NotLinalgOp:
314 return "expected a LinalgOp";
315 case MatchContractionResult::WrongNumOperands:
316 return "expected op with 2 inputs and 1 output";
317 case MatchContractionResult::NoReduction:
318 return "expected at least 1 reduction";
319 case MatchContractionResult::NotProjectedPermutations:
320 return "expected indexing maps to be projected permutations";
321 case MatchContractionResult::NotAddMul:
322 return "expected add/mul op in the body";
323 case MatchContractionResult::Success:
326 llvm_unreachable(
"unhandled MatchContractionResult case");
333 return isa<ContractionOpInterface>(op) ||
353 if (res != MatchContractionResult::Success)
364 template <
typename T>
366 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) :
nullptr);
378 struct ConvAccessExprWalker
381 llvm::SmallDenseSet<int64_t> convolvedDims;
383 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
385 llvm::SmallDenseSet<int64_t> unConvolvedDims;
387 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
392 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
394 return e.isFunctionOfDim(dimPos);
396 convolvedDims.erase(dimPos);
397 unConvolvedDims.erase(dimPos);
400 if (convolvedDimMapping.contains(dimPos)) {
401 int64_t pairedDim = convolvedDimMapping[dimPos];
402 convolvedDims.erase(pairedDim);
403 unConvolvedDims.erase(pairedDim);
404 strideAndDilationMapping.erase(pairedDim);
405 convolvedDimMapping.erase(dimPos);
406 convolvedDimMapping.erase(pairedDim);
414 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
417 unConvolvedDims.insert(position);
429 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
430 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
433 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
434 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
439 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
441 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
444 strideAndDilationMapping[dim] =
446 convolvedDims.insert(dim);
449 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
452 auto lhsExpr = symbolMulExpr.getLHS();
453 auto rhsExpr = symbolMulExpr.getRHS();
456 getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
459 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
461 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
462 if (!mulExpr || !dimExpr)
465 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
467 strideAndDilationMapping[dim] = mulExpr;
468 convolvedDims.insert(dim);
478 "expected map to have projected permutations");
479 llvm::SmallDenseSet<int64_t> preservedDims;
481 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
482 return preservedDims;
488 for (
auto e : exprs) {
489 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
490 assert(constantExpr &&
"Found non-constant stride/dilation");
491 vals.push_back(constantExpr.getValue());
505 ConvAccessExprWalker &inputExprWalker,
506 bool allowEmptyConvolvedDims) {
508 linalgOp, linalgOp.getDpsInputOperand(1), par);
510 linalgOp, linalgOp.getDpsInitOperand(0), par);
513 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
514 llvm::set_intersect(batch, outputDims);
515 llvm::set_subtract(batch, filterDims);
518 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
519 llvm::set_intersect(oi, outputDims);
522 llvm::SmallDenseSet<int64_t> oc = filterDims;
523 llvm::set_intersect(oc, outputDims);
524 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
527 llvm::SmallDenseSet<int64_t> depth = filterDims;
528 llvm::set_intersect(depth, outputDims);
529 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
531 llvm::SmallDenseSet<int64_t> filterReducedDims =
536 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
537 llvm::set_intersect(fl, filterReducedDims);
540 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
541 llvm::set_intersect(ic, filterReducedDims);
543 if (oi.empty() && !allowEmptyConvolvedDims)
556 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
557 llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
558 llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
559 llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
560 llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
561 llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
565 if (!nativeStrides) {
567 for (
unsigned oiDim : dimensions.outputImage)
568 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
571 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
573 auto nativeDilations =
575 if (!nativeDilations) {
577 for (
unsigned flDim : dimensions.filterLoop)
578 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
581 dimensions.dilations =
582 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
613 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
616 auto indexingMaps = linalgOp.getIndexingMapsArray();
619 ConvAccessExprWalker inputExprWalker;
620 for (
AffineExpr expr : indexingMaps[0].getResults())
621 (void)inputExprWalker.visit(expr);
622 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
644 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
646 return MatchConvolutionResult::NotLinalgOp;
647 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
648 return MatchConvolutionResult::WrongNumOperands;
650 auto indexingMaps = linalgOp.getIndexingMapsArray();
653 ConvAccessExprWalker inputExprWalker;
654 if (llvm::any_of(indexingMaps[0].getResults(),
656 return failed(inputExprWalker.visit(expr));
658 return MatchConvolutionResult::WrongInputIndexingMap;
662 if (!indexingMaps[1].isProjectedPermutation() ||
663 !indexingMaps.back().isProjectedPermutation())
664 return MatchConvolutionResult::NotProjectedPermutations;
666 auto iteratorTypes = linalgOp.getIteratorTypesArray();
668 llvm::SmallDenseSet<int64_t> outputDims =
670 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
684 llvm::SmallDenseSet<int64_t> allLoopDims;
685 for (
auto outputExpr : indexingMaps.back().getResults()) {
686 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
687 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
688 !filterDims.count(outputDim)) {
690 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
691 return MatchConvolutionResult::OutputDimsNotParallel;
692 allLoopDims.insert(outputDim);
695 if (inputExprWalker.convolvedDims.count(outputDim) &&
696 !filterDims.count(outputDim)) {
698 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
699 return MatchConvolutionResult::OutputDimsNotParallel;
700 allLoopDims.insert(outputDim);
703 if (!inputExprWalker.convolvedDims.count(outputDim) &&
704 !inputExprWalker.unConvolvedDims.count(outputDim) &&
705 filterDims.count(outputDim)) {
707 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
708 return MatchConvolutionResult::OutputDimsNotParallel;
709 allLoopDims.insert(outputDim);
712 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
713 filterDims.count(outputDim)) {
715 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
716 return MatchConvolutionResult::OutputDimsNotParallel;
717 allLoopDims.insert(outputDim);
720 return MatchConvolutionResult::NonConvolutionLoop;
722 for (
auto filterExpr : indexingMaps[1].getResults()) {
723 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
724 if (outputDims.count(filterDim) &&
725 !inputExprWalker.unConvolvedDims.count(filterDim) &&
726 !inputExprWalker.convolvedDims.count(filterDim)) {
730 if (inputExprWalker.convolvedDims.count(filterDim) &&
731 !outputDims.count(filterDim)) {
733 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
734 return MatchConvolutionResult::NonOutputDimNotReduction;
735 if (allLoopDims.count(filterDim))
736 return MatchConvolutionResult::NonConvolutionLoop;
737 allLoopDims.insert(filterDim);
740 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
741 !outputDims.count(filterDim)) {
743 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
744 return MatchConvolutionResult::NonOutputDimNotReduction;
745 if (allLoopDims.count(filterDim))
746 return MatchConvolutionResult::NonConvolutionLoop;
747 allLoopDims.insert(filterDim);
750 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
751 outputDims.count(filterDim)) {
755 return MatchConvolutionResult::NonConvolutionLoop;
758 if (allLoopDims.size() != linalgOp.getNumLoops())
759 return MatchConvolutionResult::NonConvolutionLoop;
765 assert(
succeeded(res) &&
"unexpected failure to infer convolution dims");
769 return MatchConvolutionResult::Success;
775 case MatchConvolutionResult::NotLinalgOp:
776 return "expected a LinalgOp";
777 case MatchConvolutionResult::WrongNumOperands:
778 return "expected op with 2 inputs and 1 output";
779 case MatchConvolutionResult::WrongInputIndexingMap:
780 return "unexpected input index map for convolutions";
781 case MatchConvolutionResult::NotProjectedPermutations:
782 return "expected output/filter indexing maps to be projected permutations";
783 case MatchConvolutionResult::NonConvolutionLoop:
784 return "unexpected loop dimension for convolution op";
785 case MatchConvolutionResult::OutputDimsNotParallel:
786 return "expected all iterators used to access outputs to be parallel";
787 case MatchConvolutionResult::NonOutputDimNotReduction:
788 return "expected all iterators not used to access outputs to be reduction";
789 case MatchConvolutionResult::Success:
792 llvm_unreachable(
"unhandled MatchConvolutionResult case");
802 if (res != MatchConvolutionResult::Success)
819 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
822 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
825 OpOperand *value = linalgOp.getDpsInputOperand(0);
826 if (!linalgOp.isScalar(value))
835 return op->
emitError(
"expected a LinalgOp");
837 return op->
emitError(
"expected op with 1 input and 1 output");
839 return op->
emitError(
"expected op with scalar input");
851 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
852 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
860 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
861 for (
OpOperand &opOperand : getOperation()->getOpOperands())
862 llvm::append_range(res,
getShape(&opOperand));
869 auto viewSizes = createFlatListOfOperandDims(b, loc);
871 for (
unsigned idx = 0; idx < numRes; ++idx) {
873 if (
auto d = dyn_cast<AffineDimExpr>(result)) {
874 if (res[d.getPosition()].offset)
876 res[d.getPosition()] =
888 for (
unsigned idx = 0; idx < numRes; ++idx) {
890 if (
auto d = dyn_cast<AffineDimExpr>(result))
891 res[d.getPosition()] = allShapeSizes[idx];
901 : positions(std::move(positions)) {}
916 llvm::SmallBitVector positions;
919 static std::pair<int64_t, int64_t>
921 int64_t inputRankSum = 0;
922 int64_t outputRankSum = 0;
923 for (
OpOperand *input : op.getDpsInputOperands())
924 inputRankSum += op.getRank(input);
925 for (
OpOperand &output : op.getDpsInitsMutable())
926 outputRankSum += op.getRank(&output);
927 return {inputRankSum, inputRankSum + outputRankSum};
942 AffineMap loopsToShapesMap = getLoopsToShapesMap();
951 resultShapesSubMapPos.first,
952 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
953 AffineMap resultShapesFromInputShapesMap =
954 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
958 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
959 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
961 Location loc = getOperation()->getLoc();
965 rewriter, loc, resultShapesFromInputShapesMap,
966 createFlatListOfOperandDims(b, loc));
969 for (
OpOperand &opOperand : getDpsInitsMutable()) {
971 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
972 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
973 if (!shapedType.isDynamicDim(dim)) {
975 shapes.push_back(b.
getIndexAttr(shapedType.getDimSize(dim)));
980 : allResultDimValues[pos];
985 reifiedReturnShapes.emplace_back(std::move(shapes));
992 int64_t LinalgOp::getIndexingMapIndex(
OpOperand *opOperand) {
994 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
995 if (!dpsIface.isDpsInput(opOperand))
996 return operandNumber;
997 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
998 assert(!dpsIface.isDpsInit(opOperand));
1001 return cast<DestinationStyleOpInterface>(*this->getOperation())
1002 .getNumDpsInputs() +
1003 operandNumber - start;
1007 LinalgOp linalgOp = cast<LinalgOp>(op);
1011 if (linalgOp.hasDynamicIndexingMaps())
1012 if (
failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1016 if (
static_cast<int64_t
>(linalgOp.getIndexingMapsArray().size()) !=
1017 linalgOp->getNumOperands())
1018 return op->
emitOpError(
"expected the number of indexing_map (")
1019 << linalgOp.getIndexingMapsArray().size()
1020 <<
") to be equal to the number of input/output operands ("
1021 << linalgOp->getNumOperands() <<
")";
1023 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1024 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1028 return op->
emitOpError(
"unexpected symbols in indexing_map #")
1032 unsigned numLoops = linalgOp.getNumLoops();
1036 <<
" dim(s) to match the number of loops";
1038 int64_t rank = linalgOp.getRank(&opOperand);
1041 << rank <<
") to match the result rank of indexing_map #"
1047 linalgOp.getReductionDims(redDims);
1049 if (!linalgOp.getShapesToLoopsMap())
1050 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1058 if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1059 for (int64_t &range : endLoopRangeValues)
1061 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1062 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1064 indexingMap.
compose(startLoopRangeValues);
1066 indexingMap.
compose(endLoopRangeValues);
1068 for (
auto dim : llvm::seq<int64_t>(0, shape.size())) {
1070 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1083 int64_t inferredDimSize =
1084 std::max(startIndices[dim], endIndices[dim]) + 1;
1085 if (
std::min(startIndices[dim], endIndices[dim]) < 0) {
1088 llvm::raw_string_ostream os(mapStr);
1092 "unexpected result less than 0 at expression #")
1093 << dim <<
" in " << mapStr;
1095 if (dyn_cast<AffineDimExpr>(indexingMap.
getResult(dim))) {
1096 if (inferredDimSize != shape[dim]) {
1097 return op->
emitOpError(
"inferred input/output operand #")
1099 << dim <<
" to be " << inferredDimSize <<
", but found "
1103 if (inferredDimSize > shape[dim]) {
1104 return op->
emitOpError(
"inferred input/output operand #")
1106 << dim <<
" to be greater than or equal to "
1107 << inferredDimSize <<
", but found " << shape[dim];
1115 if (linalgOp->getNumRegions() != 1 ||
1116 !llvm::hasSingleElement(linalgOp->getRegion(0)))
1117 return op->
emitOpError(
"expects to have 1 region with 1 block");
1125 Block &block = linalgOp->getRegion(0).
front();
1127 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1128 return op->
emitOpError(
"expected as many non-induction variable region "
1129 "arguments as the number of input/output operands");
1131 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1133 if (isa<MemRefType, RankedTensorType>(elementType))
1136 if (elementType != argType)
1137 return op->
emitOpError(
"expected type of bb argument #")
1139 <<
" to match element or self type of the corresponding operand ("
1140 << elementType <<
")";
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static SmallVector< int64_t, 2 > getConstantsFromExprList(SmallVector< AffineExpr, 2 > exprs)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter)
Given a linalgOp and one of its opOperand, returns the positions of the iterators of type iter that i...
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense integer vector or tensor object.
This class provides support for representing a failure result, or a valid value of type T.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaConvolutionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ConvolutionOpInterface.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
This class represents an efficient way to signal success or failure.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.