22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetOperations.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/raw_ostream.h"
36 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
45 for (
auto &opOperand : linalgOp->getOpOperands()) {
46 if (llvm::is_contained(droppedOperands, &opOperand))
48 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
50 if (indexingMaps.empty()) {
53 return linalgOp.getNumLoops() == 0;
56 indexingMaps, linalgOp.getContext())) !=
AffineMap();
65 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
68 auto mapRange = op.getIndexingMapsArray();
69 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
70 !mapRange.back().isIdentity()) {
74 return llvm::hasSingleElement(op.getBlock()->getOperations());
82 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
83 !op.isSingleYieldOp())
87 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
88 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
91 OpOperand *value = op.getDpsInputOperand(0);
92 if (!op.isScalar(value))
100 std::optional<SmallVector<int64_t>>
103 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
104 !op.isSingleYieldOp())
107 auto srcTy = op.getDpsInputOperand(0)->get().getType();
108 auto dstTy = op.getDpsInitOperand(0)->get().getType();
109 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
110 !isa<MemRefType, RankedTensorType>(dstTy))
116 auto dstMap = op.getIndexingMapsArray()[1];
117 if (!dstMap.isIdentity())
121 auto srcMap = op.getIndexingMapsArray()[0];
123 if (srcMap.getResults().size() >= dstMap.getResults().size())
127 for (
unsigned i = 0; i < srcMap.getNumResults(); ++i) {
128 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
131 int64_t pos = expr.getPosition();
132 if (i > 0 && pos <= position[i - 1])
134 position.push_back(expr.getPosition());
138 auto numDims = srcMap.getNumDims();
140 for (
auto dim : llvm::seq<int64_t>(0, numDims)) {
141 if (!llvm::is_contained(position, dim))
142 broadcastedDims.push_back(dim);
144 return broadcastedDims;
150 std::optional<SmallVector<int64_t>>
155 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
156 !op.isSingleYieldOp())
159 auto mapRange = op.getIndexingMapsArray();
160 if (mapRange.size() != 2)
163 auto mapOfInput = mapRange.front();
164 auto mapOfResult = mapRange.back();
168 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
172 for (
unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
173 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
174 permutation[expr.getPosition()] = i;
185 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
189 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
190 !llvm::all_of(op.getIndexingMapsArray(),
191 [](
AffineMap map) { return map.isIdentity(); }))
195 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
202 Block *body = op.getBody();
210 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
211 if (!yieldOp || yieldOp.getNumOperands() != 1 ||
212 yieldOp->getOperand(0).getDefiningOp() != oper)
223 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
233 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
234 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
235 if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
236 !op.payloadUsesValueFromOperand(inputOpOperand1))
252 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
253 if (!iface || !iface.hasNoEffect())
263 llvm::raw_ostream &errs) {
265 errs <<
"no terminator in the block";
270 errs <<
"expected block with 3 arguments";
276 errs <<
"expected terminator with 1 operand";
283 errs <<
"expected reduction op to be binary";
292 errs <<
"expected reduction to take block argument #2 as one of the "
293 "operands (modulo unary casts)";
298 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
302 errs <<
"expected elementwise op to be binary";
306 if (!isaPair(elementwiseOp, reductionOp)) {
307 errs <<
"expected reduction/elementwise op kind not satisfied";
320 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
327 template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
329 static_assert(
sizeof...(Args) % 2 == 0,
330 "expected an even number of template arguments");
331 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
334 if constexpr (
sizeof...(Args) > 0)
342 template <
typename... Args>
354 static llvm::SmallDenseSet<int64_t>
357 utils::IteratorType iter) {
358 assert(iterators.size() == indexingMap.
getNumDims());
359 llvm::SmallDenseSet<int64_t> res;
361 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
362 if (iterators[d.getPosition()] == iter &&
364 return e.isFunctionOfDim(d.getPosition());
366 res.insert(d.getPosition());
373 auto par = utils::IteratorType::parallel;
374 auto red = utils::IteratorType::reduction;
381 static FailureOr<SmallVector<utils::IteratorType>>
387 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
388 iterators[dim.getPosition()] = par;
403 static FailureOr<ContractionDimensions>
406 llvm::SmallDenseSet<int64_t> a =
408 llvm::SmallDenseSet<int64_t> b =
410 llvm::SmallDenseSet<int64_t> c =
414 llvm::SmallDenseSet<int64_t> ac = a;
415 llvm::set_intersect(ac, c);
416 llvm::set_subtract(ac, b);
418 llvm::SmallDenseSet<int64_t> bc = b;
419 llvm::set_intersect(bc, c);
420 llvm::set_subtract(bc, a);
422 llvm::SmallDenseSet<int64_t> batches = a;
423 llvm::set_intersect(batches, b);
424 llvm::set_intersect(batches, c);
427 llvm::SmallDenseSet<int64_t> ra =
429 llvm::SmallDenseSet<int64_t> rb =
431 llvm::set_intersect(ra, rb);
439 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
440 llvm::sort(dimensions.m.begin(), dimensions.m.end());
441 llvm::sort(dimensions.n.begin(), dimensions.n.end());
442 llvm::sort(dimensions.k.begin(), dimensions.k.end());
446 FailureOr<ContractionDimensions>
448 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
451 linalgOp.getIteratorTypesArray());
454 FailureOr<ContractionDimensions>
456 if (indexingMaps.size() != 3)
459 if (failed(iterators))
478 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
480 return MatchContractionResult::NotLinalgOp;
481 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
482 return MatchContractionResult::WrongNumOperands;
483 auto mapRange = linalgOp.getIndexingMapsArray();
484 if (linalgOp.getNumReductionLoops() == 0)
485 return MatchContractionResult::NoReduction;
486 if (llvm::any_of(mapRange,
487 [](
AffineMap m) {
return !m.isProjectedPermutation(); }))
488 return MatchContractionResult::NotProjectedPermutations;
492 arith::MulFOp, arith::AddFOp,
493 arith::MulIOp, arith::AddIOp,
494 complex::MulOp, complex::AddOp,
495 arith::AndIOp, arith::OrIOp>(
496 *linalgOp.getBlock())) {
497 return MatchContractionResult::NotAddMul;
503 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
506 return MatchContractionResult::Success;
512 case MatchContractionResult::NotLinalgOp:
513 return "expected a LinalgOp";
514 case MatchContractionResult::WrongNumOperands:
515 return "expected op with 2 inputs and 1 output";
516 case MatchContractionResult::NoReduction:
517 return "expected at least 1 reduction";
518 case MatchContractionResult::NotProjectedPermutations:
519 return "expected indexing maps to be projected permutations";
520 case MatchContractionResult::NotAddMul:
521 return "expected add/mul op in the body";
522 case MatchContractionResult::Success:
525 llvm_unreachable(
"unhandled MatchContractionResult case");
532 return isa<ContractionOpInterface>(op) ||
552 if (res != MatchContractionResult::Success)
563 template <
typename T>
565 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) :
nullptr);
577 struct ConvAccessExprWalker
580 llvm::SmallDenseSet<int64_t> convolvedDims;
582 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
584 llvm::SmallDenseSet<int64_t> unConvolvedDims;
586 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
591 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
593 return e.isFunctionOfDim(dimPos);
595 convolvedDims.erase(dimPos);
596 unConvolvedDims.erase(dimPos);
599 auto it = convolvedDimMapping.find(dimPos);
600 if (it != convolvedDimMapping.end()) {
601 int64_t pairedDim = it->second;
602 convolvedDims.erase(pairedDim);
603 unConvolvedDims.erase(pairedDim);
604 strideAndDilationMapping.erase(pairedDim);
605 convolvedDimMapping.erase(dimPos);
606 convolvedDimMapping.erase(pairedDim);
614 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
617 unConvolvedDims.insert(position);
629 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
630 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
631 if (failed(lhsDimPos) || failed(rhsDimPos))
633 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
634 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
638 FailureOr<int64_t> getDimExprOrMulExprDimPos(
AffineExpr expr) {
639 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
641 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
644 strideAndDilationMapping[dim] =
646 convolvedDims.insert(dim);
649 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
652 auto lhsExpr = symbolMulExpr.getLHS();
653 auto rhsExpr = symbolMulExpr.getRHS();
656 getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
659 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
661 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
662 if (!mulExpr || !dimExpr)
665 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
667 strideAndDilationMapping[dim] = mulExpr;
668 convolvedDims.insert(dim);
678 "expected map to have projected permutations");
679 llvm::SmallDenseSet<int64_t> preservedDims;
681 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
682 return preservedDims;
688 for (
auto e : exprs) {
689 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
690 assert(constantExpr &&
"Found non-constant stride/dilation");
691 vals.push_back(constantExpr.getValue());
703 static FailureOr<ConvolutionDimensions>
705 ConvAccessExprWalker &inputExprWalker,
706 bool allowEmptyConvolvedDims) {
708 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
710 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
712 filterMap, linalgOp.getIteratorTypesArray(), par);
714 outputMap, linalgOp.getIteratorTypesArray(), par);
717 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
718 llvm::set_intersect(batch, outputDims);
719 llvm::set_subtract(batch, filterDims);
722 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
723 llvm::set_intersect(oi, outputDims);
726 llvm::SmallDenseSet<int64_t> oc = filterDims;
727 llvm::set_intersect(oc, outputDims);
728 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
731 llvm::SmallDenseSet<int64_t> depth = filterDims;
732 llvm::set_intersect(depth, outputDims);
733 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
735 llvm::SmallDenseSet<int64_t> filterReducedDims =
737 linalgOp.getIteratorTypesArray(), red);
740 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
741 llvm::set_intersect(fl, filterReducedDims);
744 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
745 llvm::set_intersect(ic, filterReducedDims);
747 if (oi.empty() && !allowEmptyConvolvedDims)
760 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
761 llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
762 llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
763 llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
764 llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
765 llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
769 if (!nativeStrides) {
771 for (
unsigned oiDim : dimensions.outputImage)
772 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
775 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
777 auto nativeDilations =
779 if (!nativeDilations) {
781 for (
unsigned flDim : dimensions.filterLoop)
782 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
785 dimensions.dilations =
786 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
815 FailureOr<ConvolutionDimensions>
817 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
820 auto indexingMaps = linalgOp.getIndexingMapsArray();
823 ConvAccessExprWalker inputExprWalker;
824 for (
AffineExpr expr : indexingMaps[0].getResults())
825 (void)inputExprWalker.visit(expr);
826 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
849 bool allowEmptyConvolvedDims) {
850 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
852 return MatchConvolutionResult::NotLinalgOp;
853 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
854 return MatchConvolutionResult::WrongNumOperands;
856 auto indexingMaps = linalgOp.getIndexingMapsArray();
859 ConvAccessExprWalker inputExprWalker;
860 if (llvm::any_of(indexingMaps[0].getResults(),
862 return failed(inputExprWalker.visit(expr));
864 return MatchConvolutionResult::WrongInputIndexingMap;
868 if (!indexingMaps[1].isProjectedPermutation() ||
869 !indexingMaps.back().isProjectedPermutation())
870 return MatchConvolutionResult::NotProjectedPermutations;
872 auto iteratorTypes = linalgOp.getIteratorTypesArray();
874 llvm::SmallDenseSet<int64_t> outputDims =
876 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
890 llvm::SmallDenseSet<int64_t> allLoopDims;
891 for (
auto outputExpr : indexingMaps.back().getResults()) {
892 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
893 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
894 !filterDims.count(outputDim)) {
896 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
897 return MatchConvolutionResult::OutputDimsNotParallel;
898 allLoopDims.insert(outputDim);
901 if (inputExprWalker.convolvedDims.count(outputDim) &&
902 !filterDims.count(outputDim)) {
904 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
905 return MatchConvolutionResult::OutputDimsNotParallel;
906 allLoopDims.insert(outputDim);
909 if (!inputExprWalker.convolvedDims.count(outputDim) &&
910 !inputExprWalker.unConvolvedDims.count(outputDim) &&
911 filterDims.count(outputDim)) {
913 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
914 return MatchConvolutionResult::OutputDimsNotParallel;
915 allLoopDims.insert(outputDim);
918 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
919 filterDims.count(outputDim)) {
921 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
922 return MatchConvolutionResult::OutputDimsNotParallel;
923 allLoopDims.insert(outputDim);
926 return MatchConvolutionResult::NonConvolutionLoop;
928 for (
auto filterExpr : indexingMaps[1].getResults()) {
929 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
930 if (outputDims.count(filterDim) &&
931 !inputExprWalker.unConvolvedDims.count(filterDim) &&
932 !inputExprWalker.convolvedDims.count(filterDim)) {
936 if (inputExprWalker.convolvedDims.count(filterDim) &&
937 !outputDims.count(filterDim)) {
939 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
940 return MatchConvolutionResult::NonOutputDimNotReduction;
941 if (allLoopDims.count(filterDim))
942 return MatchConvolutionResult::NonConvolutionLoop;
943 allLoopDims.insert(filterDim);
946 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
947 !outputDims.count(filterDim)) {
949 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
950 return MatchConvolutionResult::NonOutputDimNotReduction;
951 if (allLoopDims.count(filterDim))
952 return MatchConvolutionResult::NonConvolutionLoop;
953 allLoopDims.insert(filterDim);
956 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
957 outputDims.count(filterDim)) {
961 return MatchConvolutionResult::NonConvolutionLoop;
964 if (allLoopDims.size() != linalgOp.getNumLoops())
965 return MatchConvolutionResult::NonConvolutionLoop;
967 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
968 return MatchConvolutionResult::EmptyConvolvedDims;
972 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
973 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
977 return MatchConvolutionResult::Success;
983 case MatchConvolutionResult::NotLinalgOp:
984 return "expected a LinalgOp";
985 case MatchConvolutionResult::WrongNumOperands:
986 return "expected op with 2 inputs and 1 output";
987 case MatchConvolutionResult::WrongInputIndexingMap:
988 return "unexpected input index map for convolutions";
989 case MatchConvolutionResult::NotProjectedPermutations:
990 return "expected output/filter indexing maps to be projected permutations";
991 case MatchConvolutionResult::NonConvolutionLoop:
992 return "unexpected loop dimension for convolution op";
993 case MatchConvolutionResult::OutputDimsNotParallel:
994 return "expected all iterators used to access outputs to be parallel";
995 case MatchConvolutionResult::NonOutputDimNotReduction:
996 return "expected all iterators not used to access outputs to be reduction";
997 case MatchConvolutionResult::EmptyConvolvedDims:
998 return "expected convolved dim to be non-empty";
999 case MatchConvolutionResult::Success:
1002 llvm_unreachable(
"unhandled MatchConvolutionResult case");
1006 bool allowEmptyConvolvedDims) {
1008 linalgOp.getOperation(),
nullptr, allowEmptyConvolvedDims) ==
1014 if (res != MatchConvolutionResult::Success)
1031 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1034 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1037 OpOperand *value = linalgOp.getDpsInputOperand(0);
1038 if (!linalgOp.isScalar(value))
1047 return op->
emitError(
"expected a LinalgOp");
1049 return op->
emitError(
"expected op with 1 input and 1 output");
1051 return op->
emitError(
"expected op with scalar input");
1063 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
1064 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1072 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
1073 for (
OpOperand &opOperand : getOperation()->getOpOperands())
1074 llvm::append_range(res,
getShape(&opOperand));
1081 auto viewSizes = createFlatListOfOperandDims(b, loc);
1083 for (
unsigned idx = 0; idx < numRes; ++idx) {
1085 if (
auto d = dyn_cast<AffineDimExpr>(result)) {
1086 if (res[d.getPosition()].offset)
1088 res[d.getPosition()] =
1100 : positions(std::move(positions)) {}
1115 llvm::SmallBitVector positions;
1118 static std::pair<int64_t, int64_t>
1120 int64_t inputRankSum = 0;
1121 int64_t outputRankSum = 0;
1122 for (
OpOperand *input : op.getDpsInputOperands())
1123 inputRankSum += op.getRank(input);
1124 for (
OpOperand &output : op.getDpsInitsMutable())
1125 outputRankSum += op.getRank(&output);
1126 return {inputRankSum, inputRankSum + outputRankSum};
1141 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1150 resultShapesSubMapPos.first,
1151 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1152 AffineMap resultShapesFromInputShapesMap =
1153 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1157 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1158 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1160 Location loc = getOperation()->getLoc();
1164 rewriter, loc, resultShapesFromInputShapesMap,
1165 createFlatListOfOperandDims(b, loc));
1168 for (
OpOperand &opOperand : getDpsInitsMutable()) {
1170 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1171 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1172 if (!shapedType.isDynamicDim(dim)) {
1174 shapes.push_back(b.
getIndexAttr(shapedType.getDimSize(dim)));
1179 : allResultDimValues[pos];
1184 reifiedReturnShapes.emplace_back(std::move(shapes));
1191 int64_t LinalgOp::getIndexingMapIndex(
OpOperand *opOperand) {
1193 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1194 if (!dpsIface.isDpsInput(opOperand))
1195 return operandNumber;
1196 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1197 assert(!dpsIface.isDpsInit(opOperand));
1200 return cast<DestinationStyleOpInterface>(*this->getOperation())
1201 .getNumDpsInputs() +
1202 operandNumber - start;
1206 LinalgOp linalgOp = cast<LinalgOp>(op);
1208 if (!linalgOp.hasPureTensorSemantics() &&
1210 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1214 if (linalgOp.hasDynamicIndexingMaps())
1215 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1219 if (
static_cast<int64_t
>(linalgOp.getIndexingMapsArray().size()) !=
1220 linalgOp->getNumOperands())
1221 return op->
emitOpError(
"expected the number of indexing_map (")
1222 << linalgOp.getIndexingMapsArray().size()
1223 <<
") to be equal to the number of input/output operands ("
1224 << linalgOp->getNumOperands() <<
")";
1228 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1229 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1233 return op->
emitOpError(
"unexpected symbols in indexing_map #")
1237 unsigned numLoops = linalgOp.getNumLoops();
1241 <<
" dim(s) to match the number of loops";
1243 int64_t rank = linalgOp.getRank(&opOperand);
1247 << rank <<
") to match the result rank of indexing_map #"
1252 linalgOp.getReductionDims(redDims);
1254 if (!linalgOp.getShapesToLoopsMap())
1255 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1262 if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1263 for (int64_t &range : endLoopRangeValues)
1265 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1266 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1268 indexingMap.
compose(startLoopRangeValues);
1270 indexingMap.
compose(endLoopRangeValues);
1272 for (
auto dim : llvm::seq<int64_t>(0, shape.size())) {
1274 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1287 int64_t inferredDimSize =
1288 std::max(startIndices[dim], endIndices[dim]) + 1;
1289 if (
std::min(startIndices[dim], endIndices[dim]) < 0) {
1292 llvm::raw_string_ostream os(mapStr);
1296 "unexpected result less than 0 at expression #")
1297 << dim <<
" in " << mapStr;
1299 if (dyn_cast<AffineDimExpr>(indexingMap.
getResult(dim))) {
1300 if (inferredDimSize != shape[dim]) {
1301 return op->
emitOpError(
"inferred input/output operand #")
1303 << dim <<
" to be " << inferredDimSize <<
", but found "
1307 if (inferredDimSize > shape[dim]) {
1308 return op->
emitOpError(
"inferred input/output operand #")
1310 << dim <<
" to be greater than or equal to "
1311 << inferredDimSize <<
", but found " << shape[dim];
1319 if (linalgOp->getNumRegions() != 1 ||
1320 !llvm::hasSingleElement(linalgOp->getRegion(0)))
1321 return op->
emitOpError(
"expects to have 1 region with 1 block");
1329 Block &block = linalgOp->getRegion(0).
front();
1331 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1332 return op->
emitOpError(
"expected as many non-induction variable region "
1333 "arguments as the number of input/output operands");
1335 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1337 if (isa<MemRefType, RankedTensorType>(elementType))
1340 if (elementType != argType)
1341 return op->
emitOpError(
"expected type of bb argument #")
1343 <<
" to match element or self type of the corresponding operand ("
1344 << elementType <<
")";
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.