21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 #include "llvm/ADT/SmallVector.h"
30 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
39 for (
auto &opOperand : linalgOp->getOpOperands()) {
40 if (llvm::is_contained(droppedOperands, &opOperand))
42 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
44 if (indexingMaps.empty()) {
47 return linalgOp.getNumLoops() == 0;
58 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
62 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
64 auto mapRange = linalgOp.getIndexingMapsArray();
65 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
66 !mapRange.back().isIdentity()) {
70 return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
78 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
79 genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
83 if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
84 genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
87 OpOperand *value = genericOp.getDpsInputOperand(0);
88 if (!genericOp.isScalar(value))
91 Block *body = genericOp.getBody();
95 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
96 if (!yieldOp || yieldOp.getNumOperands() != 1 ||
109 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
110 genericOp.getNumLoops() < 1)
114 if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
115 !llvm::all_of(genericOp.getIndexingMapsArray(),
116 [](
AffineMap map) { return map.isIdentity(); }))
120 if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
127 Block *body = genericOp.getBody();
135 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
136 if (!yieldOp || yieldOp.getNumOperands() != 1 ||
137 yieldOp->getOperand(0).getDefiningOp() != op)
148 if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
158 OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
159 OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
160 if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
161 !genericOp.payloadUsesValueFromOperand(inputOpOperand1))
177 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
178 if (!iface || !iface.hasNoEffect())
188 llvm::raw_ostream &errs) {
190 errs <<
"no terminator in the block";
195 errs <<
"expected block with 3 arguments";
201 errs <<
"expected terminator with 1 operand";
208 errs <<
"expected reduction op to be binary";
217 errs <<
"expected reduction to take block argument #2 as one of the "
218 "operands (modulo unary casts)";
223 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
227 errs <<
"expected elementwise op to be binary";
231 if (!isaPair(elementwiseOp, reductionOp)) {
232 errs <<
"expected reduction/elementwise op kind not satisfied";
245 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
252 template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
254 static_assert(
sizeof...(Args) % 2 == 0,
255 "expected an even number of template arguments");
256 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
259 if constexpr (
sizeof...(Args) > 0)
267 template <
typename... Args>
279 static llvm::SmallDenseSet<int64_t>
282 utils::IteratorType iter) {
283 assert(iterators.size() == indexingMap.
getNumDims());
284 llvm::SmallDenseSet<int64_t> res;
286 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
287 if (iterators[d.getPosition()] == iter &&
289 return e.isFunctionOfDim(d.getPosition());
291 res.insert(d.getPosition());
298 auto par = utils::IteratorType::parallel;
299 auto red = utils::IteratorType::reduction;
306 static FailureOr<SmallVector<utils::IteratorType>>
312 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
313 iterators[dim.getPosition()] = par;
328 static FailureOr<ContractionDimensions>
331 llvm::SmallDenseSet<int64_t> a =
333 llvm::SmallDenseSet<int64_t> b =
335 llvm::SmallDenseSet<int64_t> c =
339 llvm::SmallDenseSet<int64_t> ac = a;
340 llvm::set_intersect(ac, c);
341 llvm::set_subtract(ac, b);
343 llvm::SmallDenseSet<int64_t> bc = b;
344 llvm::set_intersect(bc, c);
345 llvm::set_subtract(bc, a);
347 llvm::SmallDenseSet<int64_t> batches = a;
348 llvm::set_intersect(batches, b);
349 llvm::set_intersect(batches, c);
352 llvm::SmallDenseSet<int64_t> ra =
354 llvm::SmallDenseSet<int64_t> rb =
356 llvm::set_intersect(ra, rb);
364 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
365 llvm::sort(dimensions.m.begin(), dimensions.m.end());
366 llvm::sort(dimensions.n.begin(), dimensions.n.end());
367 llvm::sort(dimensions.k.begin(), dimensions.k.end());
371 FailureOr<ContractionDimensions>
373 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
376 linalgOp.getIteratorTypesArray());
379 FailureOr<ContractionDimensions>
381 if (indexingMaps.size() != 3)
384 if (failed(iterators))
403 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
405 return MatchContractionResult::NotLinalgOp;
406 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
407 return MatchContractionResult::WrongNumOperands;
408 auto mapRange = linalgOp.getIndexingMapsArray();
409 if (linalgOp.getNumReductionLoops() == 0)
410 return MatchContractionResult::NoReduction;
411 if (llvm::any_of(mapRange,
412 [](
AffineMap m) {
return !m.isProjectedPermutation(); }))
413 return MatchContractionResult::NotProjectedPermutations;
417 arith::MulFOp, arith::AddFOp,
418 arith::MulIOp, arith::AddIOp,
419 complex::MulOp, complex::AddOp,
420 arith::AndIOp, arith::OrIOp>(
421 *linalgOp.getBlock())) {
422 return MatchContractionResult::NotAddMul;
428 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
431 return MatchContractionResult::Success;
437 case MatchContractionResult::NotLinalgOp:
438 return "expected a LinalgOp";
439 case MatchContractionResult::WrongNumOperands:
440 return "expected op with 2 inputs and 1 output";
441 case MatchContractionResult::NoReduction:
442 return "expected at least 1 reduction";
443 case MatchContractionResult::NotProjectedPermutations:
444 return "expected indexing maps to be projected permutations";
445 case MatchContractionResult::NotAddMul:
446 return "expected add/mul op in the body";
447 case MatchContractionResult::Success:
450 llvm_unreachable(
"unhandled MatchContractionResult case");
457 return isa<ContractionOpInterface>(op) ||
477 if (res != MatchContractionResult::Success)
488 template <
typename T>
490 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) :
nullptr);
502 struct ConvAccessExprWalker
505 llvm::SmallDenseSet<int64_t> convolvedDims;
507 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
509 llvm::SmallDenseSet<int64_t> unConvolvedDims;
511 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
516 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
518 return e.isFunctionOfDim(dimPos);
520 convolvedDims.erase(dimPos);
521 unConvolvedDims.erase(dimPos);
524 auto it = convolvedDimMapping.find(dimPos);
525 if (it != convolvedDimMapping.end()) {
526 int64_t pairedDim = it->second;
527 convolvedDims.erase(pairedDim);
528 unConvolvedDims.erase(pairedDim);
529 strideAndDilationMapping.erase(pairedDim);
530 convolvedDimMapping.erase(dimPos);
531 convolvedDimMapping.erase(pairedDim);
539 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
542 unConvolvedDims.insert(position);
554 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
555 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
556 if (failed(lhsDimPos) || failed(rhsDimPos))
558 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
559 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
563 FailureOr<int64_t> getDimExprOrMulExprDimPos(
AffineExpr expr) {
564 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
566 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
569 strideAndDilationMapping[dim] =
571 convolvedDims.insert(dim);
574 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
577 auto lhsExpr = symbolMulExpr.getLHS();
578 auto rhsExpr = symbolMulExpr.getRHS();
581 getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
584 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
586 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
587 if (!mulExpr || !dimExpr)
590 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
592 strideAndDilationMapping[dim] = mulExpr;
593 convolvedDims.insert(dim);
603 "expected map to have projected permutations");
604 llvm::SmallDenseSet<int64_t> preservedDims;
606 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
607 return preservedDims;
613 for (
auto e : exprs) {
614 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
615 assert(constantExpr &&
"Found non-constant stride/dilation");
616 vals.push_back(constantExpr.getValue());
628 static FailureOr<ConvolutionDimensions>
630 ConvAccessExprWalker &inputExprWalker,
631 bool allowEmptyConvolvedDims) {
633 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
635 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
637 filterMap, linalgOp.getIteratorTypesArray(), par);
639 outputMap, linalgOp.getIteratorTypesArray(), par);
642 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
643 llvm::set_intersect(batch, outputDims);
644 llvm::set_subtract(batch, filterDims);
647 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
648 llvm::set_intersect(oi, outputDims);
651 llvm::SmallDenseSet<int64_t> oc = filterDims;
652 llvm::set_intersect(oc, outputDims);
653 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
656 llvm::SmallDenseSet<int64_t> depth = filterDims;
657 llvm::set_intersect(depth, outputDims);
658 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
660 llvm::SmallDenseSet<int64_t> filterReducedDims =
662 linalgOp.getIteratorTypesArray(), red);
665 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
666 llvm::set_intersect(fl, filterReducedDims);
669 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
670 llvm::set_intersect(ic, filterReducedDims);
672 if (oi.empty() && !allowEmptyConvolvedDims)
685 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
686 llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
687 llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
688 llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
689 llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
690 llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
694 if (!nativeStrides) {
696 for (
unsigned oiDim : dimensions.outputImage)
697 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
700 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
702 auto nativeDilations =
704 if (!nativeDilations) {
706 for (
unsigned flDim : dimensions.filterLoop)
707 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
710 dimensions.dilations =
711 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
740 FailureOr<ConvolutionDimensions>
742 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
745 auto indexingMaps = linalgOp.getIndexingMapsArray();
748 ConvAccessExprWalker inputExprWalker;
749 for (
AffineExpr expr : indexingMaps[0].getResults())
750 (void)inputExprWalker.visit(expr);
751 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
774 bool allowEmptyConvolvedDims) {
775 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
777 return MatchConvolutionResult::NotLinalgOp;
778 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
779 return MatchConvolutionResult::WrongNumOperands;
781 auto indexingMaps = linalgOp.getIndexingMapsArray();
784 ConvAccessExprWalker inputExprWalker;
785 if (llvm::any_of(indexingMaps[0].getResults(),
787 return failed(inputExprWalker.visit(expr));
789 return MatchConvolutionResult::WrongInputIndexingMap;
793 if (!indexingMaps[1].isProjectedPermutation() ||
794 !indexingMaps.back().isProjectedPermutation())
795 return MatchConvolutionResult::NotProjectedPermutations;
797 auto iteratorTypes = linalgOp.getIteratorTypesArray();
799 llvm::SmallDenseSet<int64_t> outputDims =
801 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
815 llvm::SmallDenseSet<int64_t> allLoopDims;
816 for (
auto outputExpr : indexingMaps.back().getResults()) {
817 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
818 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
819 !filterDims.count(outputDim)) {
821 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
822 return MatchConvolutionResult::OutputDimsNotParallel;
823 allLoopDims.insert(outputDim);
826 if (inputExprWalker.convolvedDims.count(outputDim) &&
827 !filterDims.count(outputDim)) {
829 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
830 return MatchConvolutionResult::OutputDimsNotParallel;
831 allLoopDims.insert(outputDim);
834 if (!inputExprWalker.convolvedDims.count(outputDim) &&
835 !inputExprWalker.unConvolvedDims.count(outputDim) &&
836 filterDims.count(outputDim)) {
838 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
839 return MatchConvolutionResult::OutputDimsNotParallel;
840 allLoopDims.insert(outputDim);
843 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
844 filterDims.count(outputDim)) {
846 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
847 return MatchConvolutionResult::OutputDimsNotParallel;
848 allLoopDims.insert(outputDim);
851 return MatchConvolutionResult::NonConvolutionLoop;
853 for (
auto filterExpr : indexingMaps[1].getResults()) {
854 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
855 if (outputDims.count(filterDim) &&
856 !inputExprWalker.unConvolvedDims.count(filterDim) &&
857 !inputExprWalker.convolvedDims.count(filterDim)) {
861 if (inputExprWalker.convolvedDims.count(filterDim) &&
862 !outputDims.count(filterDim)) {
864 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
865 return MatchConvolutionResult::NonOutputDimNotReduction;
866 if (allLoopDims.count(filterDim))
867 return MatchConvolutionResult::NonConvolutionLoop;
868 allLoopDims.insert(filterDim);
871 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
872 !outputDims.count(filterDim)) {
874 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
875 return MatchConvolutionResult::NonOutputDimNotReduction;
876 if (allLoopDims.count(filterDim))
877 return MatchConvolutionResult::NonConvolutionLoop;
878 allLoopDims.insert(filterDim);
881 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
882 outputDims.count(filterDim)) {
886 return MatchConvolutionResult::NonConvolutionLoop;
889 if (allLoopDims.size() != linalgOp.getNumLoops())
890 return MatchConvolutionResult::NonConvolutionLoop;
892 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
893 return MatchConvolutionResult::EmptyConvolvedDims;
897 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
898 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
902 return MatchConvolutionResult::Success;
908 case MatchConvolutionResult::NotLinalgOp:
909 return "expected a LinalgOp";
910 case MatchConvolutionResult::WrongNumOperands:
911 return "expected op with 2 inputs and 1 output";
912 case MatchConvolutionResult::WrongInputIndexingMap:
913 return "unexpected input index map for convolutions";
914 case MatchConvolutionResult::NotProjectedPermutations:
915 return "expected output/filter indexing maps to be projected permutations";
916 case MatchConvolutionResult::NonConvolutionLoop:
917 return "unexpected loop dimension for convolution op";
918 case MatchConvolutionResult::OutputDimsNotParallel:
919 return "expected all iterators used to access outputs to be parallel";
920 case MatchConvolutionResult::NonOutputDimNotReduction:
921 return "expected all iterators not used to access outputs to be reduction";
922 case MatchConvolutionResult::EmptyConvolvedDims:
923 return "expected convolved dim to be non-empty";
924 case MatchConvolutionResult::Success:
927 llvm_unreachable(
"unhandled MatchConvolutionResult case");
931 bool allowEmptyConvolvedDims) {
933 linalgOp.getOperation(),
nullptr, allowEmptyConvolvedDims) ==
939 if (res != MatchConvolutionResult::Success)
956 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
959 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
962 OpOperand *value = linalgOp.getDpsInputOperand(0);
963 if (!linalgOp.isScalar(value))
972 return op->
emitError(
"expected a LinalgOp");
974 return op->
emitError(
"expected op with 1 input and 1 output");
976 return op->
emitError(
"expected op with scalar input");
988 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
989 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
997 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
998 for (
OpOperand &opOperand : getOperation()->getOpOperands())
999 llvm::append_range(res,
getShape(&opOperand));
1006 auto viewSizes = createFlatListOfOperandDims(b, loc);
1008 for (
unsigned idx = 0; idx < numRes; ++idx) {
1010 if (
auto d = dyn_cast<AffineDimExpr>(result)) {
1011 if (res[d.getPosition()].offset)
1013 res[d.getPosition()] =
1025 for (
unsigned idx = 0; idx < numRes; ++idx) {
1027 if (
auto d = dyn_cast<AffineDimExpr>(result))
1028 res[d.getPosition()] = allShapeSizes[idx];
1038 : positions(std::move(positions)) {}
1053 llvm::SmallBitVector positions;
1056 static std::pair<int64_t, int64_t>
1058 int64_t inputRankSum = 0;
1059 int64_t outputRankSum = 0;
1060 for (
OpOperand *input : op.getDpsInputOperands())
1061 inputRankSum += op.getRank(input);
1062 for (
OpOperand &output : op.getDpsInitsMutable())
1063 outputRankSum += op.getRank(&output);
1064 return {inputRankSum, inputRankSum + outputRankSum};
1079 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1088 resultShapesSubMapPos.first,
1089 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1090 AffineMap resultShapesFromInputShapesMap =
1091 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1095 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1096 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1098 Location loc = getOperation()->getLoc();
1102 rewriter, loc, resultShapesFromInputShapesMap,
1103 createFlatListOfOperandDims(b, loc));
1106 for (
OpOperand &opOperand : getDpsInitsMutable()) {
1108 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1109 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1110 if (!shapedType.isDynamicDim(dim)) {
1112 shapes.push_back(b.
getIndexAttr(shapedType.getDimSize(dim)));
1117 : allResultDimValues[pos];
1122 reifiedReturnShapes.emplace_back(std::move(shapes));
1129 int64_t LinalgOp::getIndexingMapIndex(
OpOperand *opOperand) {
1131 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1132 if (!dpsIface.isDpsInput(opOperand))
1133 return operandNumber;
1134 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1135 assert(!dpsIface.isDpsInit(opOperand));
1138 return cast<DestinationStyleOpInterface>(*this->getOperation())
1139 .getNumDpsInputs() +
1140 operandNumber - start;
1144 LinalgOp linalgOp = cast<LinalgOp>(op);
1147 if (!linalgOp.hasPureTensorSemantics() &&
1149 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1153 if (linalgOp.hasDynamicIndexingMaps())
1154 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1158 if (
static_cast<int64_t
>(linalgOp.getIndexingMapsArray().size()) !=
1159 linalgOp->getNumOperands())
1160 return op->
emitOpError(
"expected the number of indexing_map (")
1161 << linalgOp.getIndexingMapsArray().size()
1162 <<
") to be equal to the number of input/output operands ("
1163 << linalgOp->getNumOperands() <<
")";
1165 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1166 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1170 return op->
emitOpError(
"unexpected symbols in indexing_map #")
1174 unsigned numLoops = linalgOp.getNumLoops();
1178 <<
" dim(s) to match the number of loops";
1180 int64_t rank = linalgOp.getRank(&opOperand);
1183 << rank <<
") to match the result rank of indexing_map #"
1189 linalgOp.getReductionDims(redDims);
1191 if (!linalgOp.getShapesToLoopsMap())
1192 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1200 if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1201 for (int64_t &range : endLoopRangeValues)
1203 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1204 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1206 indexingMap.
compose(startLoopRangeValues);
1208 indexingMap.
compose(endLoopRangeValues);
1210 for (
auto dim : llvm::seq<int64_t>(0, shape.size())) {
1212 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1225 int64_t inferredDimSize =
1226 std::max(startIndices[dim], endIndices[dim]) + 1;
1227 if (
std::min(startIndices[dim], endIndices[dim]) < 0) {
1230 llvm::raw_string_ostream os(mapStr);
1234 "unexpected result less than 0 at expression #")
1235 << dim <<
" in " << mapStr;
1237 if (dyn_cast<AffineDimExpr>(indexingMap.
getResult(dim))) {
1238 if (inferredDimSize != shape[dim]) {
1239 return op->
emitOpError(
"inferred input/output operand #")
1241 << dim <<
" to be " << inferredDimSize <<
", but found "
1245 if (inferredDimSize > shape[dim]) {
1246 return op->
emitOpError(
"inferred input/output operand #")
1248 << dim <<
" to be greater than or equal to "
1249 << inferredDimSize <<
", but found " << shape[dim];
1257 if (linalgOp->getNumRegions() != 1 ||
1258 !llvm::hasSingleElement(linalgOp->getRegion(0)))
1259 return op->
emitOpError(
"expects to have 1 region with 1 block");
1267 Block &block = linalgOp->getRegion(0).
front();
1269 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1270 return op->
emitOpError(
"expected as many non-induction variable region "
1271 "arguments as the number of input/output operands");
1273 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1275 if (isa<MemRefType, RankedTensorType>(elementType))
1278 if (elementType != argType)
1279 return op->
emitOpError(
"expected type of bb argument #")
1281 <<
" to match element or self type of the corresponding operand ("
1282 << elementType <<
")";
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp, unsigned arity)
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.