22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
34#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
43 for (
auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
48 if (indexingMaps.empty()) {
51 return linalgOp.getNumLoops() == 0;
54 indexingMaps, linalgOp.getContext())) !=
AffineMap();
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
72 Block *body = op.getBlock();
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
78 return yieldOp->getOperand(0) == body->
getArgument(0);
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
96 Block *body = op.getBody();
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
104 Value yieldOperand = yieldOp->getOperand(0);
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
140std::optional<SmallVector<int64_t>>
142 if (
auto broadcastOp = dyn_cast<BroadcastOp>(linalgOp.getOperation()))
144 broadcastOp.getDimensions().end());
146 auto op = dyn_cast<GenericOp>(linalgOp.getOperation());
151 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
152 !op.isSingleYieldOp())
155 auto srcTy = op.getDpsInputOperand(0)->get().getType();
156 auto dstTy = op.getDpsInitOperand(0)->get().getType();
157 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
158 !isa<MemRefType, RankedTensorType>(dstTy))
164 auto dstMap = op.getIndexingMapsArray()[1];
165 if (!dstMap.isIdentity())
169 auto srcMap = op.getIndexingMapsArray()[0];
171 if (srcMap.getResults().size() >= dstMap.getResults().size())
175 for (
unsigned i = 0; i < srcMap.getNumResults(); ++i) {
176 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
179 int64_t pos = expr.getPosition();
180 if (i > 0 && pos <= position[i - 1])
182 position.push_back(expr.getPosition());
186 auto numDims = srcMap.getNumDims();
188 for (
auto dim : llvm::seq<int64_t>(0, numDims)) {
189 if (!llvm::is_contained(position, dim))
190 broadcastedDims.push_back(dim);
192 return broadcastedDims;
198std::optional<SmallVector<int64_t>>
203 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
204 !op.isSingleYieldOp())
207 auto mapRange = op.getIndexingMapsArray();
208 if (mapRange.size() != 2)
211 auto mapOfInput = mapRange.front();
212 auto mapOfResult = mapRange.back();
216 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
220 for (
unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
221 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
222 permutation[expr.getPosition()] = i;
233 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
237 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
238 !llvm::all_of(op.getIndexingMapsArray(),
239 [](
AffineMap map) { return map.isIdentity(); }))
243 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
250 Block *body = op.getBody();
262 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
263 return !(!yieldOp || yieldOp.getNumOperands() != 1 ||
264 yieldOp->getOperand(0).getDefiningOp() != oper);
273 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
283 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
284 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
285 return !(!op.payloadUsesValueFromOperand(inputOpOperand0) ||
286 !op.payloadUsesValueFromOperand(inputOpOperand1));
300 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
301 if (!iface || !iface.hasNoEffect())
311 llvm::raw_ostream &errs) {
313 errs <<
"no terminator in the block";
318 errs <<
"expected block with 3 arguments";
324 errs <<
"expected terminator with 1 operand";
332 errs <<
"expected reduction op to be binary";
341 errs <<
"expected reduction to take block argument #2 as one of the "
342 "operands (modulo unary casts)";
347 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
351 errs <<
"expected elementwise op to be binary";
355 if (!isaPair(elementwiseOp, reductionOp)) {
356 errs <<
"expected reduction/elementwise op kind not satisfied";
369 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
376template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
378 static_assert(
sizeof...(Args) % 2 == 0,
379 "expected an even number of template arguments");
380 if (isa<AddOpTy>(
add) && isa<MulOpTy>(
mul))
383 if constexpr (
sizeof...(Args) > 0)
391template <
typename... Args>
403static llvm::SmallDenseSet<int64_t>
406 utils::IteratorType iter) {
407 assert(iterators.size() == indexingMap.
getNumDims());
408 llvm::SmallDenseSet<int64_t> res;
410 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
411 if (iterators[d.getPosition()] == iter &&
413 return e.isFunctionOfDim(d.getPosition());
415 res.insert(d.getPosition());
422auto par = utils::IteratorType::parallel;
423auto red = utils::IteratorType::reduction;
430static FailureOr<SmallVector<utils::IteratorType>>
436 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
437 iterators[dim.getPosition()] = par;
452static FailureOr<ContractionDimensions>
455 llvm::SmallDenseSet<int64_t> a =
457 llvm::SmallDenseSet<int64_t>
b =
459 llvm::SmallDenseSet<int64_t> c =
463 llvm::SmallDenseSet<int64_t> ac = a;
464 llvm::set_intersect(ac, c);
465 llvm::set_subtract(ac,
b);
467 llvm::SmallDenseSet<int64_t> bc =
b;
468 llvm::set_intersect(bc, c);
469 llvm::set_subtract(bc, a);
471 llvm::SmallDenseSet<int64_t> batches = a;
472 llvm::set_intersect(batches,
b);
473 llvm::set_intersect(batches, c);
476 llvm::SmallDenseSet<int64_t> ra =
478 llvm::SmallDenseSet<int64_t> rb =
480 llvm::set_intersect(ra, rb);
488 llvm::sort(dimensions.
batch);
489 llvm::sort(dimensions.
m);
490 llvm::sort(dimensions.
n);
491 llvm::sort(dimensions.
k);
495FailureOr<ContractionDimensions>
497 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
500 linalgOp.getIteratorTypesArray());
503FailureOr<ContractionDimensions>
505 if (indexingMaps.size() != 3)
508 if (failed(iterators))
527 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
530 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
532 auto mapRange = linalgOp.getIndexingMapsArray();
533 if (linalgOp.getNumReductionLoops() == 0)
535 if (llvm::any_of(mapRange,
541 arith::MulFOp, arith::AddFOp,
542 arith::MulIOp, arith::AddIOp,
543 complex::MulOp, complex::AddOp,
544 arith::AndIOp, arith::OrIOp>(
545 *linalgOp.getBlock())) {
552 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
562 return "expected a LinalgOp";
564 return "expected op with 2 inputs and 1 output";
566 return "expected at least 1 reduction";
568 return "expected indexing maps to be projected permutations";
570 return "expected add/mul op in the body";
574 llvm_unreachable(
"unhandled MatchContractionResult case");
581 return isa<ContractionOpInterface>(op) ||
614 return isa<T>(
lhs) ? cast<T>(
lhs) : (isa<T>(
rhs) ? cast<T>(
rhs) :
nullptr);
626struct ConvAccessExprWalker
629 llvm::SmallDenseSet<int64_t> convolvedDims;
631 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
633 llvm::SmallDenseSet<int64_t> unConvolvedDims;
635 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
639 void clearMultiUseDims(AffineMap map) {
640 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
641 if (llvm::count_if(map.
getResults(), [dimPos](AffineExpr e) {
642 return e.isFunctionOfDim(dimPos);
644 convolvedDims.erase(dimPos);
645 unConvolvedDims.erase(dimPos);
648 auto it = convolvedDimMapping.find(dimPos);
649 if (it != convolvedDimMapping.end()) {
650 int64_t pairedDim = it->second;
651 convolvedDims.erase(pairedDim);
652 unConvolvedDims.erase(pairedDim);
653 strideAndDilationMapping.erase(pairedDim);
654 convolvedDimMapping.erase(dimPos);
655 convolvedDimMapping.erase(pairedDim);
661 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
663 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
666 unConvolvedDims.insert(position);
670 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) {
return failure(); }
672 LogicalResult visitConstantExpr(AffineConstantExpr expr) {
return failure(); }
674 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
676 if (binaryExpr.
getKind() != AffineExprKind::Add)
678 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
679 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
682 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
683 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
687 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
688 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
690 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
693 strideAndDilationMapping[dim] =
695 convolvedDims.insert(dim);
698 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
699 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
701 auto lhsExpr = symbolMulExpr.getLHS();
702 auto rhsExpr = symbolMulExpr.getRHS();
711 if (!mulExpr || !dimExpr)
714 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
716 strideAndDilationMapping[dim] = mulExpr;
717 convolvedDims.insert(dim);
727 "expected map to have projected permutations");
728 llvm::SmallDenseSet<int64_t> preservedDims;
730 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
731 return preservedDims;
737 for (
auto e : exprs) {
738 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
739 assert(constantExpr &&
"Found non-constant stride/dilation");
740 vals.push_back(constantExpr.getValue());
758static FailureOr<ConvolutionDimensions>
760 ConvAccessExprWalker &inputExprWalker,
761 bool allowEmptyConvolvedDims) {
763 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
765 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
767 filterMap, linalgOp.getIteratorTypesArray(), par);
769 outputMap, linalgOp.getIteratorTypesArray(), par);
772 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
773 llvm::set_intersect(batch, outputDims);
774 llvm::set_subtract(batch, filterDims);
777 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
778 llvm::set_intersect(oi, outputDims);
781 llvm::SmallDenseSet<int64_t> oc = filterDims;
782 llvm::set_intersect(oc, outputDims);
783 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
786 llvm::SmallDenseSet<int64_t> depth = filterDims;
787 llvm::set_intersect(depth, outputDims);
788 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
790 llvm::SmallDenseSet<int64_t> filterReducedDims =
792 linalgOp.getIteratorTypesArray(), red);
795 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
796 llvm::set_intersect(fl, filterReducedDims);
799 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
800 llvm::set_intersect(ic, filterReducedDims);
802 if (oi.empty() && !allowEmptyConvolvedDims)
816 llvm::sort(dimensions.
batch);
820 llvm::sort(dimensions.
depth);
826 dimensions.
filterLoop.push_back(inputExprWalker.convolvedDimMapping[oiDim]);
830 if (!nativeStrides) {
833 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
836 dimensions.
strides = llvm::to_vector<2>(nativeStrides.getValues<
int64_t>());
838 auto nativeDilations =
840 if (!nativeDilations) {
843 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
847 llvm::to_vector<2>(nativeDilations.getValues<
int64_t>());
880FailureOr<ConvolutionDimensions>
882 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
885 auto indexingMaps = linalgOp.getIndexingMapsArray();
888 ConvAccessExprWalker inputExprWalker;
889 for (
AffineExpr expr : indexingMaps[0].getResults())
890 (
void)inputExprWalker.visit(expr);
891 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
914 bool allowEmptyConvolvedDims) {
915 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
918 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
921 auto indexingMaps = linalgOp.getIndexingMapsArray();
924 ConvAccessExprWalker inputExprWalker;
925 if (llvm::any_of(indexingMaps[0].getResults(),
927 return failed(inputExprWalker.visit(expr));
933 if (!indexingMaps[1].isProjectedPermutation() ||
934 !indexingMaps.back().isProjectedPermutation())
937 auto iteratorTypes = linalgOp.getIteratorTypesArray();
939 llvm::SmallDenseSet<int64_t> outputDims =
941 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
955 llvm::SmallDenseSet<int64_t> allLoopDims;
956 for (
auto outputExpr : indexingMaps.back().getResults()) {
957 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
958 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
959 !filterDims.count(outputDim)) {
961 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
963 allLoopDims.insert(outputDim);
966 if (inputExprWalker.convolvedDims.count(outputDim) &&
967 !filterDims.count(outputDim)) {
969 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
971 allLoopDims.insert(outputDim);
974 if (!inputExprWalker.convolvedDims.count(outputDim) &&
975 !inputExprWalker.unConvolvedDims.count(outputDim) &&
976 filterDims.count(outputDim)) {
978 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
980 allLoopDims.insert(outputDim);
983 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
984 filterDims.count(outputDim)) {
986 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
988 allLoopDims.insert(outputDim);
993 for (
auto filterExpr : indexingMaps[1].getResults()) {
994 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
995 if (outputDims.count(filterDim) &&
996 !inputExprWalker.unConvolvedDims.count(filterDim) &&
997 !inputExprWalker.convolvedDims.count(filterDim)) {
1001 if (inputExprWalker.convolvedDims.count(filterDim) &&
1002 !outputDims.count(filterDim)) {
1004 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1006 if (allLoopDims.count(filterDim))
1008 allLoopDims.insert(filterDim);
1011 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1012 !outputDims.count(filterDim)) {
1014 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1016 if (allLoopDims.count(filterDim))
1018 allLoopDims.insert(filterDim);
1021 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1022 outputDims.count(filterDim)) {
1029 if (allLoopDims.size() != linalgOp.getNumLoops())
1032 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1037 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1038 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
1049 return "expected a LinalgOp";
1051 return "expected op with 2 inputs and 1 output";
1053 return "unexpected input index map for convolutions";
1055 return "expected output/filter indexing maps to be projected permutations";
1057 return "unexpected loop dimension for convolution op";
1059 return "expected all iterators used to access outputs to be parallel";
1061 return "expected all iterators not used to access outputs to be reduction";
1063 return "expected convolved dim to be non-empty";
1067 llvm_unreachable(
"unhandled MatchConvolutionResult case");
1071 bool allowEmptyConvolvedDims) {
1073 linalgOp.getOperation(),
nullptr, allowEmptyConvolvedDims) ==
1089enum class MatchFillResult {
1099 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1101 return MatchFillResult::NotLinalgOp;
1102 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1103 return MatchFillResult::WrongNumOperands;
1105 OpOperand *value = linalgOp.getDpsInputOperand(0);
1106 if (!linalgOp.isScalar(value))
1107 return MatchFillResult::NotScalarInput;
1110 OpOperand *output = linalgOp.getDpsInitOperand(0);
1113 if (scalarType != outputElementType)
1114 return MatchFillResult::TypeMismatch;
1116 return MatchFillResult::Success;
1121 if (res == MatchFillResult::NotLinalgOp)
1122 return op->
emitError(
"expected a LinalgOp");
1123 if (res == MatchFillResult::WrongNumOperands)
1124 return op->
emitError(
"expected op with 1 input and 1 output");
1125 if (res == MatchFillResult::NotScalarInput)
1126 return op->
emitError(
"expected op with scalar input");
1127 if (res == MatchFillResult::TypeMismatch) {
1128 auto linalgOp = cast<linalg::LinalgOp>(op);
1129 Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
1130 Type outputElementType =
1132 return op->
emitOpError(
"expected fill value type (")
1133 << scalarType <<
") to match output element type ("
1134 << outputElementType <<
")";
1147 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
1148 for (
int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1156 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
1157 for (
OpOperand &opOperand : getOperation()->getOpOperands())
1158 llvm::append_range(res,
getShape(&opOperand));
1163 AffineMap map = getLoopsToShapesMap();
1165 auto viewSizes = createFlatListOfOperandDims(
b, loc);
1166 SmallVector<Range, 4> res(numDims);
1167 for (
unsigned idx = 0; idx < numRes; ++idx) {
1169 if (
auto d = dyn_cast<AffineDimExpr>(
result)) {
1170 if (res[d.getPosition()].offset)
1172 res[d.getPosition()] =
1173 Range{
b.getIndexAttr(0), viewSizes[idx],
b.getIndexAttr(1)};
1184 : positions(std::move(positions)) {}
1199 llvm::SmallBitVector positions;
1202static std::pair<int64_t, int64_t>
1206 for (
OpOperand *input : op.getDpsInputOperands())
1207 inputRankSum += op.getRank(input);
1208 for (
OpOperand &output : op.getDpsInitsMutable())
1209 outputRankSum += op.getRank(&output);
1210 return {inputRankSum, inputRankSum + outputRankSum};
1225 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1233 AffineMap loopToResultsShapeMap = loopsToShapesMap.
getSliceMap(
1234 resultShapesSubMapPos.first,
1235 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1236 AffineMap resultShapesFromInputShapesMap =
1237 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1241 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1242 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1243 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1244 Location loc = getOperation()->getLoc();
1245 IRRewriter rewriter(
b);
1246 SmallVector<OpFoldResult> allResultDimValues =
1248 rewriter, loc, resultShapesFromInputShapesMap,
1249 createFlatListOfOperandDims(
b, loc));
1251 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.
getResults();
1252 for (OpOperand &opOperand : getDpsInitsMutable()) {
1253 SmallVector<OpFoldResult> shapes;
1254 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1255 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1256 if (!shapedType.isDynamicDim(dim)) {
1258 shapes.push_back(
b.getIndexAttr(shapedType.getDimSize(dim)));
1261 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1263 : allResultDimValues[pos];
1268 reifiedReturnShapes.emplace_back(std::move(shapes));
1277 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1278 if (!dpsIface.isDpsInput(opOperand))
1279 return operandNumber;
1280 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1281 assert(!dpsIface.isDpsInit(opOperand));
1284 return cast<DestinationStyleOpInterface>(*this->getOperation())
1285 .getNumDpsInputs() +
1286 operandNumber - start;
1290 LinalgOp linalgOp = cast<LinalgOp>(op);
1292 if (!linalgOp.hasPureTensorSemantics() &&
1294 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1298 if (linalgOp.hasDynamicIndexingMaps())
1299 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1303 if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1308 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1309 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1311 unsigned numLoops = linalgOp.getNumLoops();
1315 <<
" dim(s) to match the number of loops";
1318 linalgOp.getReductionDims(redDims);
1320 if (!linalgOp.getShapesToLoopsMap())
1321 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1324 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1325 return op->
emitOpError(
"expects to have 1 region with 1 block");
1333 Block &block = linalgOp->getRegion(0).front();
1335 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1336 return op->
emitOpError(
"expected as many non-induction variable region "
1337 "arguments as the number of input/output operands");
1339 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1341 if (isa<MemRefType, RankedTensorType>(elementType))
1344 if (elementType != argType)
1345 return op->
emitOpError(
"expected type of bb argument #")
1347 <<
" to match element or self type of the corresponding operand ("
1348 << elementType <<
")";
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
bool visit(AffineExpr expr)
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::function_ref< Fn > function_ref
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
SmallVector< unsigned, 2 > k
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.
SmallVector< unsigned, 2 > depth
SmallVector< unsigned, 2 > outputImage
SmallVector< unsigned, 2 > outputChannel
SmallVector< int64_t, 2 > dilations
SmallVector< int64_t, 2 > strides
SmallVector< unsigned, 2 > inputChannel
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > filterLoop