21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 #include "llvm/ADT/SmallVector.h"
30 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
39 for (
auto &opOperand : linalgOp->getOpOperands()) {
40 if (llvm::is_contained(droppedOperands, &opOperand))
42 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
44 if (indexingMaps.empty()) {
47 return linalgOp.getNumLoops() == 0;
58 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
62 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
64 auto mapRange = linalgOp.getIndexingMapsArray();
65 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
66 !mapRange.back().isIdentity()) {
70 return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
78 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
79 genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
83 if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
84 genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
87 OpOperand *value = genericOp.getDpsInputOperand(0);
88 if (!genericOp.isScalar(value))
91 Block *body = genericOp.getBody();
95 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
96 if (!yieldOp || yieldOp.getNumOperands() != 1 ||
109 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
110 genericOp.getNumLoops() < 1)
114 if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
115 !llvm::all_of(genericOp.getIndexingMapsArray(),
116 [](
AffineMap map) { return map.isIdentity(); }))
120 if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
127 Block *body = genericOp.getBody();
135 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
136 if (!yieldOp || yieldOp.getNumOperands() != 1 ||
137 yieldOp->getOperand(0).getDefiningOp() != op)
148 if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
158 OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
159 OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
160 if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
161 !genericOp.payloadUsesValueFromOperand(inputOpOperand1))
177 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
178 if (!iface || !iface.hasNoEffect())
188 llvm::raw_ostream &errs) {
190 errs <<
"no terminator in the block";
195 errs <<
"expected block with 3 arguments";
201 errs <<
"expected terminator with 1 operand";
208 errs <<
"expected reduction op to be binary";
217 errs <<
"expected reduction to take block argument #2 as one of the "
218 "operands (modulo unary casts)";
223 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
227 errs <<
"expected elementwise op to be binary";
231 if (!isaPair(elementwiseOp, reductionOp)) {
232 errs <<
"expected reduction/elementwise op kind not satisfied";
245 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
252 template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
254 static_assert(
sizeof...(Args) % 2 == 0,
255 "expected an even number of template arguments");
256 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
259 if constexpr (
sizeof...(Args) > 0)
267 template <
typename... Args>
279 static llvm::SmallDenseSet<int64_t>
282 utils::IteratorType iter) {
283 assert(iterators.size() == indexingMap.
getNumDims());
284 llvm::SmallDenseSet<int64_t> res;
286 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
287 if (iterators[d.getPosition()] == iter &&
289 return e.isFunctionOfDim(d.getPosition());
291 res.insert(d.getPosition());
298 auto par = utils::IteratorType::parallel;
299 auto red = utils::IteratorType::reduction;
306 static FailureOr<SmallVector<utils::IteratorType>>
312 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
313 iterators[dim.getPosition()] = par;
328 static FailureOr<ContractionDimensions>
331 llvm::SmallDenseSet<int64_t> a =
333 llvm::SmallDenseSet<int64_t> b =
335 llvm::SmallDenseSet<int64_t> c =
339 llvm::SmallDenseSet<int64_t> ac = a;
340 llvm::set_intersect(ac, c);
341 llvm::set_subtract(ac, b);
343 llvm::SmallDenseSet<int64_t> bc = b;
344 llvm::set_intersect(bc, c);
345 llvm::set_subtract(bc, a);
347 llvm::SmallDenseSet<int64_t> batches = a;
348 llvm::set_intersect(batches, b);
349 llvm::set_intersect(batches, c);
352 llvm::SmallDenseSet<int64_t> ra =
354 llvm::SmallDenseSet<int64_t> rb =
356 llvm::set_intersect(ra, rb);
364 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
365 llvm::sort(dimensions.m.begin(), dimensions.m.end());
366 llvm::sort(dimensions.n.begin(), dimensions.n.end());
367 llvm::sort(dimensions.k.begin(), dimensions.k.end());
371 FailureOr<ContractionDimensions>
373 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
376 linalgOp.getIteratorTypesArray());
379 FailureOr<ContractionDimensions>
381 if (indexingMaps.size() != 3)
384 if (failed(iterators))
403 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
405 return MatchContractionResult::NotLinalgOp;
406 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
407 return MatchContractionResult::WrongNumOperands;
408 auto mapRange = linalgOp.getIndexingMapsArray();
409 if (linalgOp.getNumReductionLoops() == 0)
410 return MatchContractionResult::NoReduction;
411 if (llvm::any_of(mapRange,
412 [](
AffineMap m) {
return !m.isProjectedPermutation(); }))
413 return MatchContractionResult::NotProjectedPermutations;
417 arith::MulFOp, arith::AddFOp,
418 arith::MulIOp, arith::AddIOp,
419 complex::MulOp, complex::AddOp,
420 arith::AndIOp, arith::OrIOp>(
421 *linalgOp.getBlock())) {
422 return MatchContractionResult::NotAddMul;
428 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
431 return MatchContractionResult::Success;
437 case MatchContractionResult::NotLinalgOp:
438 return "expected a LinalgOp";
439 case MatchContractionResult::WrongNumOperands:
440 return "expected op with 2 inputs and 1 output";
441 case MatchContractionResult::NoReduction:
442 return "expected at least 1 reduction";
443 case MatchContractionResult::NotProjectedPermutations:
444 return "expected indexing maps to be projected permutations";
445 case MatchContractionResult::NotAddMul:
446 return "expected add/mul op in the body";
447 case MatchContractionResult::Success:
450 llvm_unreachable(
"unhandled MatchContractionResult case");
457 return isa<ContractionOpInterface>(op) ||
477 if (res != MatchContractionResult::Success)
488 template <
typename T>
490 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) :
nullptr);
502 struct ConvAccessExprWalker
505 llvm::SmallDenseSet<int64_t> convolvedDims;
507 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
509 llvm::SmallDenseSet<int64_t> unConvolvedDims;
511 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
516 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
518 return e.isFunctionOfDim(dimPos);
520 convolvedDims.erase(dimPos);
521 unConvolvedDims.erase(dimPos);
524 if (convolvedDimMapping.contains(dimPos)) {
525 int64_t pairedDim = convolvedDimMapping[dimPos];
526 convolvedDims.erase(pairedDim);
527 unConvolvedDims.erase(pairedDim);
528 strideAndDilationMapping.erase(pairedDim);
529 convolvedDimMapping.erase(dimPos);
530 convolvedDimMapping.erase(pairedDim);
538 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
541 unConvolvedDims.insert(position);
553 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
554 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
555 if (failed(lhsDimPos) || failed(rhsDimPos))
557 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
558 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
562 FailureOr<int64_t> getDimExprOrMulExprDimPos(
AffineExpr expr) {
563 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
565 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
568 strideAndDilationMapping[dim] =
570 convolvedDims.insert(dim);
573 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
576 auto lhsExpr = symbolMulExpr.getLHS();
577 auto rhsExpr = symbolMulExpr.getRHS();
580 getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
583 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
585 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
586 if (!mulExpr || !dimExpr)
589 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
591 strideAndDilationMapping[dim] = mulExpr;
592 convolvedDims.insert(dim);
602 "expected map to have projected permutations");
603 llvm::SmallDenseSet<int64_t> preservedDims;
605 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
606 return preservedDims;
612 for (
auto e : exprs) {
613 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
614 assert(constantExpr &&
"Found non-constant stride/dilation");
615 vals.push_back(constantExpr.getValue());
627 static FailureOr<ConvolutionDimensions>
629 ConvAccessExprWalker &inputExprWalker,
630 bool allowEmptyConvolvedDims) {
632 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
634 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
636 filterMap, linalgOp.getIteratorTypesArray(), par);
638 outputMap, linalgOp.getIteratorTypesArray(), par);
641 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
642 llvm::set_intersect(batch, outputDims);
643 llvm::set_subtract(batch, filterDims);
646 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
647 llvm::set_intersect(oi, outputDims);
650 llvm::SmallDenseSet<int64_t> oc = filterDims;
651 llvm::set_intersect(oc, outputDims);
652 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
655 llvm::SmallDenseSet<int64_t> depth = filterDims;
656 llvm::set_intersect(depth, outputDims);
657 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
659 llvm::SmallDenseSet<int64_t> filterReducedDims =
661 linalgOp.getIteratorTypesArray(), red);
664 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
665 llvm::set_intersect(fl, filterReducedDims);
668 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
669 llvm::set_intersect(ic, filterReducedDims);
671 if (oi.empty() && !allowEmptyConvolvedDims)
684 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
685 llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
686 llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
687 llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
688 llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
689 llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
693 if (!nativeStrides) {
695 for (
unsigned oiDim : dimensions.outputImage)
696 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
699 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
701 auto nativeDilations =
703 if (!nativeDilations) {
705 for (
unsigned flDim : dimensions.filterLoop)
706 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
709 dimensions.dilations =
710 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
739 FailureOr<ConvolutionDimensions>
741 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
744 auto indexingMaps = linalgOp.getIndexingMapsArray();
747 ConvAccessExprWalker inputExprWalker;
748 for (
AffineExpr expr : indexingMaps[0].getResults())
749 (void)inputExprWalker.visit(expr);
750 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
772 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
774 return MatchConvolutionResult::NotLinalgOp;
775 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
776 return MatchConvolutionResult::WrongNumOperands;
778 auto indexingMaps = linalgOp.getIndexingMapsArray();
781 ConvAccessExprWalker inputExprWalker;
782 if (llvm::any_of(indexingMaps[0].getResults(),
784 return failed(inputExprWalker.visit(expr));
786 return MatchConvolutionResult::WrongInputIndexingMap;
790 if (!indexingMaps[1].isProjectedPermutation() ||
791 !indexingMaps.back().isProjectedPermutation())
792 return MatchConvolutionResult::NotProjectedPermutations;
794 auto iteratorTypes = linalgOp.getIteratorTypesArray();
796 llvm::SmallDenseSet<int64_t> outputDims =
798 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
812 llvm::SmallDenseSet<int64_t> allLoopDims;
813 for (
auto outputExpr : indexingMaps.back().getResults()) {
814 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
815 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
816 !filterDims.count(outputDim)) {
818 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
819 return MatchConvolutionResult::OutputDimsNotParallel;
820 allLoopDims.insert(outputDim);
823 if (inputExprWalker.convolvedDims.count(outputDim) &&
824 !filterDims.count(outputDim)) {
826 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
827 return MatchConvolutionResult::OutputDimsNotParallel;
828 allLoopDims.insert(outputDim);
831 if (!inputExprWalker.convolvedDims.count(outputDim) &&
832 !inputExprWalker.unConvolvedDims.count(outputDim) &&
833 filterDims.count(outputDim)) {
835 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
836 return MatchConvolutionResult::OutputDimsNotParallel;
837 allLoopDims.insert(outputDim);
840 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
841 filterDims.count(outputDim)) {
843 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
844 return MatchConvolutionResult::OutputDimsNotParallel;
845 allLoopDims.insert(outputDim);
848 return MatchConvolutionResult::NonConvolutionLoop;
850 for (
auto filterExpr : indexingMaps[1].getResults()) {
851 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
852 if (outputDims.count(filterDim) &&
853 !inputExprWalker.unConvolvedDims.count(filterDim) &&
854 !inputExprWalker.convolvedDims.count(filterDim)) {
858 if (inputExprWalker.convolvedDims.count(filterDim) &&
859 !outputDims.count(filterDim)) {
861 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
862 return MatchConvolutionResult::NonOutputDimNotReduction;
863 if (allLoopDims.count(filterDim))
864 return MatchConvolutionResult::NonConvolutionLoop;
865 allLoopDims.insert(filterDim);
868 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
869 !outputDims.count(filterDim)) {
871 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
872 return MatchConvolutionResult::NonOutputDimNotReduction;
873 if (allLoopDims.count(filterDim))
874 return MatchConvolutionResult::NonConvolutionLoop;
875 allLoopDims.insert(filterDim);
878 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
879 outputDims.count(filterDim)) {
883 return MatchConvolutionResult::NonConvolutionLoop;
886 if (allLoopDims.size() != linalgOp.getNumLoops())
887 return MatchConvolutionResult::NonConvolutionLoop;
890 FailureOr<ConvolutionDimensions> res =
893 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
897 return MatchConvolutionResult::Success;
903 case MatchConvolutionResult::NotLinalgOp:
904 return "expected a LinalgOp";
905 case MatchConvolutionResult::WrongNumOperands:
906 return "expected op with 2 inputs and 1 output";
907 case MatchConvolutionResult::WrongInputIndexingMap:
908 return "unexpected input index map for convolutions";
909 case MatchConvolutionResult::NotProjectedPermutations:
910 return "expected output/filter indexing maps to be projected permutations";
911 case MatchConvolutionResult::NonConvolutionLoop:
912 return "unexpected loop dimension for convolution op";
913 case MatchConvolutionResult::OutputDimsNotParallel:
914 return "expected all iterators used to access outputs to be parallel";
915 case MatchConvolutionResult::NonOutputDimNotReduction:
916 return "expected all iterators not used to access outputs to be reduction";
917 case MatchConvolutionResult::Success:
920 llvm_unreachable(
"unhandled MatchConvolutionResult case");
930 if (res != MatchConvolutionResult::Success)
947 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
950 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
953 OpOperand *value = linalgOp.getDpsInputOperand(0);
954 if (!linalgOp.isScalar(value))
963 return op->
emitError(
"expected a LinalgOp");
965 return op->
emitError(
"expected op with 1 input and 1 output");
967 return op->
emitError(
"expected op with scalar input");
979 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
980 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
988 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
989 for (
OpOperand &opOperand : getOperation()->getOpOperands())
990 llvm::append_range(res,
getShape(&opOperand));
997 auto viewSizes = createFlatListOfOperandDims(b, loc);
999 for (
unsigned idx = 0; idx < numRes; ++idx) {
1001 if (
auto d = dyn_cast<AffineDimExpr>(result)) {
1002 if (res[d.getPosition()].offset)
1004 res[d.getPosition()] =
1016 for (
unsigned idx = 0; idx < numRes; ++idx) {
1018 if (
auto d = dyn_cast<AffineDimExpr>(result))
1019 res[d.getPosition()] = allShapeSizes[idx];
1029 : positions(std::move(positions)) {}
1044 llvm::SmallBitVector positions;
1047 static std::pair<int64_t, int64_t>
1049 int64_t inputRankSum = 0;
1050 int64_t outputRankSum = 0;
1051 for (
OpOperand *input : op.getDpsInputOperands())
1052 inputRankSum += op.getRank(input);
1053 for (
OpOperand &output : op.getDpsInitsMutable())
1054 outputRankSum += op.getRank(&output);
1055 return {inputRankSum, inputRankSum + outputRankSum};
1070 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1079 resultShapesSubMapPos.first,
1080 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1081 AffineMap resultShapesFromInputShapesMap =
1082 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1086 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1087 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1089 Location loc = getOperation()->getLoc();
1093 rewriter, loc, resultShapesFromInputShapesMap,
1094 createFlatListOfOperandDims(b, loc));
1097 for (
OpOperand &opOperand : getDpsInitsMutable()) {
1099 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1100 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1101 if (!shapedType.isDynamicDim(dim)) {
1103 shapes.push_back(b.
getIndexAttr(shapedType.getDimSize(dim)));
1108 : allResultDimValues[pos];
1113 reifiedReturnShapes.emplace_back(std::move(shapes));
1120 int64_t LinalgOp::getIndexingMapIndex(
OpOperand *opOperand) {
1122 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1123 if (!dpsIface.isDpsInput(opOperand))
1124 return operandNumber;
1125 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1126 assert(!dpsIface.isDpsInit(opOperand));
1129 return cast<DestinationStyleOpInterface>(*this->getOperation())
1130 .getNumDpsInputs() +
1131 operandNumber - start;
1135 LinalgOp linalgOp = cast<LinalgOp>(op);
1138 if (!linalgOp.hasPureTensorSemantics() &&
1140 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1144 if (linalgOp.hasDynamicIndexingMaps())
1145 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1149 if (
static_cast<int64_t
>(linalgOp.getIndexingMapsArray().size()) !=
1150 linalgOp->getNumOperands())
1151 return op->
emitOpError(
"expected the number of indexing_map (")
1152 << linalgOp.getIndexingMapsArray().size()
1153 <<
") to be equal to the number of input/output operands ("
1154 << linalgOp->getNumOperands() <<
")";
1156 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1157 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1161 return op->
emitOpError(
"unexpected symbols in indexing_map #")
1165 unsigned numLoops = linalgOp.getNumLoops();
1169 <<
" dim(s) to match the number of loops";
1171 int64_t rank = linalgOp.getRank(&opOperand);
1174 << rank <<
") to match the result rank of indexing_map #"
1180 linalgOp.getReductionDims(redDims);
1182 if (!linalgOp.getShapesToLoopsMap())
1183 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1191 if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1192 for (int64_t &range : endLoopRangeValues)
1194 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1195 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1197 indexingMap.
compose(startLoopRangeValues);
1199 indexingMap.
compose(endLoopRangeValues);
1201 for (
auto dim : llvm::seq<int64_t>(0, shape.size())) {
1203 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1216 int64_t inferredDimSize =
1217 std::max(startIndices[dim], endIndices[dim]) + 1;
1218 if (
std::min(startIndices[dim], endIndices[dim]) < 0) {
1221 llvm::raw_string_ostream os(mapStr);
1225 "unexpected result less than 0 at expression #")
1226 << dim <<
" in " << mapStr;
1228 if (dyn_cast<AffineDimExpr>(indexingMap.
getResult(dim))) {
1229 if (inferredDimSize != shape[dim]) {
1230 return op->
emitOpError(
"inferred input/output operand #")
1232 << dim <<
" to be " << inferredDimSize <<
", but found "
1236 if (inferredDimSize > shape[dim]) {
1237 return op->
emitOpError(
"inferred input/output operand #")
1239 << dim <<
" to be greater than or equal to "
1240 << inferredDimSize <<
", but found " << shape[dim];
1248 if (linalgOp->getNumRegions() != 1 ||
1249 !llvm::hasSingleElement(linalgOp->getRegion(0)))
1250 return op->
emitOpError(
"expects to have 1 region with 1 block");
1258 Block &block = linalgOp->getRegion(0).
front();
1260 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1261 return op->
emitOpError(
"expected as many non-induction variable region "
1262 "arguments as the number of input/output operands");
1264 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1266 if (isa<MemRefType, RankedTensorType>(elementType))
1269 if (elementType != argType)
1270 return op->
emitOpError(
"expected type of bb argument #")
1272 <<
" to match element or self type of the corresponding operand ("
1273 << elementType <<
")";
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp, unsigned arity)
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaConvolutionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ConvolutionOpInterface.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.