22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetOperations.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/raw_ostream.h"
34 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
43 for (
auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
48 if (indexingMaps.empty()) {
51 return linalgOp.getNumLoops() == 0;
54 indexingMaps, linalgOp.getContext())) !=
AffineMap();
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
72 Block *body = op.getBlock();
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
78 return yieldOp->getOperand(0) == body->
getArgument(0);
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
96 Block *body = op.getBody();
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
104 Value yieldOperand = yieldOp->getOperand(0);
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
140 std::optional<SmallVector<int64_t>>
143 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
144 !op.isSingleYieldOp())
147 auto srcTy = op.getDpsInputOperand(0)->get().getType();
148 auto dstTy = op.getDpsInitOperand(0)->get().getType();
149 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
150 !isa<MemRefType, RankedTensorType>(dstTy))
156 auto dstMap = op.getIndexingMapsArray()[1];
157 if (!dstMap.isIdentity())
161 auto srcMap = op.getIndexingMapsArray()[0];
163 if (srcMap.getResults().size() >= dstMap.getResults().size())
167 for (
unsigned i = 0; i < srcMap.getNumResults(); ++i) {
168 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
171 int64_t pos = expr.getPosition();
172 if (i > 0 && pos <= position[i - 1])
174 position.push_back(expr.getPosition());
178 auto numDims = srcMap.getNumDims();
180 for (
auto dim : llvm::seq<int64_t>(0, numDims)) {
181 if (!llvm::is_contained(position, dim))
182 broadcastedDims.push_back(dim);
184 return broadcastedDims;
190 std::optional<SmallVector<int64_t>>
195 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
196 !op.isSingleYieldOp())
199 auto mapRange = op.getIndexingMapsArray();
200 if (mapRange.size() != 2)
203 auto mapOfInput = mapRange.front();
204 auto mapOfResult = mapRange.back();
208 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
212 for (
unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
213 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
214 permutation[expr.getPosition()] = i;
225 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
229 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
230 !llvm::all_of(op.getIndexingMapsArray(),
231 [](
AffineMap map) { return map.isIdentity(); }))
235 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
242 Block *body = op.getBody();
250 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
251 return !(!yieldOp || yieldOp.getNumOperands() != 1 ||
252 yieldOp->getOperand(0).getDefiningOp() != oper);
261 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
271 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
272 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
273 return !(!op.payloadUsesValueFromOperand(inputOpOperand0) ||
274 !op.payloadUsesValueFromOperand(inputOpOperand1));
288 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
289 if (!iface || !iface.hasNoEffect())
299 llvm::raw_ostream &errs) {
301 errs <<
"no terminator in the block";
306 errs <<
"expected block with 3 arguments";
312 errs <<
"expected terminator with 1 operand";
319 errs <<
"expected reduction op to be binary";
328 errs <<
"expected reduction to take block argument #2 as one of the "
329 "operands (modulo unary casts)";
334 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
338 errs <<
"expected elementwise op to be binary";
342 if (!isaPair(elementwiseOp, reductionOp)) {
343 errs <<
"expected reduction/elementwise op kind not satisfied";
356 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
363 template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
365 static_assert(
sizeof...(Args) % 2 == 0,
366 "expected an even number of template arguments");
367 if (isa<AddOpTy>(
add) && isa<MulOpTy>(mul))
370 if constexpr (
sizeof...(Args) > 0)
378 template <
typename... Args>
390 static llvm::SmallDenseSet<int64_t>
393 utils::IteratorType iter) {
394 assert(iterators.size() == indexingMap.
getNumDims());
395 llvm::SmallDenseSet<int64_t> res;
397 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
398 if (iterators[d.getPosition()] == iter &&
400 return e.isFunctionOfDim(d.getPosition());
402 res.insert(d.getPosition());
409 auto par = utils::IteratorType::parallel;
410 auto red = utils::IteratorType::reduction;
417 static FailureOr<SmallVector<utils::IteratorType>>
423 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
424 iterators[dim.getPosition()] = par;
439 static FailureOr<ContractionDimensions>
442 llvm::SmallDenseSet<int64_t> a =
444 llvm::SmallDenseSet<int64_t> b =
446 llvm::SmallDenseSet<int64_t> c =
450 llvm::SmallDenseSet<int64_t> ac = a;
451 llvm::set_intersect(ac, c);
452 llvm::set_subtract(ac, b);
454 llvm::SmallDenseSet<int64_t> bc = b;
455 llvm::set_intersect(bc, c);
456 llvm::set_subtract(bc, a);
458 llvm::SmallDenseSet<int64_t> batches = a;
459 llvm::set_intersect(batches, b);
460 llvm::set_intersect(batches, c);
463 llvm::SmallDenseSet<int64_t> ra =
465 llvm::SmallDenseSet<int64_t> rb =
467 llvm::set_intersect(ra, rb);
475 llvm::sort(dimensions.batch);
476 llvm::sort(dimensions.m);
477 llvm::sort(dimensions.n);
478 llvm::sort(dimensions.k);
482 FailureOr<ContractionDimensions>
484 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
487 linalgOp.getIteratorTypesArray());
490 FailureOr<ContractionDimensions>
492 if (indexingMaps.size() != 3)
514 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
516 return MatchContractionResult::NotLinalgOp;
517 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
518 return MatchContractionResult::WrongNumOperands;
519 auto mapRange = linalgOp.getIndexingMapsArray();
520 if (linalgOp.getNumReductionLoops() == 0)
521 return MatchContractionResult::NoReduction;
522 if (llvm::any_of(mapRange,
524 return MatchContractionResult::NotProjectedPermutations;
528 arith::MulFOp, arith::AddFOp,
529 arith::MulIOp, arith::AddIOp,
530 complex::MulOp, complex::AddOp,
531 arith::AndIOp, arith::OrIOp>(
532 *linalgOp.getBlock())) {
533 return MatchContractionResult::NotAddMul;
539 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
542 return MatchContractionResult::Success;
548 case MatchContractionResult::NotLinalgOp:
549 return "expected a LinalgOp";
550 case MatchContractionResult::WrongNumOperands:
551 return "expected op with 2 inputs and 1 output";
552 case MatchContractionResult::NoReduction:
553 return "expected at least 1 reduction";
554 case MatchContractionResult::NotProjectedPermutations:
555 return "expected indexing maps to be projected permutations";
556 case MatchContractionResult::NotAddMul:
557 return "expected add/mul op in the body";
558 case MatchContractionResult::Success:
561 llvm_unreachable(
"unhandled MatchContractionResult case");
568 return isa<ContractionOpInterface>(op) ||
588 if (res != MatchContractionResult::Success)
599 template <
typename T>
601 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) :
nullptr);
613 struct ConvAccessExprWalker
616 llvm::SmallDenseSet<int64_t> convolvedDims;
618 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
620 llvm::SmallDenseSet<int64_t> unConvolvedDims;
622 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
627 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
629 return e.isFunctionOfDim(dimPos);
631 convolvedDims.erase(dimPos);
632 unConvolvedDims.erase(dimPos);
635 auto it = convolvedDimMapping.find(dimPos);
636 if (it != convolvedDimMapping.end()) {
637 int64_t pairedDim = it->second;
638 convolvedDims.erase(pairedDim);
639 unConvolvedDims.erase(pairedDim);
640 strideAndDilationMapping.erase(pairedDim);
641 convolvedDimMapping.erase(dimPos);
642 convolvedDimMapping.erase(pairedDim);
650 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
653 unConvolvedDims.insert(position);
665 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
666 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
669 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
670 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
674 FailureOr<int64_t> getDimExprOrMulExprDimPos(
AffineExpr expr) {
675 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
677 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
680 strideAndDilationMapping[dim] =
682 convolvedDims.insert(dim);
685 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
688 auto lhsExpr = symbolMulExpr.getLHS();
689 auto rhsExpr = symbolMulExpr.getRHS();
692 getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
695 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
697 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
698 if (!mulExpr || !dimExpr)
701 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
703 strideAndDilationMapping[dim] = mulExpr;
704 convolvedDims.insert(dim);
714 "expected map to have projected permutations");
715 llvm::SmallDenseSet<int64_t> preservedDims;
717 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
718 return preservedDims;
724 for (
auto e : exprs) {
725 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
726 assert(constantExpr &&
"Found non-constant stride/dilation");
727 vals.push_back(constantExpr.getValue());
739 static FailureOr<ConvolutionDimensions>
741 ConvAccessExprWalker &inputExprWalker,
742 bool allowEmptyConvolvedDims) {
744 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
746 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
748 filterMap, linalgOp.getIteratorTypesArray(), par);
750 outputMap, linalgOp.getIteratorTypesArray(), par);
753 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
754 llvm::set_intersect(batch, outputDims);
755 llvm::set_subtract(batch, filterDims);
758 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
759 llvm::set_intersect(oi, outputDims);
762 llvm::SmallDenseSet<int64_t> oc = filterDims;
763 llvm::set_intersect(oc, outputDims);
764 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
767 llvm::SmallDenseSet<int64_t> depth = filterDims;
768 llvm::set_intersect(depth, outputDims);
769 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
771 llvm::SmallDenseSet<int64_t> filterReducedDims =
773 linalgOp.getIteratorTypesArray(), red);
776 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
777 llvm::set_intersect(fl, filterReducedDims);
780 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
781 llvm::set_intersect(ic, filterReducedDims);
783 if (oi.empty() && !allowEmptyConvolvedDims)
796 llvm::sort(dimensions.batch);
797 llvm::sort(dimensions.outputImage);
798 llvm::sort(dimensions.outputChannel);
799 llvm::sort(dimensions.filterLoop);
800 llvm::sort(dimensions.inputChannel);
801 llvm::sort(dimensions.depth);
805 if (!nativeStrides) {
807 for (
unsigned oiDim : dimensions.outputImage)
808 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
811 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
813 auto nativeDilations =
815 if (!nativeDilations) {
817 for (
unsigned flDim : dimensions.filterLoop)
818 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
821 dimensions.dilations =
822 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
851 FailureOr<ConvolutionDimensions>
853 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
856 auto indexingMaps = linalgOp.getIndexingMapsArray();
859 ConvAccessExprWalker inputExprWalker;
860 for (
AffineExpr expr : indexingMaps[0].getResults())
861 (void)inputExprWalker.visit(expr);
862 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
885 bool allowEmptyConvolvedDims) {
886 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
888 return MatchConvolutionResult::NotLinalgOp;
889 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
890 return MatchConvolutionResult::WrongNumOperands;
892 auto indexingMaps = linalgOp.getIndexingMapsArray();
895 ConvAccessExprWalker inputExprWalker;
896 if (llvm::any_of(indexingMaps[0].getResults(),
898 return failed(inputExprWalker.visit(expr));
900 return MatchConvolutionResult::WrongInputIndexingMap;
904 if (!indexingMaps[1].isProjectedPermutation() ||
905 !indexingMaps.back().isProjectedPermutation())
906 return MatchConvolutionResult::NotProjectedPermutations;
908 auto iteratorTypes = linalgOp.getIteratorTypesArray();
910 llvm::SmallDenseSet<int64_t> outputDims =
912 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
926 llvm::SmallDenseSet<int64_t> allLoopDims;
927 for (
auto outputExpr : indexingMaps.back().getResults()) {
928 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
929 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
930 !filterDims.count(outputDim)) {
932 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
933 return MatchConvolutionResult::OutputDimsNotParallel;
934 allLoopDims.insert(outputDim);
937 if (inputExprWalker.convolvedDims.count(outputDim) &&
938 !filterDims.count(outputDim)) {
940 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
941 return MatchConvolutionResult::OutputDimsNotParallel;
942 allLoopDims.insert(outputDim);
945 if (!inputExprWalker.convolvedDims.count(outputDim) &&
946 !inputExprWalker.unConvolvedDims.count(outputDim) &&
947 filterDims.count(outputDim)) {
949 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
950 return MatchConvolutionResult::OutputDimsNotParallel;
951 allLoopDims.insert(outputDim);
954 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
955 filterDims.count(outputDim)) {
957 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
958 return MatchConvolutionResult::OutputDimsNotParallel;
959 allLoopDims.insert(outputDim);
962 return MatchConvolutionResult::NonConvolutionLoop;
964 for (
auto filterExpr : indexingMaps[1].getResults()) {
965 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
966 if (outputDims.count(filterDim) &&
967 !inputExprWalker.unConvolvedDims.count(filterDim) &&
968 !inputExprWalker.convolvedDims.count(filterDim)) {
972 if (inputExprWalker.convolvedDims.count(filterDim) &&
973 !outputDims.count(filterDim)) {
975 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
976 return MatchConvolutionResult::NonOutputDimNotReduction;
977 if (allLoopDims.count(filterDim))
978 return MatchConvolutionResult::NonConvolutionLoop;
979 allLoopDims.insert(filterDim);
982 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
983 !outputDims.count(filterDim)) {
985 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
986 return MatchConvolutionResult::NonOutputDimNotReduction;
987 if (allLoopDims.count(filterDim))
988 return MatchConvolutionResult::NonConvolutionLoop;
989 allLoopDims.insert(filterDim);
992 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
993 outputDims.count(filterDim)) {
997 return MatchConvolutionResult::NonConvolutionLoop;
1000 if (allLoopDims.size() != linalgOp.getNumLoops())
1001 return MatchConvolutionResult::NonConvolutionLoop;
1003 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1004 return MatchConvolutionResult::EmptyConvolvedDims;
1008 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1009 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
1013 return MatchConvolutionResult::Success;
1019 case MatchConvolutionResult::NotLinalgOp:
1020 return "expected a LinalgOp";
1021 case MatchConvolutionResult::WrongNumOperands:
1022 return "expected op with 2 inputs and 1 output";
1023 case MatchConvolutionResult::WrongInputIndexingMap:
1024 return "unexpected input index map for convolutions";
1025 case MatchConvolutionResult::NotProjectedPermutations:
1026 return "expected output/filter indexing maps to be projected permutations";
1027 case MatchConvolutionResult::NonConvolutionLoop:
1028 return "unexpected loop dimension for convolution op";
1029 case MatchConvolutionResult::OutputDimsNotParallel:
1030 return "expected all iterators used to access outputs to be parallel";
1031 case MatchConvolutionResult::NonOutputDimNotReduction:
1032 return "expected all iterators not used to access outputs to be reduction";
1033 case MatchConvolutionResult::EmptyConvolvedDims:
1034 return "expected convolved dim to be non-empty";
1035 case MatchConvolutionResult::Success:
1038 llvm_unreachable(
"unhandled MatchConvolutionResult case");
1042 bool allowEmptyConvolvedDims) {
1044 linalgOp.getOperation(),
nullptr, allowEmptyConvolvedDims) ==
1050 if (res != MatchConvolutionResult::Success)
1067 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1070 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1073 OpOperand *value = linalgOp.getDpsInputOperand(0);
1074 if (!linalgOp.isScalar(value))
1083 return op->
emitError(
"expected a LinalgOp");
1085 return op->
emitError(
"expected op with 1 input and 1 output");
1087 return op->
emitError(
"expected op with scalar input");
1099 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
1100 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1108 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
1109 for (
OpOperand &opOperand : getOperation()->getOpOperands())
1110 llvm::append_range(res,
getShape(&opOperand));
1117 auto viewSizes = createFlatListOfOperandDims(b, loc);
1119 for (
unsigned idx = 0; idx < numRes; ++idx) {
1121 if (
auto d = dyn_cast<AffineDimExpr>(result)) {
1122 if (res[d.getPosition()].offset)
1124 res[d.getPosition()] =
1136 : positions(std::move(positions)) {}
1151 llvm::SmallBitVector positions;
1154 static std::pair<int64_t, int64_t>
1156 int64_t inputRankSum = 0;
1157 int64_t outputRankSum = 0;
1158 for (
OpOperand *input : op.getDpsInputOperands())
1159 inputRankSum += op.getRank(input);
1160 for (
OpOperand &output : op.getDpsInitsMutable())
1161 outputRankSum += op.getRank(&output);
1162 return {inputRankSum, inputRankSum + outputRankSum};
1177 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1186 resultShapesSubMapPos.first,
1187 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1188 AffineMap resultShapesFromInputShapesMap =
1189 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1193 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1194 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1196 Location loc = getOperation()->getLoc();
1200 rewriter, loc, resultShapesFromInputShapesMap,
1201 createFlatListOfOperandDims(b, loc));
1204 for (
OpOperand &opOperand : getDpsInitsMutable()) {
1206 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1207 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1208 if (!shapedType.isDynamicDim(dim)) {
1210 shapes.push_back(b.
getIndexAttr(shapedType.getDimSize(dim)));
1215 : allResultDimValues[pos];
1220 reifiedReturnShapes.emplace_back(std::move(shapes));
1227 int64_t LinalgOp::getIndexingMapIndex(
OpOperand *opOperand) {
1229 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1230 if (!dpsIface.isDpsInput(opOperand))
1231 return operandNumber;
1232 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1233 assert(!dpsIface.isDpsInit(opOperand));
1236 return cast<DestinationStyleOpInterface>(*this->getOperation())
1237 .getNumDpsInputs() +
1238 operandNumber - start;
1242 LinalgOp linalgOp = cast<LinalgOp>(op);
1244 if (!linalgOp.hasPureTensorSemantics() &&
1246 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1250 if (linalgOp.hasDynamicIndexingMaps())
1251 if (
failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1255 if (
failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1260 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1261 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1263 unsigned numLoops = linalgOp.getNumLoops();
1267 <<
" dim(s) to match the number of loops";
1270 linalgOp.getReductionDims(redDims);
1272 if (!linalgOp.getShapesToLoopsMap())
1273 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1276 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1277 return op->
emitOpError(
"expects to have 1 region with 1 block");
1285 Block &block = linalgOp->getRegion(0).
front();
1287 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1288 return op->
emitOpError(
"expected as many non-induction variable region "
1289 "arguments as the number of input/output operands");
1291 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1293 if (isa<MemRefType, RankedTensorType>(elementType))
1296 if (elementType != argType)
1297 return op->
emitOpError(
"expected type of bb argument #")
1299 <<
" to match element or self type of the corresponding operand ("
1300 << elementType <<
")";
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.