22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetOperations.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/raw_ostream.h"
34 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
43 for (
auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
48 if (indexingMaps.empty()) {
51 return linalgOp.getNumLoops() == 0;
54 indexingMaps, linalgOp.getContext())) !=
AffineMap();
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
72 Block *body = op.getBlock();
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
78 return yieldOp->getOperand(0) == body->
getArgument(0);
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
96 Block *body = op.getBody();
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
104 Value yieldOperand = yieldOp->getOperand(0);
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
140 std::optional<SmallVector<int64_t>>
143 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
144 !op.isSingleYieldOp())
147 auto srcTy = op.getDpsInputOperand(0)->get().getType();
148 auto dstTy = op.getDpsInitOperand(0)->get().getType();
149 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
150 !isa<MemRefType, RankedTensorType>(dstTy))
156 auto dstMap = op.getIndexingMapsArray()[1];
157 if (!dstMap.isIdentity())
161 auto srcMap = op.getIndexingMapsArray()[0];
163 if (srcMap.getResults().size() >= dstMap.getResults().size())
167 for (
unsigned i = 0; i < srcMap.getNumResults(); ++i) {
168 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
171 int64_t pos = expr.getPosition();
172 if (i > 0 && pos <= position[i - 1])
174 position.push_back(expr.getPosition());
178 auto numDims = srcMap.getNumDims();
180 for (
auto dim : llvm::seq<int64_t>(0, numDims)) {
181 if (!llvm::is_contained(position, dim))
182 broadcastedDims.push_back(dim);
184 return broadcastedDims;
190 std::optional<SmallVector<int64_t>>
195 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
196 !op.isSingleYieldOp())
199 auto mapRange = op.getIndexingMapsArray();
200 if (mapRange.size() != 2)
203 auto mapOfInput = mapRange.front();
204 auto mapOfResult = mapRange.back();
208 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
212 for (
unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
213 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
214 permutation[expr.getPosition()] = i;
225 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
229 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
230 !llvm::all_of(op.getIndexingMapsArray(),
231 [](
AffineMap map) { return map.isIdentity(); }))
235 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
242 Block *body = op.getBody();
250 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
251 if (!yieldOp || yieldOp.getNumOperands() != 1 ||
252 yieldOp->getOperand(0).getDefiningOp() != oper)
263 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
273 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
274 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
275 if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
276 !op.payloadUsesValueFromOperand(inputOpOperand1))
292 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
293 if (!iface || !iface.hasNoEffect())
303 llvm::raw_ostream &errs) {
305 errs <<
"no terminator in the block";
310 errs <<
"expected block with 3 arguments";
316 errs <<
"expected terminator with 1 operand";
323 errs <<
"expected reduction op to be binary";
332 errs <<
"expected reduction to take block argument #2 as one of the "
333 "operands (modulo unary casts)";
338 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
342 errs <<
"expected elementwise op to be binary";
346 if (!isaPair(elementwiseOp, reductionOp)) {
347 errs <<
"expected reduction/elementwise op kind not satisfied";
360 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
367 template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
369 static_assert(
sizeof...(Args) % 2 == 0,
370 "expected an even number of template arguments");
371 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
374 if constexpr (
sizeof...(Args) > 0)
382 template <
typename... Args>
394 static llvm::SmallDenseSet<int64_t>
397 utils::IteratorType iter) {
398 assert(iterators.size() == indexingMap.
getNumDims());
399 llvm::SmallDenseSet<int64_t> res;
401 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
402 if (iterators[d.getPosition()] == iter &&
404 return e.isFunctionOfDim(d.getPosition());
406 res.insert(d.getPosition());
413 auto par = utils::IteratorType::parallel;
414 auto red = utils::IteratorType::reduction;
421 static FailureOr<SmallVector<utils::IteratorType>>
427 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
428 iterators[dim.getPosition()] = par;
443 static FailureOr<ContractionDimensions>
446 llvm::SmallDenseSet<int64_t> a =
448 llvm::SmallDenseSet<int64_t> b =
450 llvm::SmallDenseSet<int64_t> c =
454 llvm::SmallDenseSet<int64_t> ac = a;
455 llvm::set_intersect(ac, c);
456 llvm::set_subtract(ac, b);
458 llvm::SmallDenseSet<int64_t> bc = b;
459 llvm::set_intersect(bc, c);
460 llvm::set_subtract(bc, a);
462 llvm::SmallDenseSet<int64_t> batches = a;
463 llvm::set_intersect(batches, b);
464 llvm::set_intersect(batches, c);
467 llvm::SmallDenseSet<int64_t> ra =
469 llvm::SmallDenseSet<int64_t> rb =
471 llvm::set_intersect(ra, rb);
479 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
480 llvm::sort(dimensions.m.begin(), dimensions.m.end());
481 llvm::sort(dimensions.n.begin(), dimensions.n.end());
482 llvm::sort(dimensions.k.begin(), dimensions.k.end());
486 FailureOr<ContractionDimensions>
488 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
491 linalgOp.getIteratorTypesArray());
494 FailureOr<ContractionDimensions>
496 if (indexingMaps.size() != 3)
499 if (failed(iterators))
518 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
520 return MatchContractionResult::NotLinalgOp;
521 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
522 return MatchContractionResult::WrongNumOperands;
523 auto mapRange = linalgOp.getIndexingMapsArray();
524 if (linalgOp.getNumReductionLoops() == 0)
525 return MatchContractionResult::NoReduction;
526 if (llvm::any_of(mapRange,
528 return MatchContractionResult::NotProjectedPermutations;
532 arith::MulFOp, arith::AddFOp,
533 arith::MulIOp, arith::AddIOp,
534 complex::MulOp, complex::AddOp,
535 arith::AndIOp, arith::OrIOp>(
536 *linalgOp.getBlock())) {
537 return MatchContractionResult::NotAddMul;
543 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
546 return MatchContractionResult::Success;
552 case MatchContractionResult::NotLinalgOp:
553 return "expected a LinalgOp";
554 case MatchContractionResult::WrongNumOperands:
555 return "expected op with 2 inputs and 1 output";
556 case MatchContractionResult::NoReduction:
557 return "expected at least 1 reduction";
558 case MatchContractionResult::NotProjectedPermutations:
559 return "expected indexing maps to be projected permutations";
560 case MatchContractionResult::NotAddMul:
561 return "expected add/mul op in the body";
562 case MatchContractionResult::Success:
565 llvm_unreachable(
"unhandled MatchContractionResult case");
572 return isa<ContractionOpInterface>(op) ||
592 if (res != MatchContractionResult::Success)
603 template <
typename T>
605 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) :
nullptr);
617 struct ConvAccessExprWalker
620 llvm::SmallDenseSet<int64_t> convolvedDims;
622 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
624 llvm::SmallDenseSet<int64_t> unConvolvedDims;
626 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
631 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
633 return e.isFunctionOfDim(dimPos);
635 convolvedDims.erase(dimPos);
636 unConvolvedDims.erase(dimPos);
639 auto it = convolvedDimMapping.find(dimPos);
640 if (it != convolvedDimMapping.end()) {
641 int64_t pairedDim = it->second;
642 convolvedDims.erase(pairedDim);
643 unConvolvedDims.erase(pairedDim);
644 strideAndDilationMapping.erase(pairedDim);
645 convolvedDimMapping.erase(dimPos);
646 convolvedDimMapping.erase(pairedDim);
654 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
657 unConvolvedDims.insert(position);
669 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
670 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
671 if (failed(lhsDimPos) || failed(rhsDimPos))
673 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
674 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
678 FailureOr<int64_t> getDimExprOrMulExprDimPos(
AffineExpr expr) {
679 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
681 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
684 strideAndDilationMapping[dim] =
686 convolvedDims.insert(dim);
689 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
692 auto lhsExpr = symbolMulExpr.getLHS();
693 auto rhsExpr = symbolMulExpr.getRHS();
696 getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
699 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
701 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
702 if (!mulExpr || !dimExpr)
705 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
707 strideAndDilationMapping[dim] = mulExpr;
708 convolvedDims.insert(dim);
718 "expected map to have projected permutations");
719 llvm::SmallDenseSet<int64_t> preservedDims;
721 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
722 return preservedDims;
728 for (
auto e : exprs) {
729 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
730 assert(constantExpr &&
"Found non-constant stride/dilation");
731 vals.push_back(constantExpr.getValue());
743 static FailureOr<ConvolutionDimensions>
745 ConvAccessExprWalker &inputExprWalker,
746 bool allowEmptyConvolvedDims) {
748 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
750 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
752 filterMap, linalgOp.getIteratorTypesArray(), par);
754 outputMap, linalgOp.getIteratorTypesArray(), par);
757 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
758 llvm::set_intersect(batch, outputDims);
759 llvm::set_subtract(batch, filterDims);
762 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
763 llvm::set_intersect(oi, outputDims);
766 llvm::SmallDenseSet<int64_t> oc = filterDims;
767 llvm::set_intersect(oc, outputDims);
768 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
771 llvm::SmallDenseSet<int64_t> depth = filterDims;
772 llvm::set_intersect(depth, outputDims);
773 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
775 llvm::SmallDenseSet<int64_t> filterReducedDims =
777 linalgOp.getIteratorTypesArray(), red);
780 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
781 llvm::set_intersect(fl, filterReducedDims);
784 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
785 llvm::set_intersect(ic, filterReducedDims);
787 if (oi.empty() && !allowEmptyConvolvedDims)
800 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
801 llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
802 llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
803 llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
804 llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
805 llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
809 if (!nativeStrides) {
811 for (
unsigned oiDim : dimensions.outputImage)
812 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
815 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
817 auto nativeDilations =
819 if (!nativeDilations) {
821 for (
unsigned flDim : dimensions.filterLoop)
822 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
825 dimensions.dilations =
826 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
855 FailureOr<ConvolutionDimensions>
857 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
860 auto indexingMaps = linalgOp.getIndexingMapsArray();
863 ConvAccessExprWalker inputExprWalker;
864 for (
AffineExpr expr : indexingMaps[0].getResults())
865 (void)inputExprWalker.visit(expr);
866 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
889 bool allowEmptyConvolvedDims) {
890 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
892 return MatchConvolutionResult::NotLinalgOp;
893 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
894 return MatchConvolutionResult::WrongNumOperands;
896 auto indexingMaps = linalgOp.getIndexingMapsArray();
899 ConvAccessExprWalker inputExprWalker;
900 if (llvm::any_of(indexingMaps[0].getResults(),
902 return failed(inputExprWalker.visit(expr));
904 return MatchConvolutionResult::WrongInputIndexingMap;
908 if (!indexingMaps[1].isProjectedPermutation() ||
909 !indexingMaps.back().isProjectedPermutation())
910 return MatchConvolutionResult::NotProjectedPermutations;
912 auto iteratorTypes = linalgOp.getIteratorTypesArray();
914 llvm::SmallDenseSet<int64_t> outputDims =
916 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
930 llvm::SmallDenseSet<int64_t> allLoopDims;
931 for (
auto outputExpr : indexingMaps.back().getResults()) {
932 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
933 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
934 !filterDims.count(outputDim)) {
936 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
937 return MatchConvolutionResult::OutputDimsNotParallel;
938 allLoopDims.insert(outputDim);
941 if (inputExprWalker.convolvedDims.count(outputDim) &&
942 !filterDims.count(outputDim)) {
944 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
945 return MatchConvolutionResult::OutputDimsNotParallel;
946 allLoopDims.insert(outputDim);
949 if (!inputExprWalker.convolvedDims.count(outputDim) &&
950 !inputExprWalker.unConvolvedDims.count(outputDim) &&
951 filterDims.count(outputDim)) {
953 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
954 return MatchConvolutionResult::OutputDimsNotParallel;
955 allLoopDims.insert(outputDim);
958 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
959 filterDims.count(outputDim)) {
961 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
962 return MatchConvolutionResult::OutputDimsNotParallel;
963 allLoopDims.insert(outputDim);
966 return MatchConvolutionResult::NonConvolutionLoop;
968 for (
auto filterExpr : indexingMaps[1].getResults()) {
969 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
970 if (outputDims.count(filterDim) &&
971 !inputExprWalker.unConvolvedDims.count(filterDim) &&
972 !inputExprWalker.convolvedDims.count(filterDim)) {
976 if (inputExprWalker.convolvedDims.count(filterDim) &&
977 !outputDims.count(filterDim)) {
979 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
980 return MatchConvolutionResult::NonOutputDimNotReduction;
981 if (allLoopDims.count(filterDim))
982 return MatchConvolutionResult::NonConvolutionLoop;
983 allLoopDims.insert(filterDim);
986 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
987 !outputDims.count(filterDim)) {
989 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
990 return MatchConvolutionResult::NonOutputDimNotReduction;
991 if (allLoopDims.count(filterDim))
992 return MatchConvolutionResult::NonConvolutionLoop;
993 allLoopDims.insert(filterDim);
996 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
997 outputDims.count(filterDim)) {
1001 return MatchConvolutionResult::NonConvolutionLoop;
1004 if (allLoopDims.size() != linalgOp.getNumLoops())
1005 return MatchConvolutionResult::NonConvolutionLoop;
1007 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1008 return MatchConvolutionResult::EmptyConvolvedDims;
1012 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1013 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
1017 return MatchConvolutionResult::Success;
1023 case MatchConvolutionResult::NotLinalgOp:
1024 return "expected a LinalgOp";
1025 case MatchConvolutionResult::WrongNumOperands:
1026 return "expected op with 2 inputs and 1 output";
1027 case MatchConvolutionResult::WrongInputIndexingMap:
1028 return "unexpected input index map for convolutions";
1029 case MatchConvolutionResult::NotProjectedPermutations:
1030 return "expected output/filter indexing maps to be projected permutations";
1031 case MatchConvolutionResult::NonConvolutionLoop:
1032 return "unexpected loop dimension for convolution op";
1033 case MatchConvolutionResult::OutputDimsNotParallel:
1034 return "expected all iterators used to access outputs to be parallel";
1035 case MatchConvolutionResult::NonOutputDimNotReduction:
1036 return "expected all iterators not used to access outputs to be reduction";
1037 case MatchConvolutionResult::EmptyConvolvedDims:
1038 return "expected convolved dim to be non-empty";
1039 case MatchConvolutionResult::Success:
1042 llvm_unreachable(
"unhandled MatchConvolutionResult case");
1046 bool allowEmptyConvolvedDims) {
1048 linalgOp.getOperation(),
nullptr, allowEmptyConvolvedDims) ==
1054 if (res != MatchConvolutionResult::Success)
1071 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1074 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1077 OpOperand *value = linalgOp.getDpsInputOperand(0);
1078 if (!linalgOp.isScalar(value))
1087 return op->
emitError(
"expected a LinalgOp");
1089 return op->
emitError(
"expected op with 1 input and 1 output");
1091 return op->
emitError(
"expected op with scalar input");
1103 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
1104 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1112 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
1113 for (
OpOperand &opOperand : getOperation()->getOpOperands())
1114 llvm::append_range(res,
getShape(&opOperand));
1121 auto viewSizes = createFlatListOfOperandDims(b, loc);
1123 for (
unsigned idx = 0; idx < numRes; ++idx) {
1125 if (
auto d = dyn_cast<AffineDimExpr>(result)) {
1126 if (res[d.getPosition()].offset)
1128 res[d.getPosition()] =
1140 : positions(std::move(positions)) {}
1155 llvm::SmallBitVector positions;
1158 static std::pair<int64_t, int64_t>
1160 int64_t inputRankSum = 0;
1161 int64_t outputRankSum = 0;
1162 for (
OpOperand *input : op.getDpsInputOperands())
1163 inputRankSum += op.getRank(input);
1164 for (
OpOperand &output : op.getDpsInitsMutable())
1165 outputRankSum += op.getRank(&output);
1166 return {inputRankSum, inputRankSum + outputRankSum};
1181 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1190 resultShapesSubMapPos.first,
1191 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1192 AffineMap resultShapesFromInputShapesMap =
1193 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1197 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1198 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1200 Location loc = getOperation()->getLoc();
1204 rewriter, loc, resultShapesFromInputShapesMap,
1205 createFlatListOfOperandDims(b, loc));
1208 for (
OpOperand &opOperand : getDpsInitsMutable()) {
1210 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1211 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1212 if (!shapedType.isDynamicDim(dim)) {
1214 shapes.push_back(b.
getIndexAttr(shapedType.getDimSize(dim)));
1219 : allResultDimValues[pos];
1224 reifiedReturnShapes.emplace_back(std::move(shapes));
1231 int64_t LinalgOp::getIndexingMapIndex(
OpOperand *opOperand) {
1233 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1234 if (!dpsIface.isDpsInput(opOperand))
1235 return operandNumber;
1236 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1237 assert(!dpsIface.isDpsInit(opOperand));
1240 return cast<DestinationStyleOpInterface>(*this->getOperation())
1241 .getNumDpsInputs() +
1242 operandNumber - start;
1246 LinalgOp linalgOp = cast<LinalgOp>(op);
1248 if (!linalgOp.hasPureTensorSemantics() &&
1250 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1254 if (linalgOp.hasDynamicIndexingMaps())
1255 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1259 if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1264 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1265 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1267 unsigned numLoops = linalgOp.getNumLoops();
1271 <<
" dim(s) to match the number of loops";
1274 linalgOp.getReductionDims(redDims);
1276 if (!linalgOp.getShapesToLoopsMap())
1277 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1280 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1281 return op->
emitOpError(
"expects to have 1 region with 1 block");
1289 Block &block = linalgOp->getRegion(0).
front();
1291 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1292 return op->
emitOpError(
"expected as many non-induction variable region "
1293 "arguments as the number of input/output operands");
1295 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1297 if (isa<MemRefType, RankedTensorType>(elementType))
1300 if (elementType != argType)
1301 return op->
emitOpError(
"expected type of bb argument #")
1303 <<
" to match element or self type of the corresponding operand ("
1304 << elementType <<
")";
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.