21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 #include "llvm/ADT/SmallVector.h"
30 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
39 for (
auto &opOperand : linalgOp->getOpOperands()) {
40 if (llvm::is_contained(droppedOperands, &opOperand))
42 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
44 if (indexingMaps.empty()) {
47 return linalgOp.getNumLoops() == 0;
58 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
62 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
64 auto mapRange = linalgOp.getIndexingMapsArray();
65 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
66 !mapRange.back().isIdentity()) {
70 return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
84 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
85 if (!iface || !iface.hasNoEffect())
95 llvm::raw_ostream &errs) {
97 errs <<
"no terminator in the block";
102 errs <<
"expected block with 3 arguments";
108 errs <<
"expected terminator with 1 operand";
115 errs <<
"expected reduction op to be binary";
124 errs <<
"expected reduction to take block argument #2 as one of the "
125 "operands (modulo unary casts)";
130 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
134 errs <<
"expected elementwise op to be binary";
138 if (!isaPair(elementwiseOp, reductionOp)) {
139 errs <<
"expected reduction/elementwise op kind not satisfied";
152 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
159 template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
161 static_assert(
sizeof...(Args) % 2 == 0,
162 "expected an even number of template arguments");
163 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
166 if constexpr (
sizeof...(Args) > 0)
174 template <
typename... Args>
186 static llvm::SmallDenseSet<int64_t>
189 utils::IteratorType iter) {
190 assert(iterators.size() == indexingMap.
getNumDims());
191 llvm::SmallDenseSet<int64_t> res;
193 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
194 if (iterators[d.getPosition()] == iter &&
196 return e.isFunctionOfDim(d.getPosition());
198 res.insert(d.getPosition());
205 auto par = utils::IteratorType::parallel;
206 auto red = utils::IteratorType::reduction;
219 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
220 iterators[dim.getPosition()] = par;
238 llvm::SmallDenseSet<int64_t> a =
240 llvm::SmallDenseSet<int64_t> b =
242 llvm::SmallDenseSet<int64_t> c =
246 llvm::SmallDenseSet<int64_t> ac = a;
247 llvm::set_intersect(ac, c);
248 llvm::set_subtract(ac, b);
250 llvm::SmallDenseSet<int64_t> bc = b;
251 llvm::set_intersect(bc, c);
252 llvm::set_subtract(bc, a);
254 llvm::SmallDenseSet<int64_t> batches = a;
255 llvm::set_intersect(batches, b);
256 llvm::set_intersect(batches, c);
259 llvm::SmallDenseSet<int64_t> ra =
261 llvm::SmallDenseSet<int64_t> rb =
263 llvm::set_intersect(ra, rb);
271 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
272 llvm::sort(dimensions.m.begin(), dimensions.m.end());
273 llvm::sort(dimensions.n.begin(), dimensions.n.end());
274 llvm::sort(dimensions.k.begin(), dimensions.k.end());
280 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
283 linalgOp.getIteratorTypesArray());
288 if (indexingMaps.size() != 3)
310 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
312 return MatchContractionResult::NotLinalgOp;
313 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
314 return MatchContractionResult::WrongNumOperands;
315 auto mapRange = linalgOp.getIndexingMapsArray();
316 if (linalgOp.getNumReductionLoops() == 0)
317 return MatchContractionResult::NoReduction;
318 if (llvm::any_of(mapRange,
320 return MatchContractionResult::NotProjectedPermutations;
324 arith::MulFOp, arith::AddFOp,
325 arith::MulIOp, arith::AddIOp,
326 complex::MulOp, complex::AddOp,
327 arith::AndIOp, arith::OrIOp>(
328 *linalgOp.getBlock())) {
329 return MatchContractionResult::NotAddMul;
335 assert(
succeeded(res) &&
"unexpected failure to infer contraction dims");
338 return MatchContractionResult::Success;
344 case MatchContractionResult::NotLinalgOp:
345 return "expected a LinalgOp";
346 case MatchContractionResult::WrongNumOperands:
347 return "expected op with 2 inputs and 1 output";
348 case MatchContractionResult::NoReduction:
349 return "expected at least 1 reduction";
350 case MatchContractionResult::NotProjectedPermutations:
351 return "expected indexing maps to be projected permutations";
352 case MatchContractionResult::NotAddMul:
353 return "expected add/mul op in the body";
354 case MatchContractionResult::Success:
357 llvm_unreachable(
"unhandled MatchContractionResult case");
364 return isa<ContractionOpInterface>(op) ||
384 if (res != MatchContractionResult::Success)
395 template <
typename T>
397 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) :
nullptr);
409 struct ConvAccessExprWalker
412 llvm::SmallDenseSet<int64_t> convolvedDims;
414 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
416 llvm::SmallDenseSet<int64_t> unConvolvedDims;
418 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
423 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
425 return e.isFunctionOfDim(dimPos);
427 convolvedDims.erase(dimPos);
428 unConvolvedDims.erase(dimPos);
431 if (convolvedDimMapping.contains(dimPos)) {
432 int64_t pairedDim = convolvedDimMapping[dimPos];
433 convolvedDims.erase(pairedDim);
434 unConvolvedDims.erase(pairedDim);
435 strideAndDilationMapping.erase(pairedDim);
436 convolvedDimMapping.erase(dimPos);
437 convolvedDimMapping.erase(pairedDim);
445 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
448 unConvolvedDims.insert(position);
460 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
461 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
464 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
465 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
470 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
472 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
475 strideAndDilationMapping[dim] =
477 convolvedDims.insert(dim);
480 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
483 auto lhsExpr = symbolMulExpr.getLHS();
484 auto rhsExpr = symbolMulExpr.getRHS();
487 getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
490 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
492 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
493 if (!mulExpr || !dimExpr)
496 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
498 strideAndDilationMapping[dim] = mulExpr;
499 convolvedDims.insert(dim);
509 "expected map to have projected permutations");
510 llvm::SmallDenseSet<int64_t> preservedDims;
512 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
513 return preservedDims;
519 for (
auto e : exprs) {
520 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
521 assert(constantExpr &&
"Found non-constant stride/dilation");
522 vals.push_back(constantExpr.getValue());
536 ConvAccessExprWalker &inputExprWalker,
537 bool allowEmptyConvolvedDims) {
539 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
541 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
543 filterMap, linalgOp.getIteratorTypesArray(), par);
545 outputMap, linalgOp.getIteratorTypesArray(), par);
548 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
549 llvm::set_intersect(batch, outputDims);
550 llvm::set_subtract(batch, filterDims);
553 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
554 llvm::set_intersect(oi, outputDims);
557 llvm::SmallDenseSet<int64_t> oc = filterDims;
558 llvm::set_intersect(oc, outputDims);
559 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
562 llvm::SmallDenseSet<int64_t> depth = filterDims;
563 llvm::set_intersect(depth, outputDims);
564 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
566 llvm::SmallDenseSet<int64_t> filterReducedDims =
568 linalgOp.getIteratorTypesArray(), red);
571 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
572 llvm::set_intersect(fl, filterReducedDims);
575 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
576 llvm::set_intersect(ic, filterReducedDims);
578 if (oi.empty() && !allowEmptyConvolvedDims)
591 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
592 llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
593 llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
594 llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
595 llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
596 llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
600 if (!nativeStrides) {
602 for (
unsigned oiDim : dimensions.outputImage)
603 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
606 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
608 auto nativeDilations =
610 if (!nativeDilations) {
612 for (
unsigned flDim : dimensions.filterLoop)
613 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
616 dimensions.dilations =
617 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
648 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
651 auto indexingMaps = linalgOp.getIndexingMapsArray();
654 ConvAccessExprWalker inputExprWalker;
655 for (
AffineExpr expr : indexingMaps[0].getResults())
656 (void)inputExprWalker.visit(expr);
657 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
679 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
681 return MatchConvolutionResult::NotLinalgOp;
682 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
683 return MatchConvolutionResult::WrongNumOperands;
685 auto indexingMaps = linalgOp.getIndexingMapsArray();
688 ConvAccessExprWalker inputExprWalker;
689 if (llvm::any_of(indexingMaps[0].getResults(),
691 return failed(inputExprWalker.visit(expr));
693 return MatchConvolutionResult::WrongInputIndexingMap;
697 if (!indexingMaps[1].isProjectedPermutation() ||
698 !indexingMaps.back().isProjectedPermutation())
699 return MatchConvolutionResult::NotProjectedPermutations;
701 auto iteratorTypes = linalgOp.getIteratorTypesArray();
703 llvm::SmallDenseSet<int64_t> outputDims =
705 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
719 llvm::SmallDenseSet<int64_t> allLoopDims;
720 for (
auto outputExpr : indexingMaps.back().getResults()) {
721 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
722 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
723 !filterDims.count(outputDim)) {
725 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
726 return MatchConvolutionResult::OutputDimsNotParallel;
727 allLoopDims.insert(outputDim);
730 if (inputExprWalker.convolvedDims.count(outputDim) &&
731 !filterDims.count(outputDim)) {
733 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
734 return MatchConvolutionResult::OutputDimsNotParallel;
735 allLoopDims.insert(outputDim);
738 if (!inputExprWalker.convolvedDims.count(outputDim) &&
739 !inputExprWalker.unConvolvedDims.count(outputDim) &&
740 filterDims.count(outputDim)) {
742 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
743 return MatchConvolutionResult::OutputDimsNotParallel;
744 allLoopDims.insert(outputDim);
747 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
748 filterDims.count(outputDim)) {
750 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
751 return MatchConvolutionResult::OutputDimsNotParallel;
752 allLoopDims.insert(outputDim);
755 return MatchConvolutionResult::NonConvolutionLoop;
757 for (
auto filterExpr : indexingMaps[1].getResults()) {
758 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
759 if (outputDims.count(filterDim) &&
760 !inputExprWalker.unConvolvedDims.count(filterDim) &&
761 !inputExprWalker.convolvedDims.count(filterDim)) {
765 if (inputExprWalker.convolvedDims.count(filterDim) &&
766 !outputDims.count(filterDim)) {
768 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
769 return MatchConvolutionResult::NonOutputDimNotReduction;
770 if (allLoopDims.count(filterDim))
771 return MatchConvolutionResult::NonConvolutionLoop;
772 allLoopDims.insert(filterDim);
775 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
776 !outputDims.count(filterDim)) {
778 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
779 return MatchConvolutionResult::NonOutputDimNotReduction;
780 if (allLoopDims.count(filterDim))
781 return MatchConvolutionResult::NonConvolutionLoop;
782 allLoopDims.insert(filterDim);
785 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
786 outputDims.count(filterDim)) {
790 return MatchConvolutionResult::NonConvolutionLoop;
793 if (allLoopDims.size() != linalgOp.getNumLoops())
794 return MatchConvolutionResult::NonConvolutionLoop;
800 assert(
succeeded(res) &&
"unexpected failure to infer convolution dims");
804 return MatchConvolutionResult::Success;
810 case MatchConvolutionResult::NotLinalgOp:
811 return "expected a LinalgOp";
812 case MatchConvolutionResult::WrongNumOperands:
813 return "expected op with 2 inputs and 1 output";
814 case MatchConvolutionResult::WrongInputIndexingMap:
815 return "unexpected input index map for convolutions";
816 case MatchConvolutionResult::NotProjectedPermutations:
817 return "expected output/filter indexing maps to be projected permutations";
818 case MatchConvolutionResult::NonConvolutionLoop:
819 return "unexpected loop dimension for convolution op";
820 case MatchConvolutionResult::OutputDimsNotParallel:
821 return "expected all iterators used to access outputs to be parallel";
822 case MatchConvolutionResult::NonOutputDimNotReduction:
823 return "expected all iterators not used to access outputs to be reduction";
824 case MatchConvolutionResult::Success:
827 llvm_unreachable(
"unhandled MatchConvolutionResult case");
837 if (res != MatchConvolutionResult::Success)
854 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
857 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
860 OpOperand *value = linalgOp.getDpsInputOperand(0);
861 if (!linalgOp.isScalar(value))
870 return op->
emitError(
"expected a LinalgOp");
872 return op->
emitError(
"expected op with 1 input and 1 output");
874 return op->
emitError(
"expected op with scalar input");
886 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
887 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
895 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
896 for (
OpOperand &opOperand : getOperation()->getOpOperands())
897 llvm::append_range(res,
getShape(&opOperand));
904 auto viewSizes = createFlatListOfOperandDims(b, loc);
906 for (
unsigned idx = 0; idx < numRes; ++idx) {
908 if (
auto d = dyn_cast<AffineDimExpr>(result)) {
909 if (res[d.getPosition()].offset)
911 res[d.getPosition()] =
923 for (
unsigned idx = 0; idx < numRes; ++idx) {
925 if (
auto d = dyn_cast<AffineDimExpr>(result))
926 res[d.getPosition()] = allShapeSizes[idx];
936 : positions(std::move(positions)) {}
951 llvm::SmallBitVector positions;
954 static std::pair<int64_t, int64_t>
956 int64_t inputRankSum = 0;
957 int64_t outputRankSum = 0;
958 for (
OpOperand *input : op.getDpsInputOperands())
959 inputRankSum += op.getRank(input);
960 for (
OpOperand &output : op.getDpsInitsMutable())
961 outputRankSum += op.getRank(&output);
962 return {inputRankSum, inputRankSum + outputRankSum};
977 AffineMap loopsToShapesMap = getLoopsToShapesMap();
986 resultShapesSubMapPos.first,
987 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
988 AffineMap resultShapesFromInputShapesMap =
989 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
993 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
994 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
996 Location loc = getOperation()->getLoc();
1000 rewriter, loc, resultShapesFromInputShapesMap,
1001 createFlatListOfOperandDims(b, loc));
1004 for (
OpOperand &opOperand : getDpsInitsMutable()) {
1006 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1007 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1008 if (!shapedType.isDynamicDim(dim)) {
1010 shapes.push_back(b.
getIndexAttr(shapedType.getDimSize(dim)));
1015 : allResultDimValues[pos];
1020 reifiedReturnShapes.emplace_back(std::move(shapes));
1027 int64_t LinalgOp::getIndexingMapIndex(
OpOperand *opOperand) {
1029 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1030 if (!dpsIface.isDpsInput(opOperand))
1031 return operandNumber;
1032 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1033 assert(!dpsIface.isDpsInit(opOperand));
1036 return cast<DestinationStyleOpInterface>(*this->getOperation())
1037 .getNumDpsInputs() +
1038 operandNumber - start;
1042 LinalgOp linalgOp = cast<LinalgOp>(op);
1045 if (!linalgOp.hasPureTensorSemantics() &&
1047 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1051 if (linalgOp.hasDynamicIndexingMaps())
1052 if (
failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1056 if (
static_cast<int64_t
>(linalgOp.getIndexingMapsArray().size()) !=
1057 linalgOp->getNumOperands())
1058 return op->
emitOpError(
"expected the number of indexing_map (")
1059 << linalgOp.getIndexingMapsArray().size()
1060 <<
") to be equal to the number of input/output operands ("
1061 << linalgOp->getNumOperands() <<
")";
1063 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1064 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1068 return op->
emitOpError(
"unexpected symbols in indexing_map #")
1072 unsigned numLoops = linalgOp.getNumLoops();
1076 <<
" dim(s) to match the number of loops";
1078 int64_t rank = linalgOp.getRank(&opOperand);
1081 << rank <<
") to match the result rank of indexing_map #"
1087 linalgOp.getReductionDims(redDims);
1089 if (!linalgOp.getShapesToLoopsMap())
1090 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1098 if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1099 for (int64_t &range : endLoopRangeValues)
1101 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1102 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1104 indexingMap.
compose(startLoopRangeValues);
1106 indexingMap.
compose(endLoopRangeValues);
1108 for (
auto dim : llvm::seq<int64_t>(0, shape.size())) {
1110 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1123 int64_t inferredDimSize =
1124 std::max(startIndices[dim], endIndices[dim]) + 1;
1125 if (
std::min(startIndices[dim], endIndices[dim]) < 0) {
1128 llvm::raw_string_ostream os(mapStr);
1132 "unexpected result less than 0 at expression #")
1133 << dim <<
" in " << mapStr;
1135 if (dyn_cast<AffineDimExpr>(indexingMap.
getResult(dim))) {
1136 if (inferredDimSize != shape[dim]) {
1137 return op->
emitOpError(
"inferred input/output operand #")
1139 << dim <<
" to be " << inferredDimSize <<
", but found "
1143 if (inferredDimSize > shape[dim]) {
1144 return op->
emitOpError(
"inferred input/output operand #")
1146 << dim <<
" to be greater than or equal to "
1147 << inferredDimSize <<
", but found " << shape[dim];
1155 if (linalgOp->getNumRegions() != 1 ||
1156 !llvm::hasSingleElement(linalgOp->getRegion(0)))
1157 return op->
emitOpError(
"expects to have 1 region with 1 block");
1165 Block &block = linalgOp->getRegion(0).
front();
1167 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1168 return op->
emitOpError(
"expected as many non-induction variable region "
1169 "arguments as the number of input/output operands");
1171 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1173 if (isa<MemRefType, RankedTensorType>(elementType))
1176 if (elementType != argType)
1177 return op->
emitOpError(
"expected type of bb argument #")
1179 <<
" to match element or self type of the corresponding operand ("
1180 << elementType <<
")";
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense integer vector or tensor object.
This class provides support for representing a failure result, or a valid value of type T.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaConvolutionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ConvolutionOpInterface.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
This class represents an efficient way to signal success or failure.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.