24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SetOperations.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/raw_ostream.h"
38 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
47 for (
auto &opOperand : linalgOp->getOpOperands()) {
48 if (llvm::is_contained(droppedOperands, &opOperand))
50 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
52 if (indexingMaps.empty()) {
55 return linalgOp.getNumLoops() == 0;
58 indexingMaps, linalgOp.getContext())) !=
AffineMap();
67 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
70 auto mapRange = op.getIndexingMapsArray();
71 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
72 !mapRange.back().isIdentity()) {
76 return llvm::hasSingleElement(op.getBlock()->getOperations());
84 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
85 !op.isSingleYieldOp())
89 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
90 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
93 OpOperand *value = op.getDpsInputOperand(0);
94 if (!op.isScalar(value))
102 std::optional<SmallVector<int64_t>>
105 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
106 !op.isSingleYieldOp())
109 auto srcTy = op.getDpsInputOperand(0)->get().getType();
110 auto dstTy = op.getDpsInitOperand(0)->get().getType();
111 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
112 !isa<MemRefType, RankedTensorType>(dstTy))
118 auto dstMap = op.getIndexingMapsArray()[1];
119 if (!dstMap.isIdentity())
123 auto srcMap = op.getIndexingMapsArray()[0];
125 if (srcMap.getResults().size() >= dstMap.getResults().size())
129 for (
unsigned i = 0; i < srcMap.getNumResults(); ++i) {
130 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
133 int64_t pos = expr.getPosition();
134 if (i > 0 && pos <= position[i - 1])
136 position.push_back(expr.getPosition());
140 auto numDims = srcMap.getNumDims();
142 for (
auto dim : llvm::seq<int64_t>(0, numDims)) {
143 if (!llvm::is_contained(position, dim))
144 broadcastedDims.push_back(dim);
146 return broadcastedDims;
152 std::optional<SmallVector<int64_t>>
157 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
158 !op.isSingleYieldOp())
161 auto mapRange = op.getIndexingMapsArray();
162 if (mapRange.size() != 2)
165 auto mapOfInput = mapRange.front();
166 auto mapOfResult = mapRange.back();
170 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
174 for (
unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
175 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
176 permutation[expr.getPosition()] = i;
187 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
191 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
192 !llvm::all_of(op.getIndexingMapsArray(),
193 [](
AffineMap map) { return map.isIdentity(); }))
197 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
204 Block *body = op.getBody();
212 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
213 if (!yieldOp || yieldOp.getNumOperands() != 1 ||
214 yieldOp->getOperand(0).getDefiningOp() != oper)
225 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
235 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
236 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
237 if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
238 !op.payloadUsesValueFromOperand(inputOpOperand1))
254 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
255 if (!iface || !iface.hasNoEffect())
265 llvm::raw_ostream &errs) {
267 errs <<
"no terminator in the block";
272 errs <<
"expected block with 3 arguments";
278 errs <<
"expected terminator with 1 operand";
285 errs <<
"expected reduction op to be binary";
294 errs <<
"expected reduction to take block argument #2 as one of the "
295 "operands (modulo unary casts)";
300 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
304 errs <<
"expected elementwise op to be binary";
308 if (!isaPair(elementwiseOp, reductionOp)) {
309 errs <<
"expected reduction/elementwise op kind not satisfied";
322 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
329 template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
331 static_assert(
sizeof...(Args) % 2 == 0,
332 "expected an even number of template arguments");
333 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
336 if constexpr (
sizeof...(Args) > 0)
344 template <
typename... Args>
356 static llvm::SmallDenseSet<int64_t>
359 utils::IteratorType iter) {
360 assert(iterators.size() == indexingMap.
getNumDims());
361 llvm::SmallDenseSet<int64_t> res;
363 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
364 if (iterators[d.getPosition()] == iter &&
366 return e.isFunctionOfDim(d.getPosition());
368 res.insert(d.getPosition());
375 auto par = utils::IteratorType::parallel;
376 auto red = utils::IteratorType::reduction;
383 static FailureOr<SmallVector<utils::IteratorType>>
389 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
390 iterators[dim.getPosition()] = par;
405 static FailureOr<ContractionDimensions>
408 llvm::SmallDenseSet<int64_t> a =
410 llvm::SmallDenseSet<int64_t> b =
412 llvm::SmallDenseSet<int64_t> c =
416 llvm::SmallDenseSet<int64_t> ac = a;
417 llvm::set_intersect(ac, c);
418 llvm::set_subtract(ac, b);
420 llvm::SmallDenseSet<int64_t> bc = b;
421 llvm::set_intersect(bc, c);
422 llvm::set_subtract(bc, a);
424 llvm::SmallDenseSet<int64_t> batches = a;
425 llvm::set_intersect(batches, b);
426 llvm::set_intersect(batches, c);
429 llvm::SmallDenseSet<int64_t> ra =
431 llvm::SmallDenseSet<int64_t> rb =
433 llvm::set_intersect(ra, rb);
441 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
442 llvm::sort(dimensions.m.begin(), dimensions.m.end());
443 llvm::sort(dimensions.n.begin(), dimensions.n.end());
444 llvm::sort(dimensions.k.begin(), dimensions.k.end());
448 FailureOr<ContractionDimensions>
450 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
453 linalgOp.getIteratorTypesArray());
456 FailureOr<ContractionDimensions>
458 if (indexingMaps.size() != 3)
461 if (failed(iterators))
480 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
482 return MatchContractionResult::NotLinalgOp;
483 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
484 return MatchContractionResult::WrongNumOperands;
485 auto mapRange = linalgOp.getIndexingMapsArray();
486 if (linalgOp.getNumReductionLoops() == 0)
487 return MatchContractionResult::NoReduction;
488 if (llvm::any_of(mapRange,
489 [](
AffineMap m) {
return !m.isProjectedPermutation(); }))
490 return MatchContractionResult::NotProjectedPermutations;
494 arith::MulFOp, arith::AddFOp,
495 arith::MulIOp, arith::AddIOp,
496 complex::MulOp, complex::AddOp,
497 arith::AndIOp, arith::OrIOp>(
498 *linalgOp.getBlock())) {
499 return MatchContractionResult::NotAddMul;
505 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
508 return MatchContractionResult::Success;
514 case MatchContractionResult::NotLinalgOp:
515 return "expected a LinalgOp";
516 case MatchContractionResult::WrongNumOperands:
517 return "expected op with 2 inputs and 1 output";
518 case MatchContractionResult::NoReduction:
519 return "expected at least 1 reduction";
520 case MatchContractionResult::NotProjectedPermutations:
521 return "expected indexing maps to be projected permutations";
522 case MatchContractionResult::NotAddMul:
523 return "expected add/mul op in the body";
524 case MatchContractionResult::Success:
527 llvm_unreachable(
"unhandled MatchContractionResult case");
534 return isa<ContractionOpInterface>(op) ||
554 if (res != MatchContractionResult::Success)
565 template <
typename T>
567 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) :
nullptr);
579 struct ConvAccessExprWalker
582 llvm::SmallDenseSet<int64_t> convolvedDims;
584 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
586 llvm::SmallDenseSet<int64_t> unConvolvedDims;
588 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
593 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
595 return e.isFunctionOfDim(dimPos);
597 convolvedDims.erase(dimPos);
598 unConvolvedDims.erase(dimPos);
601 auto it = convolvedDimMapping.find(dimPos);
602 if (it != convolvedDimMapping.end()) {
603 int64_t pairedDim = it->second;
604 convolvedDims.erase(pairedDim);
605 unConvolvedDims.erase(pairedDim);
606 strideAndDilationMapping.erase(pairedDim);
607 convolvedDimMapping.erase(dimPos);
608 convolvedDimMapping.erase(pairedDim);
616 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
619 unConvolvedDims.insert(position);
631 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
632 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
633 if (failed(lhsDimPos) || failed(rhsDimPos))
635 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
636 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
640 FailureOr<int64_t> getDimExprOrMulExprDimPos(
AffineExpr expr) {
641 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
643 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
646 strideAndDilationMapping[dim] =
648 convolvedDims.insert(dim);
651 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
654 auto lhsExpr = symbolMulExpr.getLHS();
655 auto rhsExpr = symbolMulExpr.getRHS();
658 getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
661 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
663 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
664 if (!mulExpr || !dimExpr)
667 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
669 strideAndDilationMapping[dim] = mulExpr;
670 convolvedDims.insert(dim);
680 "expected map to have projected permutations");
681 llvm::SmallDenseSet<int64_t> preservedDims;
683 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
684 return preservedDims;
690 for (
auto e : exprs) {
691 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
692 assert(constantExpr &&
"Found non-constant stride/dilation");
693 vals.push_back(constantExpr.getValue());
705 static FailureOr<ConvolutionDimensions>
707 ConvAccessExprWalker &inputExprWalker,
708 bool allowEmptyConvolvedDims) {
710 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
712 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
714 filterMap, linalgOp.getIteratorTypesArray(), par);
716 outputMap, linalgOp.getIteratorTypesArray(), par);
719 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
720 llvm::set_intersect(batch, outputDims);
721 llvm::set_subtract(batch, filterDims);
724 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
725 llvm::set_intersect(oi, outputDims);
728 llvm::SmallDenseSet<int64_t> oc = filterDims;
729 llvm::set_intersect(oc, outputDims);
730 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
733 llvm::SmallDenseSet<int64_t> depth = filterDims;
734 llvm::set_intersect(depth, outputDims);
735 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
737 llvm::SmallDenseSet<int64_t> filterReducedDims =
739 linalgOp.getIteratorTypesArray(), red);
742 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
743 llvm::set_intersect(fl, filterReducedDims);
746 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
747 llvm::set_intersect(ic, filterReducedDims);
749 if (oi.empty() && !allowEmptyConvolvedDims)
762 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
763 llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
764 llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
765 llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
766 llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
767 llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
771 if (!nativeStrides) {
773 for (
unsigned oiDim : dimensions.outputImage)
774 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
777 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
779 auto nativeDilations =
781 if (!nativeDilations) {
783 for (
unsigned flDim : dimensions.filterLoop)
784 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
787 dimensions.dilations =
788 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
817 FailureOr<ConvolutionDimensions>
819 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
822 auto indexingMaps = linalgOp.getIndexingMapsArray();
825 ConvAccessExprWalker inputExprWalker;
826 for (
AffineExpr expr : indexingMaps[0].getResults())
827 (void)inputExprWalker.visit(expr);
828 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
851 bool allowEmptyConvolvedDims) {
852 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
854 return MatchConvolutionResult::NotLinalgOp;
855 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
856 return MatchConvolutionResult::WrongNumOperands;
858 auto indexingMaps = linalgOp.getIndexingMapsArray();
861 ConvAccessExprWalker inputExprWalker;
862 if (llvm::any_of(indexingMaps[0].getResults(),
864 return failed(inputExprWalker.visit(expr));
866 return MatchConvolutionResult::WrongInputIndexingMap;
870 if (!indexingMaps[1].isProjectedPermutation() ||
871 !indexingMaps.back().isProjectedPermutation())
872 return MatchConvolutionResult::NotProjectedPermutations;
874 auto iteratorTypes = linalgOp.getIteratorTypesArray();
876 llvm::SmallDenseSet<int64_t> outputDims =
878 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
892 llvm::SmallDenseSet<int64_t> allLoopDims;
893 for (
auto outputExpr : indexingMaps.back().getResults()) {
894 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
895 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
896 !filterDims.count(outputDim)) {
898 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
899 return MatchConvolutionResult::OutputDimsNotParallel;
900 allLoopDims.insert(outputDim);
903 if (inputExprWalker.convolvedDims.count(outputDim) &&
904 !filterDims.count(outputDim)) {
906 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
907 return MatchConvolutionResult::OutputDimsNotParallel;
908 allLoopDims.insert(outputDim);
911 if (!inputExprWalker.convolvedDims.count(outputDim) &&
912 !inputExprWalker.unConvolvedDims.count(outputDim) &&
913 filterDims.count(outputDim)) {
915 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
916 return MatchConvolutionResult::OutputDimsNotParallel;
917 allLoopDims.insert(outputDim);
920 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
921 filterDims.count(outputDim)) {
923 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
924 return MatchConvolutionResult::OutputDimsNotParallel;
925 allLoopDims.insert(outputDim);
928 return MatchConvolutionResult::NonConvolutionLoop;
930 for (
auto filterExpr : indexingMaps[1].getResults()) {
931 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
932 if (outputDims.count(filterDim) &&
933 !inputExprWalker.unConvolvedDims.count(filterDim) &&
934 !inputExprWalker.convolvedDims.count(filterDim)) {
938 if (inputExprWalker.convolvedDims.count(filterDim) &&
939 !outputDims.count(filterDim)) {
941 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
942 return MatchConvolutionResult::NonOutputDimNotReduction;
943 if (allLoopDims.count(filterDim))
944 return MatchConvolutionResult::NonConvolutionLoop;
945 allLoopDims.insert(filterDim);
948 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
949 !outputDims.count(filterDim)) {
951 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
952 return MatchConvolutionResult::NonOutputDimNotReduction;
953 if (allLoopDims.count(filterDim))
954 return MatchConvolutionResult::NonConvolutionLoop;
955 allLoopDims.insert(filterDim);
958 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
959 outputDims.count(filterDim)) {
963 return MatchConvolutionResult::NonConvolutionLoop;
966 if (allLoopDims.size() != linalgOp.getNumLoops())
967 return MatchConvolutionResult::NonConvolutionLoop;
969 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
970 return MatchConvolutionResult::EmptyConvolvedDims;
974 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
975 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
979 return MatchConvolutionResult::Success;
985 case MatchConvolutionResult::NotLinalgOp:
986 return "expected a LinalgOp";
987 case MatchConvolutionResult::WrongNumOperands:
988 return "expected op with 2 inputs and 1 output";
989 case MatchConvolutionResult::WrongInputIndexingMap:
990 return "unexpected input index map for convolutions";
991 case MatchConvolutionResult::NotProjectedPermutations:
992 return "expected output/filter indexing maps to be projected permutations";
993 case MatchConvolutionResult::NonConvolutionLoop:
994 return "unexpected loop dimension for convolution op";
995 case MatchConvolutionResult::OutputDimsNotParallel:
996 return "expected all iterators used to access outputs to be parallel";
997 case MatchConvolutionResult::NonOutputDimNotReduction:
998 return "expected all iterators not used to access outputs to be reduction";
999 case MatchConvolutionResult::EmptyConvolvedDims:
1000 return "expected convolved dim to be non-empty";
1001 case MatchConvolutionResult::Success:
1004 llvm_unreachable(
"unhandled MatchConvolutionResult case");
1008 bool allowEmptyConvolvedDims) {
1010 linalgOp.getOperation(),
nullptr, allowEmptyConvolvedDims) ==
1016 if (res != MatchConvolutionResult::Success)
1033 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1036 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1039 OpOperand *value = linalgOp.getDpsInputOperand(0);
1040 if (!linalgOp.isScalar(value))
1049 return op->
emitError(
"expected a LinalgOp");
1051 return op->
emitError(
"expected op with 1 input and 1 output");
1053 return op->
emitError(
"expected op with scalar input");
1065 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
1066 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1074 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
1075 for (
OpOperand &opOperand : getOperation()->getOpOperands())
1076 llvm::append_range(res,
getShape(&opOperand));
1083 auto viewSizes = createFlatListOfOperandDims(b, loc);
1085 for (
unsigned idx = 0; idx < numRes; ++idx) {
1087 if (
auto d = dyn_cast<AffineDimExpr>(result)) {
1088 if (res[d.getPosition()].offset)
1090 res[d.getPosition()] =
1102 for (
unsigned idx = 0; idx < numRes; ++idx) {
1104 if (
auto d = dyn_cast<AffineDimExpr>(result))
1105 res[d.getPosition()] = allShapeSizes[idx];
1115 : positions(std::move(positions)) {}
1130 llvm::SmallBitVector positions;
1133 static std::pair<int64_t, int64_t>
1135 int64_t inputRankSum = 0;
1136 int64_t outputRankSum = 0;
1137 for (
OpOperand *input : op.getDpsInputOperands())
1138 inputRankSum += op.getRank(input);
1139 for (
OpOperand &output : op.getDpsInitsMutable())
1140 outputRankSum += op.getRank(&output);
1141 return {inputRankSum, inputRankSum + outputRankSum};
1156 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1165 resultShapesSubMapPos.first,
1166 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1167 AffineMap resultShapesFromInputShapesMap =
1168 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1172 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1173 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1175 Location loc = getOperation()->getLoc();
1179 rewriter, loc, resultShapesFromInputShapesMap,
1180 createFlatListOfOperandDims(b, loc));
1183 for (
OpOperand &opOperand : getDpsInitsMutable()) {
1185 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1186 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1187 if (!shapedType.isDynamicDim(dim)) {
1189 shapes.push_back(b.
getIndexAttr(shapedType.getDimSize(dim)));
1194 : allResultDimValues[pos];
1199 reifiedReturnShapes.emplace_back(std::move(shapes));
1206 int64_t LinalgOp::getIndexingMapIndex(
OpOperand *opOperand) {
1208 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1209 if (!dpsIface.isDpsInput(opOperand))
1210 return operandNumber;
1211 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1212 assert(!dpsIface.isDpsInit(opOperand));
1215 return cast<DestinationStyleOpInterface>(*this->getOperation())
1216 .getNumDpsInputs() +
1217 operandNumber - start;
1221 LinalgOp linalgOp = cast<LinalgOp>(op);
1223 if (!linalgOp.hasPureTensorSemantics() &&
1225 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1229 if (linalgOp.hasDynamicIndexingMaps())
1230 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1234 if (
static_cast<int64_t
>(linalgOp.getIndexingMapsArray().size()) !=
1235 linalgOp->getNumOperands())
1236 return op->
emitOpError(
"expected the number of indexing_map (")
1237 << linalgOp.getIndexingMapsArray().size()
1238 <<
") to be equal to the number of input/output operands ("
1239 << linalgOp->getNumOperands() <<
")";
1243 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1244 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1248 return op->
emitOpError(
"unexpected symbols in indexing_map #")
1252 unsigned numLoops = linalgOp.getNumLoops();
1256 <<
" dim(s) to match the number of loops";
1258 int64_t rank = linalgOp.getRank(&opOperand);
1262 << rank <<
") to match the result rank of indexing_map #"
1267 linalgOp.getReductionDims(redDims);
1269 if (!linalgOp.getShapesToLoopsMap())
1270 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1277 if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1278 for (int64_t &range : endLoopRangeValues)
1280 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1281 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1283 indexingMap.
compose(startLoopRangeValues);
1285 indexingMap.
compose(endLoopRangeValues);
1287 for (
auto dim : llvm::seq<int64_t>(0, shape.size())) {
1289 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1302 int64_t inferredDimSize =
1303 std::max(startIndices[dim], endIndices[dim]) + 1;
1304 if (
std::min(startIndices[dim], endIndices[dim]) < 0) {
1307 llvm::raw_string_ostream os(mapStr);
1311 "unexpected result less than 0 at expression #")
1312 << dim <<
" in " << mapStr;
1314 if (dyn_cast<AffineDimExpr>(indexingMap.
getResult(dim))) {
1315 if (inferredDimSize != shape[dim]) {
1316 return op->
emitOpError(
"inferred input/output operand #")
1318 << dim <<
" to be " << inferredDimSize <<
", but found "
1322 if (inferredDimSize > shape[dim]) {
1323 return op->
emitOpError(
"inferred input/output operand #")
1325 << dim <<
" to be greater than or equal to "
1326 << inferredDimSize <<
", but found " << shape[dim];
1334 if (linalgOp->getNumRegions() != 1 ||
1335 !llvm::hasSingleElement(linalgOp->getRegion(0)))
1336 return op->
emitOpError(
"expects to have 1 region with 1 block");
1344 Block &block = linalgOp->getRegion(0).
front();
1346 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1347 return op->
emitOpError(
"expected as many non-induction variable region "
1348 "arguments as the number of input/output operands");
1350 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1352 if (isa<MemRefType, RankedTensorType>(elementType))
1355 if (elementType != argType)
1356 return op->
emitOpError(
"expected type of bb argument #")
1358 <<
" to match element or self type of the corresponding operand ("
1359 << elementType <<
")";
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.