22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetOperations.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/raw_ostream.h"
34 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
43 for (
auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
48 if (indexingMaps.empty()) {
51 return linalgOp.getNumLoops() == 0;
54 indexingMaps, linalgOp.getContext())) !=
AffineMap();
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
72 return llvm::hasSingleElement(op.getBlock()->getOperations());
82 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
83 op.getNumDpsInputs() != 0)
87 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
90 Block *body = op.getBody();
94 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
95 if (!yieldOp || yieldOp.getNumOperands() != 1)
98 Value yieldOperand = yieldOp->getOperand(0);
110 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
111 !op.isSingleYieldOp())
115 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
116 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
119 OpOperand *value = op.getDpsInputOperand(0);
120 if (!op.isScalar(value))
134 std::optional<SmallVector<int64_t>>
137 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
138 !op.isSingleYieldOp())
141 auto srcTy = op.getDpsInputOperand(0)->get().getType();
142 auto dstTy = op.getDpsInitOperand(0)->get().getType();
143 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
144 !isa<MemRefType, RankedTensorType>(dstTy))
150 auto dstMap = op.getIndexingMapsArray()[1];
151 if (!dstMap.isIdentity())
155 auto srcMap = op.getIndexingMapsArray()[0];
157 if (srcMap.getResults().size() >= dstMap.getResults().size())
161 for (
unsigned i = 0; i < srcMap.getNumResults(); ++i) {
162 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
165 int64_t pos = expr.getPosition();
166 if (i > 0 && pos <= position[i - 1])
168 position.push_back(expr.getPosition());
172 auto numDims = srcMap.getNumDims();
174 for (
auto dim : llvm::seq<int64_t>(0, numDims)) {
175 if (!llvm::is_contained(position, dim))
176 broadcastedDims.push_back(dim);
178 return broadcastedDims;
184 std::optional<SmallVector<int64_t>>
189 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
190 !op.isSingleYieldOp())
193 auto mapRange = op.getIndexingMapsArray();
194 if (mapRange.size() != 2)
197 auto mapOfInput = mapRange.front();
198 auto mapOfResult = mapRange.back();
202 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
206 for (
unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
207 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
208 permutation[expr.getPosition()] = i;
219 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
223 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
224 !llvm::all_of(op.getIndexingMapsArray(),
225 [](
AffineMap map) { return map.isIdentity(); }))
229 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
236 Block *body = op.getBody();
244 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
245 if (!yieldOp || yieldOp.getNumOperands() != 1 ||
246 yieldOp->getOperand(0).getDefiningOp() != oper)
257 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
267 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
268 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
269 if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
270 !op.payloadUsesValueFromOperand(inputOpOperand1))
286 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
287 if (!iface || !iface.hasNoEffect())
297 llvm::raw_ostream &errs) {
299 errs <<
"no terminator in the block";
304 errs <<
"expected block with 3 arguments";
310 errs <<
"expected terminator with 1 operand";
317 errs <<
"expected reduction op to be binary";
326 errs <<
"expected reduction to take block argument #2 as one of the "
327 "operands (modulo unary casts)";
332 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
336 errs <<
"expected elementwise op to be binary";
340 if (!isaPair(elementwiseOp, reductionOp)) {
341 errs <<
"expected reduction/elementwise op kind not satisfied";
354 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
361 template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
363 static_assert(
sizeof...(Args) % 2 == 0,
364 "expected an even number of template arguments");
365 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
368 if constexpr (
sizeof...(Args) > 0)
376 template <
typename... Args>
388 static llvm::SmallDenseSet<int64_t>
391 utils::IteratorType iter) {
392 assert(iterators.size() == indexingMap.
getNumDims());
393 llvm::SmallDenseSet<int64_t> res;
395 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
396 if (iterators[d.getPosition()] == iter &&
398 return e.isFunctionOfDim(d.getPosition());
400 res.insert(d.getPosition());
407 auto par = utils::IteratorType::parallel;
408 auto red = utils::IteratorType::reduction;
415 static FailureOr<SmallVector<utils::IteratorType>>
421 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
422 iterators[dim.getPosition()] = par;
437 static FailureOr<ContractionDimensions>
440 llvm::SmallDenseSet<int64_t> a =
442 llvm::SmallDenseSet<int64_t> b =
444 llvm::SmallDenseSet<int64_t> c =
448 llvm::SmallDenseSet<int64_t> ac = a;
449 llvm::set_intersect(ac, c);
450 llvm::set_subtract(ac, b);
452 llvm::SmallDenseSet<int64_t> bc = b;
453 llvm::set_intersect(bc, c);
454 llvm::set_subtract(bc, a);
456 llvm::SmallDenseSet<int64_t> batches = a;
457 llvm::set_intersect(batches, b);
458 llvm::set_intersect(batches, c);
461 llvm::SmallDenseSet<int64_t> ra =
463 llvm::SmallDenseSet<int64_t> rb =
465 llvm::set_intersect(ra, rb);
473 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
474 llvm::sort(dimensions.m.begin(), dimensions.m.end());
475 llvm::sort(dimensions.n.begin(), dimensions.n.end());
476 llvm::sort(dimensions.k.begin(), dimensions.k.end());
480 FailureOr<ContractionDimensions>
482 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
485 linalgOp.getIteratorTypesArray());
488 FailureOr<ContractionDimensions>
490 if (indexingMaps.size() != 3)
493 if (failed(iterators))
512 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
514 return MatchContractionResult::NotLinalgOp;
515 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
516 return MatchContractionResult::WrongNumOperands;
517 auto mapRange = linalgOp.getIndexingMapsArray();
518 if (linalgOp.getNumReductionLoops() == 0)
519 return MatchContractionResult::NoReduction;
520 if (llvm::any_of(mapRange,
522 return MatchContractionResult::NotProjectedPermutations;
526 arith::MulFOp, arith::AddFOp,
527 arith::MulIOp, arith::AddIOp,
528 complex::MulOp, complex::AddOp,
529 arith::AndIOp, arith::OrIOp>(
530 *linalgOp.getBlock())) {
531 return MatchContractionResult::NotAddMul;
537 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
540 return MatchContractionResult::Success;
546 case MatchContractionResult::NotLinalgOp:
547 return "expected a LinalgOp";
548 case MatchContractionResult::WrongNumOperands:
549 return "expected op with 2 inputs and 1 output";
550 case MatchContractionResult::NoReduction:
551 return "expected at least 1 reduction";
552 case MatchContractionResult::NotProjectedPermutations:
553 return "expected indexing maps to be projected permutations";
554 case MatchContractionResult::NotAddMul:
555 return "expected add/mul op in the body";
556 case MatchContractionResult::Success:
559 llvm_unreachable(
"unhandled MatchContractionResult case");
566 return isa<ContractionOpInterface>(op) ||
586 if (res != MatchContractionResult::Success)
597 template <
typename T>
599 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) :
nullptr);
611 struct ConvAccessExprWalker
614 llvm::SmallDenseSet<int64_t> convolvedDims;
616 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
618 llvm::SmallDenseSet<int64_t> unConvolvedDims;
620 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
625 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
627 return e.isFunctionOfDim(dimPos);
629 convolvedDims.erase(dimPos);
630 unConvolvedDims.erase(dimPos);
633 auto it = convolvedDimMapping.find(dimPos);
634 if (it != convolvedDimMapping.end()) {
635 int64_t pairedDim = it->second;
636 convolvedDims.erase(pairedDim);
637 unConvolvedDims.erase(pairedDim);
638 strideAndDilationMapping.erase(pairedDim);
639 convolvedDimMapping.erase(dimPos);
640 convolvedDimMapping.erase(pairedDim);
648 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
651 unConvolvedDims.insert(position);
663 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
664 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
665 if (failed(lhsDimPos) || failed(rhsDimPos))
667 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
668 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
672 FailureOr<int64_t> getDimExprOrMulExprDimPos(
AffineExpr expr) {
673 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
675 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
678 strideAndDilationMapping[dim] =
680 convolvedDims.insert(dim);
683 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
686 auto lhsExpr = symbolMulExpr.getLHS();
687 auto rhsExpr = symbolMulExpr.getRHS();
690 getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
693 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
695 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
696 if (!mulExpr || !dimExpr)
699 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
701 strideAndDilationMapping[dim] = mulExpr;
702 convolvedDims.insert(dim);
712 "expected map to have projected permutations");
713 llvm::SmallDenseSet<int64_t> preservedDims;
715 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
716 return preservedDims;
722 for (
auto e : exprs) {
723 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
724 assert(constantExpr &&
"Found non-constant stride/dilation");
725 vals.push_back(constantExpr.getValue());
737 static FailureOr<ConvolutionDimensions>
739 ConvAccessExprWalker &inputExprWalker,
740 bool allowEmptyConvolvedDims) {
742 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
744 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
746 filterMap, linalgOp.getIteratorTypesArray(), par);
748 outputMap, linalgOp.getIteratorTypesArray(), par);
751 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
752 llvm::set_intersect(batch, outputDims);
753 llvm::set_subtract(batch, filterDims);
756 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
757 llvm::set_intersect(oi, outputDims);
760 llvm::SmallDenseSet<int64_t> oc = filterDims;
761 llvm::set_intersect(oc, outputDims);
762 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
765 llvm::SmallDenseSet<int64_t> depth = filterDims;
766 llvm::set_intersect(depth, outputDims);
767 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
769 llvm::SmallDenseSet<int64_t> filterReducedDims =
771 linalgOp.getIteratorTypesArray(), red);
774 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
775 llvm::set_intersect(fl, filterReducedDims);
778 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
779 llvm::set_intersect(ic, filterReducedDims);
781 if (oi.empty() && !allowEmptyConvolvedDims)
794 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
795 llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
796 llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
797 llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
798 llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
799 llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
803 if (!nativeStrides) {
805 for (
unsigned oiDim : dimensions.outputImage)
806 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
809 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
811 auto nativeDilations =
813 if (!nativeDilations) {
815 for (
unsigned flDim : dimensions.filterLoop)
816 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
819 dimensions.dilations =
820 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
849 FailureOr<ConvolutionDimensions>
851 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
854 auto indexingMaps = linalgOp.getIndexingMapsArray();
857 ConvAccessExprWalker inputExprWalker;
858 for (
AffineExpr expr : indexingMaps[0].getResults())
859 (void)inputExprWalker.visit(expr);
860 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
883 bool allowEmptyConvolvedDims) {
884 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
886 return MatchConvolutionResult::NotLinalgOp;
887 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
888 return MatchConvolutionResult::WrongNumOperands;
890 auto indexingMaps = linalgOp.getIndexingMapsArray();
893 ConvAccessExprWalker inputExprWalker;
894 if (llvm::any_of(indexingMaps[0].getResults(),
896 return failed(inputExprWalker.visit(expr));
898 return MatchConvolutionResult::WrongInputIndexingMap;
902 if (!indexingMaps[1].isProjectedPermutation() ||
903 !indexingMaps.back().isProjectedPermutation())
904 return MatchConvolutionResult::NotProjectedPermutations;
906 auto iteratorTypes = linalgOp.getIteratorTypesArray();
908 llvm::SmallDenseSet<int64_t> outputDims =
910 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
924 llvm::SmallDenseSet<int64_t> allLoopDims;
925 for (
auto outputExpr : indexingMaps.back().getResults()) {
926 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
927 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
928 !filterDims.count(outputDim)) {
930 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
931 return MatchConvolutionResult::OutputDimsNotParallel;
932 allLoopDims.insert(outputDim);
935 if (inputExprWalker.convolvedDims.count(outputDim) &&
936 !filterDims.count(outputDim)) {
938 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
939 return MatchConvolutionResult::OutputDimsNotParallel;
940 allLoopDims.insert(outputDim);
943 if (!inputExprWalker.convolvedDims.count(outputDim) &&
944 !inputExprWalker.unConvolvedDims.count(outputDim) &&
945 filterDims.count(outputDim)) {
947 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
948 return MatchConvolutionResult::OutputDimsNotParallel;
949 allLoopDims.insert(outputDim);
952 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
953 filterDims.count(outputDim)) {
955 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
956 return MatchConvolutionResult::OutputDimsNotParallel;
957 allLoopDims.insert(outputDim);
960 return MatchConvolutionResult::NonConvolutionLoop;
962 for (
auto filterExpr : indexingMaps[1].getResults()) {
963 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
964 if (outputDims.count(filterDim) &&
965 !inputExprWalker.unConvolvedDims.count(filterDim) &&
966 !inputExprWalker.convolvedDims.count(filterDim)) {
970 if (inputExprWalker.convolvedDims.count(filterDim) &&
971 !outputDims.count(filterDim)) {
973 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
974 return MatchConvolutionResult::NonOutputDimNotReduction;
975 if (allLoopDims.count(filterDim))
976 return MatchConvolutionResult::NonConvolutionLoop;
977 allLoopDims.insert(filterDim);
980 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
981 !outputDims.count(filterDim)) {
983 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
984 return MatchConvolutionResult::NonOutputDimNotReduction;
985 if (allLoopDims.count(filterDim))
986 return MatchConvolutionResult::NonConvolutionLoop;
987 allLoopDims.insert(filterDim);
990 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
991 outputDims.count(filterDim)) {
995 return MatchConvolutionResult::NonConvolutionLoop;
998 if (allLoopDims.size() != linalgOp.getNumLoops())
999 return MatchConvolutionResult::NonConvolutionLoop;
1001 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1002 return MatchConvolutionResult::EmptyConvolvedDims;
1006 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1007 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
1011 return MatchConvolutionResult::Success;
1017 case MatchConvolutionResult::NotLinalgOp:
1018 return "expected a LinalgOp";
1019 case MatchConvolutionResult::WrongNumOperands:
1020 return "expected op with 2 inputs and 1 output";
1021 case MatchConvolutionResult::WrongInputIndexingMap:
1022 return "unexpected input index map for convolutions";
1023 case MatchConvolutionResult::NotProjectedPermutations:
1024 return "expected output/filter indexing maps to be projected permutations";
1025 case MatchConvolutionResult::NonConvolutionLoop:
1026 return "unexpected loop dimension for convolution op";
1027 case MatchConvolutionResult::OutputDimsNotParallel:
1028 return "expected all iterators used to access outputs to be parallel";
1029 case MatchConvolutionResult::NonOutputDimNotReduction:
1030 return "expected all iterators not used to access outputs to be reduction";
1031 case MatchConvolutionResult::EmptyConvolvedDims:
1032 return "expected convolved dim to be non-empty";
1033 case MatchConvolutionResult::Success:
1036 llvm_unreachable(
"unhandled MatchConvolutionResult case");
1040 bool allowEmptyConvolvedDims) {
1042 linalgOp.getOperation(),
nullptr, allowEmptyConvolvedDims) ==
1048 if (res != MatchConvolutionResult::Success)
1065 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1068 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1071 OpOperand *value = linalgOp.getDpsInputOperand(0);
1072 if (!linalgOp.isScalar(value))
1081 return op->
emitError(
"expected a LinalgOp");
1083 return op->
emitError(
"expected op with 1 input and 1 output");
1085 return op->
emitError(
"expected op with scalar input");
1097 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
1098 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1106 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
1107 for (
OpOperand &opOperand : getOperation()->getOpOperands())
1108 llvm::append_range(res,
getShape(&opOperand));
1115 auto viewSizes = createFlatListOfOperandDims(b, loc);
1117 for (
unsigned idx = 0; idx < numRes; ++idx) {
1119 if (
auto d = dyn_cast<AffineDimExpr>(result)) {
1120 if (res[d.getPosition()].offset)
1122 res[d.getPosition()] =
1134 : positions(std::move(positions)) {}
1149 llvm::SmallBitVector positions;
1152 static std::pair<int64_t, int64_t>
1154 int64_t inputRankSum = 0;
1155 int64_t outputRankSum = 0;
1156 for (
OpOperand *input : op.getDpsInputOperands())
1157 inputRankSum += op.getRank(input);
1158 for (
OpOperand &output : op.getDpsInitsMutable())
1159 outputRankSum += op.getRank(&output);
1160 return {inputRankSum, inputRankSum + outputRankSum};
1175 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1184 resultShapesSubMapPos.first,
1185 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1186 AffineMap resultShapesFromInputShapesMap =
1187 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1191 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1192 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1194 Location loc = getOperation()->getLoc();
1198 rewriter, loc, resultShapesFromInputShapesMap,
1199 createFlatListOfOperandDims(b, loc));
1202 for (
OpOperand &opOperand : getDpsInitsMutable()) {
1204 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1205 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1206 if (!shapedType.isDynamicDim(dim)) {
1208 shapes.push_back(b.
getIndexAttr(shapedType.getDimSize(dim)));
1213 : allResultDimValues[pos];
1218 reifiedReturnShapes.emplace_back(std::move(shapes));
1225 int64_t LinalgOp::getIndexingMapIndex(
OpOperand *opOperand) {
1227 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1228 if (!dpsIface.isDpsInput(opOperand))
1229 return operandNumber;
1230 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1231 assert(!dpsIface.isDpsInit(opOperand));
1234 return cast<DestinationStyleOpInterface>(*this->getOperation())
1235 .getNumDpsInputs() +
1236 operandNumber - start;
1240 LinalgOp linalgOp = cast<LinalgOp>(op);
1242 if (!linalgOp.hasPureTensorSemantics() &&
1244 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1248 if (linalgOp.hasDynamicIndexingMaps())
1249 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1253 if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1258 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1259 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1261 unsigned numLoops = linalgOp.getNumLoops();
1265 <<
" dim(s) to match the number of loops";
1268 linalgOp.getReductionDims(redDims);
1270 if (!linalgOp.getShapesToLoopsMap())
1271 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1274 if (linalgOp->getNumRegions() != 1 ||
1275 !llvm::hasSingleElement(linalgOp->getRegion(0)))
1276 return op->
emitOpError(
"expects to have 1 region with 1 block");
1284 Block &block = linalgOp->getRegion(0).
front();
1286 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1287 return op->
emitOpError(
"expected as many non-induction variable region "
1288 "arguments as the number of input/output operands");
1290 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1292 if (isa<MemRefType, RankedTensorType>(elementType))
1295 if (elementType != argType)
1296 return op->
emitOpError(
"expected type of bb argument #")
1298 <<
" to match element or self type of the corresponding operand ("
1299 << elementType <<
")";
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.