21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 #include "llvm/ADT/SmallVector.h"
31 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
40 for (
auto &opOperand : linalgOp->getOpOperands()) {
41 if (llvm::is_contained(droppedOperands, &opOperand))
43 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
45 if (indexingMaps.empty()) {
48 return linalgOp.getNumLoops() == 0;
59 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
62 auto mapRange = op.getIndexingMapsArray();
63 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
64 !mapRange.back().isIdentity()) {
76 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
77 !op.isSingleYieldOp())
81 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
82 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
85 OpOperand *value = op.getDpsInputOperand(0);
86 if (!op.isScalar(value))
94 std::optional<SmallVector<int64_t>>
97 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
98 !op.isSingleYieldOp())
101 auto srcTy = op.getDpsInputOperand(0)->get().getType();
102 auto dstTy = op.getDpsInitOperand(0)->get().getType();
103 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
104 !isa<MemRefType, RankedTensorType>(dstTy))
110 auto dstMap = op.getIndexingMapsArray()[1];
111 if (!dstMap.isIdentity())
115 auto srcMap = op.getIndexingMapsArray()[0];
117 if (srcMap.getResults().size() >= dstMap.getResults().size())
121 for (
unsigned i = 0; i < srcMap.getNumResults(); ++i) {
122 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
125 int64_t pos = expr.getPosition();
126 if (i > 0 && pos <= position[i - 1])
128 position.push_back(expr.getPosition());
132 auto numDims = srcMap.getNumDims();
134 for (
auto dim : llvm::seq<int64_t>(0, numDims)) {
135 if (!llvm::is_contained(position, dim))
136 broadcastedDims.push_back(dim);
138 return broadcastedDims;
144 std::optional<SmallVector<int64_t>>
149 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
150 !op.isSingleYieldOp())
153 auto mapRange = op.getIndexingMapsArray();
154 if (mapRange.size() != 2)
157 auto mapOfInput = mapRange.front();
158 auto mapOfResult = mapRange.back();
162 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
166 for (
unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
167 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
168 permutation[expr.getPosition()] = i;
179 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
183 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
184 !llvm::all_of(op.getIndexingMapsArray(),
185 [](
AffineMap map) { return map.isIdentity(); }))
189 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
196 Block *body = op.getBody();
204 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
205 if (!yieldOp || yieldOp.getNumOperands() != 1 ||
206 yieldOp->getOperand(0).getDefiningOp() != oper)
217 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
227 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
228 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
229 if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
230 !op.payloadUsesValueFromOperand(inputOpOperand1))
246 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
247 if (!iface || !iface.hasNoEffect())
257 llvm::raw_ostream &errs) {
259 errs <<
"no terminator in the block";
264 errs <<
"expected block with 3 arguments";
270 errs <<
"expected terminator with 1 operand";
277 errs <<
"expected reduction op to be binary";
286 errs <<
"expected reduction to take block argument #2 as one of the "
287 "operands (modulo unary casts)";
292 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
296 errs <<
"expected elementwise op to be binary";
300 if (!isaPair(elementwiseOp, reductionOp)) {
301 errs <<
"expected reduction/elementwise op kind not satisfied";
314 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
321 template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
323 static_assert(
sizeof...(Args) % 2 == 0,
324 "expected an even number of template arguments");
325 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
328 if constexpr (
sizeof...(Args) > 0)
336 template <
typename... Args>
348 static llvm::SmallDenseSet<int64_t>
351 utils::IteratorType iter) {
352 assert(iterators.size() == indexingMap.
getNumDims());
353 llvm::SmallDenseSet<int64_t> res;
355 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
356 if (iterators[d.getPosition()] == iter &&
358 return e.isFunctionOfDim(d.getPosition());
360 res.insert(d.getPosition());
367 auto par = utils::IteratorType::parallel;
368 auto red = utils::IteratorType::reduction;
375 static FailureOr<SmallVector<utils::IteratorType>>
381 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
382 iterators[dim.getPosition()] = par;
397 static FailureOr<ContractionDimensions>
400 llvm::SmallDenseSet<int64_t> a =
402 llvm::SmallDenseSet<int64_t> b =
404 llvm::SmallDenseSet<int64_t> c =
408 llvm::SmallDenseSet<int64_t> ac = a;
409 llvm::set_intersect(ac, c);
410 llvm::set_subtract(ac, b);
412 llvm::SmallDenseSet<int64_t> bc = b;
413 llvm::set_intersect(bc, c);
414 llvm::set_subtract(bc, a);
416 llvm::SmallDenseSet<int64_t> batches = a;
417 llvm::set_intersect(batches, b);
418 llvm::set_intersect(batches, c);
421 llvm::SmallDenseSet<int64_t> ra =
423 llvm::SmallDenseSet<int64_t> rb =
425 llvm::set_intersect(ra, rb);
433 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
434 llvm::sort(dimensions.m.begin(), dimensions.m.end());
435 llvm::sort(dimensions.n.begin(), dimensions.n.end());
436 llvm::sort(dimensions.k.begin(), dimensions.k.end());
440 FailureOr<ContractionDimensions>
442 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
445 linalgOp.getIteratorTypesArray());
448 FailureOr<ContractionDimensions>
450 if (indexingMaps.size() != 3)
453 if (failed(iterators))
472 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
474 return MatchContractionResult::NotLinalgOp;
475 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
476 return MatchContractionResult::WrongNumOperands;
477 auto mapRange = linalgOp.getIndexingMapsArray();
478 if (linalgOp.getNumReductionLoops() == 0)
479 return MatchContractionResult::NoReduction;
480 if (llvm::any_of(mapRange,
481 [](
AffineMap m) {
return !m.isProjectedPermutation(); }))
482 return MatchContractionResult::NotProjectedPermutations;
486 arith::MulFOp, arith::AddFOp,
487 arith::MulIOp, arith::AddIOp,
488 complex::MulOp, complex::AddOp,
489 arith::AndIOp, arith::OrIOp>(
490 *linalgOp.getBlock())) {
491 return MatchContractionResult::NotAddMul;
497 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
500 return MatchContractionResult::Success;
506 case MatchContractionResult::NotLinalgOp:
507 return "expected a LinalgOp";
508 case MatchContractionResult::WrongNumOperands:
509 return "expected op with 2 inputs and 1 output";
510 case MatchContractionResult::NoReduction:
511 return "expected at least 1 reduction";
512 case MatchContractionResult::NotProjectedPermutations:
513 return "expected indexing maps to be projected permutations";
514 case MatchContractionResult::NotAddMul:
515 return "expected add/mul op in the body";
516 case MatchContractionResult::Success:
519 llvm_unreachable(
"unhandled MatchContractionResult case");
526 return isa<ContractionOpInterface>(op) ||
546 if (res != MatchContractionResult::Success)
557 template <
typename T>
559 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) :
nullptr);
571 struct ConvAccessExprWalker
574 llvm::SmallDenseSet<int64_t> convolvedDims;
576 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
578 llvm::SmallDenseSet<int64_t> unConvolvedDims;
580 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
585 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
587 return e.isFunctionOfDim(dimPos);
589 convolvedDims.erase(dimPos);
590 unConvolvedDims.erase(dimPos);
593 auto it = convolvedDimMapping.find(dimPos);
594 if (it != convolvedDimMapping.end()) {
595 int64_t pairedDim = it->second;
596 convolvedDims.erase(pairedDim);
597 unConvolvedDims.erase(pairedDim);
598 strideAndDilationMapping.erase(pairedDim);
599 convolvedDimMapping.erase(dimPos);
600 convolvedDimMapping.erase(pairedDim);
608 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
611 unConvolvedDims.insert(position);
623 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
624 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
625 if (failed(lhsDimPos) || failed(rhsDimPos))
627 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
628 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
632 FailureOr<int64_t> getDimExprOrMulExprDimPos(
AffineExpr expr) {
633 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
635 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
638 strideAndDilationMapping[dim] =
640 convolvedDims.insert(dim);
643 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
646 auto lhsExpr = symbolMulExpr.getLHS();
647 auto rhsExpr = symbolMulExpr.getRHS();
650 getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
653 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
655 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
656 if (!mulExpr || !dimExpr)
659 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
661 strideAndDilationMapping[dim] = mulExpr;
662 convolvedDims.insert(dim);
672 "expected map to have projected permutations");
673 llvm::SmallDenseSet<int64_t> preservedDims;
675 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
676 return preservedDims;
682 for (
auto e : exprs) {
683 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
684 assert(constantExpr &&
"Found non-constant stride/dilation");
685 vals.push_back(constantExpr.getValue());
697 static FailureOr<ConvolutionDimensions>
699 ConvAccessExprWalker &inputExprWalker,
700 bool allowEmptyConvolvedDims) {
702 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
704 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
706 filterMap, linalgOp.getIteratorTypesArray(), par);
708 outputMap, linalgOp.getIteratorTypesArray(), par);
711 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
712 llvm::set_intersect(batch, outputDims);
713 llvm::set_subtract(batch, filterDims);
716 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
717 llvm::set_intersect(oi, outputDims);
720 llvm::SmallDenseSet<int64_t> oc = filterDims;
721 llvm::set_intersect(oc, outputDims);
722 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
725 llvm::SmallDenseSet<int64_t> depth = filterDims;
726 llvm::set_intersect(depth, outputDims);
727 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
729 llvm::SmallDenseSet<int64_t> filterReducedDims =
731 linalgOp.getIteratorTypesArray(), red);
734 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
735 llvm::set_intersect(fl, filterReducedDims);
738 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
739 llvm::set_intersect(ic, filterReducedDims);
741 if (oi.empty() && !allowEmptyConvolvedDims)
754 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
755 llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
756 llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
757 llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
758 llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
759 llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
763 if (!nativeStrides) {
765 for (
unsigned oiDim : dimensions.outputImage)
766 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
769 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
771 auto nativeDilations =
773 if (!nativeDilations) {
775 for (
unsigned flDim : dimensions.filterLoop)
776 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
779 dimensions.dilations =
780 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
809 FailureOr<ConvolutionDimensions>
811 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
814 auto indexingMaps = linalgOp.getIndexingMapsArray();
817 ConvAccessExprWalker inputExprWalker;
818 for (
AffineExpr expr : indexingMaps[0].getResults())
819 (void)inputExprWalker.visit(expr);
820 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
843 bool allowEmptyConvolvedDims) {
844 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
846 return MatchConvolutionResult::NotLinalgOp;
847 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
848 return MatchConvolutionResult::WrongNumOperands;
850 auto indexingMaps = linalgOp.getIndexingMapsArray();
853 ConvAccessExprWalker inputExprWalker;
854 if (llvm::any_of(indexingMaps[0].getResults(),
856 return failed(inputExprWalker.visit(expr));
858 return MatchConvolutionResult::WrongInputIndexingMap;
862 if (!indexingMaps[1].isProjectedPermutation() ||
863 !indexingMaps.back().isProjectedPermutation())
864 return MatchConvolutionResult::NotProjectedPermutations;
866 auto iteratorTypes = linalgOp.getIteratorTypesArray();
868 llvm::SmallDenseSet<int64_t> outputDims =
870 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
884 llvm::SmallDenseSet<int64_t> allLoopDims;
885 for (
auto outputExpr : indexingMaps.back().getResults()) {
886 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
887 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
888 !filterDims.count(outputDim)) {
890 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
891 return MatchConvolutionResult::OutputDimsNotParallel;
892 allLoopDims.insert(outputDim);
895 if (inputExprWalker.convolvedDims.count(outputDim) &&
896 !filterDims.count(outputDim)) {
898 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
899 return MatchConvolutionResult::OutputDimsNotParallel;
900 allLoopDims.insert(outputDim);
903 if (!inputExprWalker.convolvedDims.count(outputDim) &&
904 !inputExprWalker.unConvolvedDims.count(outputDim) &&
905 filterDims.count(outputDim)) {
907 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
908 return MatchConvolutionResult::OutputDimsNotParallel;
909 allLoopDims.insert(outputDim);
912 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
913 filterDims.count(outputDim)) {
915 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
916 return MatchConvolutionResult::OutputDimsNotParallel;
917 allLoopDims.insert(outputDim);
920 return MatchConvolutionResult::NonConvolutionLoop;
922 for (
auto filterExpr : indexingMaps[1].getResults()) {
923 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
924 if (outputDims.count(filterDim) &&
925 !inputExprWalker.unConvolvedDims.count(filterDim) &&
926 !inputExprWalker.convolvedDims.count(filterDim)) {
930 if (inputExprWalker.convolvedDims.count(filterDim) &&
931 !outputDims.count(filterDim)) {
933 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
934 return MatchConvolutionResult::NonOutputDimNotReduction;
935 if (allLoopDims.count(filterDim))
936 return MatchConvolutionResult::NonConvolutionLoop;
937 allLoopDims.insert(filterDim);
940 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
941 !outputDims.count(filterDim)) {
943 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
944 return MatchConvolutionResult::NonOutputDimNotReduction;
945 if (allLoopDims.count(filterDim))
946 return MatchConvolutionResult::NonConvolutionLoop;
947 allLoopDims.insert(filterDim);
950 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
951 outputDims.count(filterDim)) {
955 return MatchConvolutionResult::NonConvolutionLoop;
958 if (allLoopDims.size() != linalgOp.getNumLoops())
959 return MatchConvolutionResult::NonConvolutionLoop;
961 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
962 return MatchConvolutionResult::EmptyConvolvedDims;
966 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
967 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
971 return MatchConvolutionResult::Success;
977 case MatchConvolutionResult::NotLinalgOp:
978 return "expected a LinalgOp";
979 case MatchConvolutionResult::WrongNumOperands:
980 return "expected op with 2 inputs and 1 output";
981 case MatchConvolutionResult::WrongInputIndexingMap:
982 return "unexpected input index map for convolutions";
983 case MatchConvolutionResult::NotProjectedPermutations:
984 return "expected output/filter indexing maps to be projected permutations";
985 case MatchConvolutionResult::NonConvolutionLoop:
986 return "unexpected loop dimension for convolution op";
987 case MatchConvolutionResult::OutputDimsNotParallel:
988 return "expected all iterators used to access outputs to be parallel";
989 case MatchConvolutionResult::NonOutputDimNotReduction:
990 return "expected all iterators not used to access outputs to be reduction";
991 case MatchConvolutionResult::EmptyConvolvedDims:
992 return "expected convolved dim to be non-empty";
993 case MatchConvolutionResult::Success:
996 llvm_unreachable(
"unhandled MatchConvolutionResult case");
1000 bool allowEmptyConvolvedDims) {
1002 linalgOp.getOperation(),
nullptr, allowEmptyConvolvedDims) ==
1008 if (res != MatchConvolutionResult::Success)
1025 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1028 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1031 OpOperand *value = linalgOp.getDpsInputOperand(0);
1032 if (!linalgOp.isScalar(value))
1041 return op->
emitError(
"expected a LinalgOp");
1043 return op->
emitError(
"expected op with 1 input and 1 output");
1045 return op->
emitError(
"expected op with scalar input");
1057 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
1058 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1066 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
1067 for (
OpOperand &opOperand : getOperation()->getOpOperands())
1068 llvm::append_range(res,
getShape(&opOperand));
1075 auto viewSizes = createFlatListOfOperandDims(b, loc);
1077 for (
unsigned idx = 0; idx < numRes; ++idx) {
1079 if (
auto d = dyn_cast<AffineDimExpr>(result)) {
1080 if (res[d.getPosition()].offset)
1082 res[d.getPosition()] =
1094 for (
unsigned idx = 0; idx < numRes; ++idx) {
1096 if (
auto d = dyn_cast<AffineDimExpr>(result))
1097 res[d.getPosition()] = allShapeSizes[idx];
1107 : positions(std::move(positions)) {}
1122 llvm::SmallBitVector positions;
1125 static std::pair<int64_t, int64_t>
1127 int64_t inputRankSum = 0;
1128 int64_t outputRankSum = 0;
1129 for (
OpOperand *input : op.getDpsInputOperands())
1130 inputRankSum += op.getRank(input);
1131 for (
OpOperand &output : op.getDpsInitsMutable())
1132 outputRankSum += op.getRank(&output);
1133 return {inputRankSum, inputRankSum + outputRankSum};
1148 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1157 resultShapesSubMapPos.first,
1158 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1159 AffineMap resultShapesFromInputShapesMap =
1160 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1164 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1165 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1167 Location loc = getOperation()->getLoc();
1171 rewriter, loc, resultShapesFromInputShapesMap,
1172 createFlatListOfOperandDims(b, loc));
1175 for (
OpOperand &opOperand : getDpsInitsMutable()) {
1177 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1178 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1179 if (!shapedType.isDynamicDim(dim)) {
1181 shapes.push_back(b.
getIndexAttr(shapedType.getDimSize(dim)));
1186 : allResultDimValues[pos];
1191 reifiedReturnShapes.emplace_back(std::move(shapes));
1198 int64_t LinalgOp::getIndexingMapIndex(
OpOperand *opOperand) {
1200 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1201 if (!dpsIface.isDpsInput(opOperand))
1202 return operandNumber;
1203 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1204 assert(!dpsIface.isDpsInit(opOperand));
1207 return cast<DestinationStyleOpInterface>(*this->getOperation())
1208 .getNumDpsInputs() +
1209 operandNumber - start;
1213 LinalgOp linalgOp = cast<LinalgOp>(op);
1216 if (!linalgOp.hasPureTensorSemantics() &&
1218 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1222 if (linalgOp.hasDynamicIndexingMaps())
1223 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1227 if (
static_cast<int64_t
>(linalgOp.getIndexingMapsArray().size()) !=
1228 linalgOp->getNumOperands())
1229 return op->
emitOpError(
"expected the number of indexing_map (")
1230 << linalgOp.getIndexingMapsArray().size()
1231 <<
") to be equal to the number of input/output operands ("
1232 << linalgOp->getNumOperands() <<
")";
1234 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1235 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1239 return op->
emitOpError(
"unexpected symbols in indexing_map #")
1243 unsigned numLoops = linalgOp.getNumLoops();
1247 <<
" dim(s) to match the number of loops";
1249 int64_t rank = linalgOp.getRank(&opOperand);
1252 << rank <<
") to match the result rank of indexing_map #"
1258 linalgOp.getReductionDims(redDims);
1260 if (!linalgOp.getShapesToLoopsMap())
1261 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1269 if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1270 for (int64_t &range : endLoopRangeValues)
1272 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1273 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1275 indexingMap.
compose(startLoopRangeValues);
1277 indexingMap.
compose(endLoopRangeValues);
1279 for (
auto dim : llvm::seq<int64_t>(0, shape.size())) {
1281 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1294 int64_t inferredDimSize =
1295 std::max(startIndices[dim], endIndices[dim]) + 1;
1296 if (
std::min(startIndices[dim], endIndices[dim]) < 0) {
1299 llvm::raw_string_ostream os(mapStr);
1303 "unexpected result less than 0 at expression #")
1304 << dim <<
" in " << mapStr;
1306 if (dyn_cast<AffineDimExpr>(indexingMap.
getResult(dim))) {
1307 if (inferredDimSize != shape[dim]) {
1308 return op->
emitOpError(
"inferred input/output operand #")
1310 << dim <<
" to be " << inferredDimSize <<
", but found "
1314 if (inferredDimSize > shape[dim]) {
1315 return op->
emitOpError(
"inferred input/output operand #")
1317 << dim <<
" to be greater than or equal to "
1318 << inferredDimSize <<
", but found " << shape[dim];
1326 if (linalgOp->getNumRegions() != 1 ||
1327 !llvm::hasSingleElement(linalgOp->getRegion(0)))
1328 return op->
emitOpError(
"expects to have 1 region with 1 block");
1336 Block &block = linalgOp->getRegion(0).
front();
1338 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1339 return op->
emitOpError(
"expected as many non-induction variable region "
1340 "arguments as the number of input/output operands");
1342 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1344 if (isa<MemRefType, RankedTensorType>(elementType))
1347 if (elementType != argType)
1348 return op->
emitOpError(
"expected type of bb argument #")
1350 <<
" to match element or self type of the corresponding operand ("
1351 << elementType <<
")";
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.