22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetOperations.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/raw_ostream.h"
34 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
43 for (
auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
48 if (indexingMaps.empty()) {
51 return linalgOp.getNumLoops() == 0;
54 indexingMaps, linalgOp.getContext())) !=
AffineMap();
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
72 Block *body = op.getBlock();
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
78 return yieldOp->getOperand(0) == body->
getArgument(0);
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
96 Block *body = op.getBody();
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
104 Value yieldOperand = yieldOp->getOperand(0);
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
140 std::optional<SmallVector<int64_t>>
143 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
144 !op.isSingleYieldOp())
147 auto srcTy = op.getDpsInputOperand(0)->get().getType();
148 auto dstTy = op.getDpsInitOperand(0)->get().getType();
149 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
150 !isa<MemRefType, RankedTensorType>(dstTy))
156 auto dstMap = op.getIndexingMapsArray()[1];
157 if (!dstMap.isIdentity())
161 auto srcMap = op.getIndexingMapsArray()[0];
163 if (srcMap.getResults().size() >= dstMap.getResults().size())
167 for (
unsigned i = 0; i < srcMap.getNumResults(); ++i) {
168 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
171 int64_t pos = expr.getPosition();
172 if (i > 0 && pos <= position[i - 1])
174 position.push_back(expr.getPosition());
178 auto numDims = srcMap.getNumDims();
180 for (
auto dim : llvm::seq<int64_t>(0, numDims)) {
181 if (!llvm::is_contained(position, dim))
182 broadcastedDims.push_back(dim);
184 return broadcastedDims;
190 std::optional<SmallVector<int64_t>>
195 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
196 !op.isSingleYieldOp())
199 auto mapRange = op.getIndexingMapsArray();
200 if (mapRange.size() != 2)
203 auto mapOfInput = mapRange.front();
204 auto mapOfResult = mapRange.back();
208 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
212 for (
unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
213 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
214 permutation[expr.getPosition()] = i;
225 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
229 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
230 !llvm::all_of(op.getIndexingMapsArray(),
231 [](
AffineMap map) { return map.isIdentity(); }))
235 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
242 Block *body = op.getBody();
250 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
251 return !(!yieldOp || yieldOp.getNumOperands() != 1 ||
252 yieldOp->getOperand(0).getDefiningOp() != oper);
261 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
271 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
272 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
273 return !(!op.payloadUsesValueFromOperand(inputOpOperand0) ||
274 !op.payloadUsesValueFromOperand(inputOpOperand1));
288 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
289 if (!iface || !iface.hasNoEffect())
299 llvm::raw_ostream &errs) {
301 errs <<
"no terminator in the block";
306 errs <<
"expected block with 3 arguments";
312 errs <<
"expected terminator with 1 operand";
320 errs <<
"expected reduction op to be binary";
329 errs <<
"expected reduction to take block argument #2 as one of the "
330 "operands (modulo unary casts)";
335 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
339 errs <<
"expected elementwise op to be binary";
343 if (!isaPair(elementwiseOp, reductionOp)) {
344 errs <<
"expected reduction/elementwise op kind not satisfied";
357 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
364 template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
366 static_assert(
sizeof...(Args) % 2 == 0,
367 "expected an even number of template arguments");
368 if (isa<AddOpTy>(
add) && isa<MulOpTy>(
mul))
371 if constexpr (
sizeof...(Args) > 0)
379 template <
typename... Args>
391 static llvm::SmallDenseSet<int64_t>
394 utils::IteratorType iter) {
395 assert(iterators.size() == indexingMap.
getNumDims());
396 llvm::SmallDenseSet<int64_t> res;
398 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
399 if (iterators[d.getPosition()] == iter &&
401 return e.isFunctionOfDim(d.getPosition());
403 res.insert(d.getPosition());
410 auto par = utils::IteratorType::parallel;
411 auto red = utils::IteratorType::reduction;
418 static FailureOr<SmallVector<utils::IteratorType>>
424 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
425 iterators[dim.getPosition()] = par;
440 static FailureOr<ContractionDimensions>
443 llvm::SmallDenseSet<int64_t> a =
445 llvm::SmallDenseSet<int64_t> b =
447 llvm::SmallDenseSet<int64_t> c =
451 llvm::SmallDenseSet<int64_t> ac = a;
452 llvm::set_intersect(ac, c);
453 llvm::set_subtract(ac, b);
455 llvm::SmallDenseSet<int64_t> bc = b;
456 llvm::set_intersect(bc, c);
457 llvm::set_subtract(bc, a);
459 llvm::SmallDenseSet<int64_t> batches = a;
460 llvm::set_intersect(batches, b);
461 llvm::set_intersect(batches, c);
464 llvm::SmallDenseSet<int64_t> ra =
466 llvm::SmallDenseSet<int64_t> rb =
468 llvm::set_intersect(ra, rb);
476 llvm::sort(dimensions.batch);
477 llvm::sort(dimensions.m);
478 llvm::sort(dimensions.n);
479 llvm::sort(dimensions.k);
483 FailureOr<ContractionDimensions>
485 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
488 linalgOp.getIteratorTypesArray());
491 FailureOr<ContractionDimensions>
493 if (indexingMaps.size() != 3)
515 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
517 return MatchContractionResult::NotLinalgOp;
518 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
519 return MatchContractionResult::WrongNumOperands;
520 auto mapRange = linalgOp.getIndexingMapsArray();
521 if (linalgOp.getNumReductionLoops() == 0)
522 return MatchContractionResult::NoReduction;
523 if (llvm::any_of(mapRange,
525 return MatchContractionResult::NotProjectedPermutations;
529 arith::MulFOp, arith::AddFOp,
530 arith::MulIOp, arith::AddIOp,
531 complex::MulOp, complex::AddOp,
532 arith::AndIOp, arith::OrIOp>(
533 *linalgOp.getBlock())) {
534 return MatchContractionResult::NotAddMul;
540 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
543 return MatchContractionResult::Success;
549 case MatchContractionResult::NotLinalgOp:
550 return "expected a LinalgOp";
551 case MatchContractionResult::WrongNumOperands:
552 return "expected op with 2 inputs and 1 output";
553 case MatchContractionResult::NoReduction:
554 return "expected at least 1 reduction";
555 case MatchContractionResult::NotProjectedPermutations:
556 return "expected indexing maps to be projected permutations";
557 case MatchContractionResult::NotAddMul:
558 return "expected add/mul op in the body";
559 case MatchContractionResult::Success:
562 llvm_unreachable(
"unhandled MatchContractionResult case");
569 return isa<ContractionOpInterface>(op) ||
589 if (res != MatchContractionResult::Success)
600 template <
typename T>
602 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) :
nullptr);
614 struct ConvAccessExprWalker
617 llvm::SmallDenseSet<int64_t> convolvedDims;
619 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
621 llvm::SmallDenseSet<int64_t> unConvolvedDims;
623 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
628 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
630 return e.isFunctionOfDim(dimPos);
632 convolvedDims.erase(dimPos);
633 unConvolvedDims.erase(dimPos);
636 auto it = convolvedDimMapping.find(dimPos);
637 if (it != convolvedDimMapping.end()) {
638 int64_t pairedDim = it->second;
639 convolvedDims.erase(pairedDim);
640 unConvolvedDims.erase(pairedDim);
641 strideAndDilationMapping.erase(pairedDim);
642 convolvedDimMapping.erase(dimPos);
643 convolvedDimMapping.erase(pairedDim);
651 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
654 unConvolvedDims.insert(position);
666 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
667 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
670 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
671 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
675 FailureOr<int64_t> getDimExprOrMulExprDimPos(
AffineExpr expr) {
676 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
678 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
681 strideAndDilationMapping[dim] =
683 convolvedDims.insert(dim);
686 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
689 auto lhsExpr = symbolMulExpr.getLHS();
690 auto rhsExpr = symbolMulExpr.getRHS();
693 getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
696 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
698 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
699 if (!mulExpr || !dimExpr)
702 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
704 strideAndDilationMapping[dim] = mulExpr;
705 convolvedDims.insert(dim);
715 "expected map to have projected permutations");
716 llvm::SmallDenseSet<int64_t> preservedDims;
718 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
719 return preservedDims;
725 for (
auto e : exprs) {
726 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
727 assert(constantExpr &&
"Found non-constant stride/dilation");
728 vals.push_back(constantExpr.getValue());
740 static FailureOr<ConvolutionDimensions>
742 ConvAccessExprWalker &inputExprWalker,
743 bool allowEmptyConvolvedDims) {
745 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
747 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
749 filterMap, linalgOp.getIteratorTypesArray(), par);
751 outputMap, linalgOp.getIteratorTypesArray(), par);
754 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
755 llvm::set_intersect(batch, outputDims);
756 llvm::set_subtract(batch, filterDims);
759 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
760 llvm::set_intersect(oi, outputDims);
763 llvm::SmallDenseSet<int64_t> oc = filterDims;
764 llvm::set_intersect(oc, outputDims);
765 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
768 llvm::SmallDenseSet<int64_t> depth = filterDims;
769 llvm::set_intersect(depth, outputDims);
770 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
772 llvm::SmallDenseSet<int64_t> filterReducedDims =
774 linalgOp.getIteratorTypesArray(), red);
777 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
778 llvm::set_intersect(fl, filterReducedDims);
781 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
782 llvm::set_intersect(ic, filterReducedDims);
784 if (oi.empty() && !allowEmptyConvolvedDims)
797 llvm::sort(dimensions.batch);
798 llvm::sort(dimensions.outputImage);
799 llvm::sort(dimensions.outputChannel);
800 llvm::sort(dimensions.filterLoop);
801 llvm::sort(dimensions.inputChannel);
802 llvm::sort(dimensions.depth);
806 if (!nativeStrides) {
808 for (
unsigned oiDim : dimensions.outputImage)
809 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
812 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
814 auto nativeDilations =
816 if (!nativeDilations) {
818 for (
unsigned flDim : dimensions.filterLoop)
819 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
822 dimensions.dilations =
823 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
852 FailureOr<ConvolutionDimensions>
854 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
857 auto indexingMaps = linalgOp.getIndexingMapsArray();
860 ConvAccessExprWalker inputExprWalker;
861 for (
AffineExpr expr : indexingMaps[0].getResults())
862 (void)inputExprWalker.visit(expr);
863 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
886 bool allowEmptyConvolvedDims) {
887 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
889 return MatchConvolutionResult::NotLinalgOp;
890 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
891 return MatchConvolutionResult::WrongNumOperands;
893 auto indexingMaps = linalgOp.getIndexingMapsArray();
896 ConvAccessExprWalker inputExprWalker;
897 if (llvm::any_of(indexingMaps[0].getResults(),
899 return failed(inputExprWalker.visit(expr));
901 return MatchConvolutionResult::WrongInputIndexingMap;
905 if (!indexingMaps[1].isProjectedPermutation() ||
906 !indexingMaps.back().isProjectedPermutation())
907 return MatchConvolutionResult::NotProjectedPermutations;
909 auto iteratorTypes = linalgOp.getIteratorTypesArray();
911 llvm::SmallDenseSet<int64_t> outputDims =
913 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
927 llvm::SmallDenseSet<int64_t> allLoopDims;
928 for (
auto outputExpr : indexingMaps.back().getResults()) {
929 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
930 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
931 !filterDims.count(outputDim)) {
933 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
934 return MatchConvolutionResult::OutputDimsNotParallel;
935 allLoopDims.insert(outputDim);
938 if (inputExprWalker.convolvedDims.count(outputDim) &&
939 !filterDims.count(outputDim)) {
941 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
942 return MatchConvolutionResult::OutputDimsNotParallel;
943 allLoopDims.insert(outputDim);
946 if (!inputExprWalker.convolvedDims.count(outputDim) &&
947 !inputExprWalker.unConvolvedDims.count(outputDim) &&
948 filterDims.count(outputDim)) {
950 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
951 return MatchConvolutionResult::OutputDimsNotParallel;
952 allLoopDims.insert(outputDim);
955 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
956 filterDims.count(outputDim)) {
958 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
959 return MatchConvolutionResult::OutputDimsNotParallel;
960 allLoopDims.insert(outputDim);
963 return MatchConvolutionResult::NonConvolutionLoop;
965 for (
auto filterExpr : indexingMaps[1].getResults()) {
966 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
967 if (outputDims.count(filterDim) &&
968 !inputExprWalker.unConvolvedDims.count(filterDim) &&
969 !inputExprWalker.convolvedDims.count(filterDim)) {
973 if (inputExprWalker.convolvedDims.count(filterDim) &&
974 !outputDims.count(filterDim)) {
976 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
977 return MatchConvolutionResult::NonOutputDimNotReduction;
978 if (allLoopDims.count(filterDim))
979 return MatchConvolutionResult::NonConvolutionLoop;
980 allLoopDims.insert(filterDim);
983 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
984 !outputDims.count(filterDim)) {
986 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
987 return MatchConvolutionResult::NonOutputDimNotReduction;
988 if (allLoopDims.count(filterDim))
989 return MatchConvolutionResult::NonConvolutionLoop;
990 allLoopDims.insert(filterDim);
993 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
994 outputDims.count(filterDim)) {
998 return MatchConvolutionResult::NonConvolutionLoop;
1001 if (allLoopDims.size() != linalgOp.getNumLoops())
1002 return MatchConvolutionResult::NonConvolutionLoop;
1004 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1005 return MatchConvolutionResult::EmptyConvolvedDims;
1009 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1010 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
1014 return MatchConvolutionResult::Success;
1020 case MatchConvolutionResult::NotLinalgOp:
1021 return "expected a LinalgOp";
1022 case MatchConvolutionResult::WrongNumOperands:
1023 return "expected op with 2 inputs and 1 output";
1024 case MatchConvolutionResult::WrongInputIndexingMap:
1025 return "unexpected input index map for convolutions";
1026 case MatchConvolutionResult::NotProjectedPermutations:
1027 return "expected output/filter indexing maps to be projected permutations";
1028 case MatchConvolutionResult::NonConvolutionLoop:
1029 return "unexpected loop dimension for convolution op";
1030 case MatchConvolutionResult::OutputDimsNotParallel:
1031 return "expected all iterators used to access outputs to be parallel";
1032 case MatchConvolutionResult::NonOutputDimNotReduction:
1033 return "expected all iterators not used to access outputs to be reduction";
1034 case MatchConvolutionResult::EmptyConvolvedDims:
1035 return "expected convolved dim to be non-empty";
1036 case MatchConvolutionResult::Success:
1039 llvm_unreachable(
"unhandled MatchConvolutionResult case");
1043 bool allowEmptyConvolvedDims) {
1045 linalgOp.getOperation(),
nullptr, allowEmptyConvolvedDims) ==
1051 if (res != MatchConvolutionResult::Success)
1068 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1071 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1074 OpOperand *value = linalgOp.getDpsInputOperand(0);
1075 if (!linalgOp.isScalar(value))
1084 return op->
emitError(
"expected a LinalgOp");
1086 return op->
emitError(
"expected op with 1 input and 1 output");
1088 return op->
emitError(
"expected op with scalar input");
1100 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
1101 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1109 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
1110 for (
OpOperand &opOperand : getOperation()->getOpOperands())
1111 llvm::append_range(res,
getShape(&opOperand));
1118 auto viewSizes = createFlatListOfOperandDims(b, loc);
1120 for (
unsigned idx = 0; idx < numRes; ++idx) {
1122 if (
auto d = dyn_cast<AffineDimExpr>(result)) {
1123 if (res[d.getPosition()].offset)
1125 res[d.getPosition()] =
1137 : positions(std::move(positions)) {}
1152 llvm::SmallBitVector positions;
1155 static std::pair<int64_t, int64_t>
1157 int64_t inputRankSum = 0;
1158 int64_t outputRankSum = 0;
1159 for (
OpOperand *input : op.getDpsInputOperands())
1160 inputRankSum += op.getRank(input);
1161 for (
OpOperand &output : op.getDpsInitsMutable())
1162 outputRankSum += op.getRank(&output);
1163 return {inputRankSum, inputRankSum + outputRankSum};
1178 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1187 resultShapesSubMapPos.first,
1188 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1189 AffineMap resultShapesFromInputShapesMap =
1190 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1194 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1195 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1197 Location loc = getOperation()->getLoc();
1201 rewriter, loc, resultShapesFromInputShapesMap,
1202 createFlatListOfOperandDims(b, loc));
1205 for (
OpOperand &opOperand : getDpsInitsMutable()) {
1207 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1208 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1209 if (!shapedType.isDynamicDim(dim)) {
1211 shapes.push_back(b.
getIndexAttr(shapedType.getDimSize(dim)));
1216 : allResultDimValues[pos];
1221 reifiedReturnShapes.emplace_back(std::move(shapes));
1228 int64_t LinalgOp::getIndexingMapIndex(
OpOperand *opOperand) {
1230 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1231 if (!dpsIface.isDpsInput(opOperand))
1232 return operandNumber;
1233 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1234 assert(!dpsIface.isDpsInit(opOperand));
1237 return cast<DestinationStyleOpInterface>(*this->getOperation())
1238 .getNumDpsInputs() +
1239 operandNumber - start;
1243 LinalgOp linalgOp = cast<LinalgOp>(op);
1245 if (!linalgOp.hasPureTensorSemantics() &&
1247 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1251 if (linalgOp.hasDynamicIndexingMaps())
1252 if (
failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1256 if (
failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1261 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1262 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1264 unsigned numLoops = linalgOp.getNumLoops();
1268 <<
" dim(s) to match the number of loops";
1271 linalgOp.getReductionDims(redDims);
1273 if (!linalgOp.getShapesToLoopsMap())
1274 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1277 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1278 return op->
emitOpError(
"expects to have 1 region with 1 block");
1286 Block &block = linalgOp->getRegion(0).
front();
1288 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1289 return op->
emitOpError(
"expected as many non-induction variable region "
1290 "arguments as the number of input/output operands");
1292 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1294 if (isa<MemRefType, RankedTensorType>(elementType))
1297 if (elementType != argType)
1298 return op->
emitOpError(
"expected type of bb argument #")
1300 <<
" to match element or self type of the corresponding operand ("
1301 << elementType <<
")";
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.