22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
34#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
43 for (
auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
48 if (indexingMaps.empty()) {
51 return linalgOp.getNumLoops() == 0;
54 indexingMaps, linalgOp.getContext())) !=
AffineMap();
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
72 Block *body = op.getBlock();
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
78 return yieldOp->getOperand(0) == body->
getArgument(0);
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
96 Block *body = op.getBody();
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
104 Value yieldOperand = yieldOp->getOperand(0);
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
140std::optional<SmallVector<int64_t>>
142 if (
auto broadcastOp = dyn_cast<BroadcastOp>(linalgOp.getOperation()))
144 broadcastOp.getDimensions().end());
146 auto op = dyn_cast<GenericOp>(linalgOp.getOperation());
151 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
152 !op.isSingleYieldOp())
155 auto srcTy = op.getDpsInputOperand(0)->get().getType();
156 auto dstTy = op.getDpsInitOperand(0)->get().getType();
157 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
158 !isa<MemRefType, RankedTensorType>(dstTy))
164 auto dstMap = op.getIndexingMapsArray()[1];
165 if (!dstMap.isIdentity())
169 auto srcMap = op.getIndexingMapsArray()[0];
171 if (srcMap.getResults().size() >= dstMap.getResults().size())
175 for (
unsigned i = 0; i < srcMap.getNumResults(); ++i) {
176 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
179 int64_t pos = expr.getPosition();
180 if (i > 0 && pos <= position[i - 1])
182 position.push_back(expr.getPosition());
186 auto numDims = srcMap.getNumDims();
188 for (
auto dim : llvm::seq<int64_t>(0, numDims)) {
189 if (!llvm::is_contained(position, dim))
190 broadcastedDims.push_back(dim);
192 return broadcastedDims;
198std::optional<SmallVector<int64_t>>
203 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
204 !op.isSingleYieldOp())
207 auto mapRange = op.getIndexingMapsArray();
208 if (mapRange.size() != 2)
211 auto mapOfInput = mapRange.front();
212 auto mapOfResult = mapRange.back();
216 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
220 for (
unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
221 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
222 permutation[expr.getPosition()] = i;
232 bool allowNonIdentityMaps) {
234 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
239 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
240 (!allowNonIdentityMaps &&
241 !llvm::all_of(op.getIndexingMapsArray(),
242 [](
AffineMap map) { return map.isIdentity(); })))
246 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
253 Block *body = op.getBody();
265 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
266 return !(!yieldOp || yieldOp.getNumOperands() != 1 ||
267 yieldOp->getOperand(0).getDefiningOp() != oper);
271 bool allowNonIdentityMaps) {
277 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
283 bool allowNonIdentityMaps) {
289 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
290 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
291 return !(!op.payloadUsesValueFromOperand(inputOpOperand0) ||
292 !op.payloadUsesValueFromOperand(inputOpOperand1));
306 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
307 if (!iface || !iface.hasNoEffect())
317 llvm::raw_ostream &errs) {
319 errs <<
"no terminator in the block";
324 errs <<
"expected block with 3 arguments";
330 errs <<
"expected terminator with 1 operand";
338 errs <<
"expected reduction op to be binary";
347 errs <<
"expected reduction to take block argument #2 as one of the "
348 "operands (modulo unary casts)";
353 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
357 errs <<
"expected elementwise op to be binary";
361 if (!isaPair(elementwiseOp, reductionOp)) {
362 errs <<
"expected reduction/elementwise op kind not satisfied";
375 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
382template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
384 static_assert(
sizeof...(Args) % 2 == 0,
385 "expected an even number of template arguments");
386 if (isa<AddOpTy>(
add) && isa<MulOpTy>(
mul))
389 if constexpr (
sizeof...(Args) > 0)
397template <
typename... Args>
409static llvm::SmallDenseSet<int64_t>
412 utils::IteratorType iter) {
413 assert(iterators.size() == indexingMap.
getNumDims());
414 llvm::SmallDenseSet<int64_t> res;
416 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
417 if (iterators[d.getPosition()] == iter &&
419 return e.isFunctionOfDim(d.getPosition());
421 res.insert(d.getPosition());
428auto par = utils::IteratorType::parallel;
429auto red = utils::IteratorType::reduction;
436static FailureOr<SmallVector<utils::IteratorType>>
442 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
443 iterators[dim.getPosition()] = par;
458static FailureOr<ContractionDimensions>
461 llvm::SmallDenseSet<int64_t> a =
463 llvm::SmallDenseSet<int64_t>
b =
465 llvm::SmallDenseSet<int64_t> c =
469 llvm::SmallDenseSet<int64_t> ac = a;
470 llvm::set_intersect(ac, c);
471 llvm::set_subtract(ac,
b);
473 llvm::SmallDenseSet<int64_t> bc =
b;
474 llvm::set_intersect(bc, c);
475 llvm::set_subtract(bc, a);
477 llvm::SmallDenseSet<int64_t> batches = a;
478 llvm::set_intersect(batches,
b);
479 llvm::set_intersect(batches, c);
482 llvm::SmallDenseSet<int64_t> ra =
484 llvm::SmallDenseSet<int64_t> rb =
486 llvm::set_intersect(ra, rb);
494 llvm::sort(dimensions.
batch);
495 llvm::sort(dimensions.
m);
496 llvm::sort(dimensions.
n);
497 llvm::sort(dimensions.
k);
501FailureOr<ContractionDimensions>
503 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
506 linalgOp.getIteratorTypesArray());
509FailureOr<ContractionDimensions>
511 if (indexingMaps.size() != 3)
514 if (failed(iterators))
533 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
536 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
538 auto mapRange = linalgOp.getIndexingMapsArray();
539 if (linalgOp.getNumReductionLoops() == 0)
541 if (llvm::any_of(mapRange,
547 arith::MulFOp, arith::AddFOp,
548 arith::MulIOp, arith::AddIOp,
549 complex::MulOp, complex::AddOp,
550 arith::AndIOp, arith::OrIOp>(
551 *linalgOp.getBlock())) {
558 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
568 return "expected a LinalgOp";
570 return "expected op with 2 inputs and 1 output";
572 return "expected at least 1 reduction";
574 return "expected indexing maps to be projected permutations";
576 return "expected add/mul op in the body";
580 llvm_unreachable(
"unhandled MatchContractionResult case");
587 return isa<ContractionOpInterface>(op) ||
620 return isa<T>(
lhs) ? cast<T>(
lhs) : (isa<T>(
rhs) ? cast<T>(
rhs) :
nullptr);
632struct ConvAccessExprWalker
635 llvm::SmallDenseSet<int64_t> convolvedDims;
637 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
639 llvm::SmallDenseSet<int64_t> unConvolvedDims;
641 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
645 void clearMultiUseDims(AffineMap map) {
646 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
647 if (llvm::count_if(map.
getResults(), [dimPos](AffineExpr e) {
648 return e.isFunctionOfDim(dimPos);
650 convolvedDims.erase(dimPos);
651 unConvolvedDims.erase(dimPos);
654 auto it = convolvedDimMapping.find(dimPos);
655 if (it != convolvedDimMapping.end()) {
656 int64_t pairedDim = it->second;
657 convolvedDims.erase(pairedDim);
658 unConvolvedDims.erase(pairedDim);
659 strideAndDilationMapping.erase(pairedDim);
660 convolvedDimMapping.erase(dimPos);
661 convolvedDimMapping.erase(pairedDim);
667 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
669 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
672 unConvolvedDims.insert(position);
676 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) {
return failure(); }
678 LogicalResult visitConstantExpr(AffineConstantExpr expr) {
return failure(); }
680 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
682 if (binaryExpr.
getKind() != AffineExprKind::Add)
684 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
685 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
688 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
689 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
693 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
694 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
696 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
699 strideAndDilationMapping[dim] =
701 convolvedDims.insert(dim);
704 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
705 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
707 auto lhsExpr = symbolMulExpr.getLHS();
708 auto rhsExpr = symbolMulExpr.getRHS();
717 if (!mulExpr || !dimExpr)
720 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
722 strideAndDilationMapping[dim] = mulExpr;
723 convolvedDims.insert(dim);
733 "expected map to have projected permutations");
734 llvm::SmallDenseSet<int64_t> preservedDims;
736 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
737 return preservedDims;
743 for (
auto e : exprs) {
744 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
745 assert(constantExpr &&
"Found non-constant stride/dilation");
746 vals.push_back(constantExpr.getValue());
764static FailureOr<ConvolutionDimensions>
766 ConvAccessExprWalker &inputExprWalker,
767 bool allowEmptyConvolvedDims) {
769 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
771 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
773 filterMap, linalgOp.getIteratorTypesArray(), par);
775 outputMap, linalgOp.getIteratorTypesArray(), par);
778 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
779 llvm::set_intersect(batch, outputDims);
780 llvm::set_subtract(batch, filterDims);
783 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
784 llvm::set_intersect(oi, outputDims);
787 llvm::SmallDenseSet<int64_t> oc = filterDims;
788 llvm::set_intersect(oc, outputDims);
789 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
792 llvm::SmallDenseSet<int64_t> depth = filterDims;
793 llvm::set_intersect(depth, outputDims);
794 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
796 llvm::SmallDenseSet<int64_t> filterReducedDims =
798 linalgOp.getIteratorTypesArray(), red);
801 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
802 llvm::set_intersect(fl, filterReducedDims);
805 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
806 llvm::set_intersect(ic, filterReducedDims);
808 if (oi.empty() && !allowEmptyConvolvedDims)
822 llvm::sort(dimensions.
batch);
826 llvm::sort(dimensions.
depth);
832 dimensions.
filterLoop.push_back(inputExprWalker.convolvedDimMapping[oiDim]);
836 if (!nativeStrides) {
839 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
842 dimensions.
strides = llvm::to_vector<2>(nativeStrides.getValues<
int64_t>());
844 auto nativeDilations =
846 if (!nativeDilations) {
849 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
853 llvm::to_vector<2>(nativeDilations.getValues<
int64_t>());
886FailureOr<ConvolutionDimensions>
888 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
891 auto indexingMaps = linalgOp.getIndexingMapsArray();
894 ConvAccessExprWalker inputExprWalker;
895 for (
AffineExpr expr : indexingMaps[0].getResults())
896 (
void)inputExprWalker.visit(expr);
897 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
920 bool allowEmptyConvolvedDims) {
921 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
924 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
927 auto indexingMaps = linalgOp.getIndexingMapsArray();
930 ConvAccessExprWalker inputExprWalker;
931 if (llvm::any_of(indexingMaps[0].getResults(),
933 return failed(inputExprWalker.visit(expr));
939 if (!indexingMaps[1].isProjectedPermutation() ||
940 !indexingMaps.back().isProjectedPermutation())
943 auto iteratorTypes = linalgOp.getIteratorTypesArray();
945 llvm::SmallDenseSet<int64_t> outputDims =
947 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
961 llvm::SmallDenseSet<int64_t> allLoopDims;
962 for (
auto outputExpr : indexingMaps.back().getResults()) {
963 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
964 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
965 !filterDims.count(outputDim)) {
967 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
969 allLoopDims.insert(outputDim);
972 if (inputExprWalker.convolvedDims.count(outputDim) &&
973 !filterDims.count(outputDim)) {
975 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
977 allLoopDims.insert(outputDim);
980 if (!inputExprWalker.convolvedDims.count(outputDim) &&
981 !inputExprWalker.unConvolvedDims.count(outputDim) &&
982 filterDims.count(outputDim)) {
984 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
986 allLoopDims.insert(outputDim);
989 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
990 filterDims.count(outputDim)) {
992 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
994 allLoopDims.insert(outputDim);
999 for (
auto filterExpr : indexingMaps[1].getResults()) {
1000 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
1001 if (outputDims.count(filterDim) &&
1002 !inputExprWalker.unConvolvedDims.count(filterDim) &&
1003 !inputExprWalker.convolvedDims.count(filterDim)) {
1007 if (inputExprWalker.convolvedDims.count(filterDim) &&
1008 !outputDims.count(filterDim)) {
1010 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1012 if (allLoopDims.count(filterDim))
1014 allLoopDims.insert(filterDim);
1017 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1018 !outputDims.count(filterDim)) {
1020 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1022 if (allLoopDims.count(filterDim))
1024 allLoopDims.insert(filterDim);
1027 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1028 outputDims.count(filterDim)) {
1035 if (allLoopDims.size() != linalgOp.getNumLoops())
1038 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1043 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1044 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
1055 return "expected a LinalgOp";
1057 return "expected op with 2 inputs and 1 output";
1059 return "unexpected input index map for convolutions";
1061 return "expected output/filter indexing maps to be projected permutations";
1063 return "unexpected loop dimension for convolution op";
1065 return "expected all iterators used to access outputs to be parallel";
1067 return "expected all iterators not used to access outputs to be reduction";
1069 return "expected convolved dim to be non-empty";
1073 llvm_unreachable(
"unhandled MatchConvolutionResult case");
1077 bool allowEmptyConvolvedDims) {
1079 linalgOp.getOperation(),
nullptr, allowEmptyConvolvedDims) ==
1095enum class MatchFillResult {
1105 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1107 return MatchFillResult::NotLinalgOp;
1108 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1109 return MatchFillResult::WrongNumOperands;
1111 OpOperand *value = linalgOp.getDpsInputOperand(0);
1112 if (!linalgOp.isScalar(value))
1113 return MatchFillResult::NotScalarInput;
1116 OpOperand *output = linalgOp.getDpsInitOperand(0);
1119 if (scalarType != outputElementType)
1120 return MatchFillResult::TypeMismatch;
1122 return MatchFillResult::Success;
1127 if (res == MatchFillResult::NotLinalgOp)
1128 return op->
emitError(
"expected a LinalgOp");
1129 if (res == MatchFillResult::WrongNumOperands)
1130 return op->
emitError(
"expected op with 1 input and 1 output");
1131 if (res == MatchFillResult::NotScalarInput)
1132 return op->
emitError(
"expected op with scalar input");
1133 if (res == MatchFillResult::TypeMismatch) {
1134 auto linalgOp = cast<linalg::LinalgOp>(op);
1135 Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
1136 Type outputElementType =
1138 return op->
emitOpError(
"expected fill value type (")
1139 << scalarType <<
") to match output element type ("
1140 << outputElementType <<
")";
1153 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
1154 for (
int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1162 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
1163 for (
OpOperand &opOperand : getOperation()->getOpOperands())
1164 llvm::append_range(res,
getShape(&opOperand));
1169 AffineMap map = getLoopsToShapesMap();
1171 auto viewSizes = createFlatListOfOperandDims(
b, loc);
1172 SmallVector<Range, 4> res(numDims);
1173 for (
unsigned idx = 0; idx < numRes; ++idx) {
1175 if (
auto d = dyn_cast<AffineDimExpr>(
result)) {
1176 if (res[d.getPosition()].offset)
1178 res[d.getPosition()] =
1179 Range{
b.getIndexAttr(0), viewSizes[idx],
b.getIndexAttr(1)};
1190 : positions(std::move(positions)) {}
1205 llvm::SmallBitVector positions;
1208static std::pair<int64_t, int64_t>
1212 for (
OpOperand *input : op.getDpsInputOperands())
1213 inputRankSum += op.getRank(input);
1214 for (
OpOperand &output : op.getDpsInitsMutable())
1215 outputRankSum += op.getRank(&output);
1216 return {inputRankSum, inputRankSum + outputRankSum};
1231 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1239 AffineMap loopToResultsShapeMap = loopsToShapesMap.
getSliceMap(
1240 resultShapesSubMapPos.first,
1241 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1242 AffineMap resultShapesFromInputShapesMap =
1243 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1247 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1248 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1249 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1250 Location loc = getOperation()->getLoc();
1251 IRRewriter rewriter(
b);
1252 SmallVector<OpFoldResult> allResultDimValues =
1254 rewriter, loc, resultShapesFromInputShapesMap,
1255 createFlatListOfOperandDims(
b, loc));
1257 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.
getResults();
1258 for (OpOperand &opOperand : getDpsInitsMutable()) {
1259 SmallVector<OpFoldResult> shapes;
1260 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1261 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1262 if (!shapedType.isDynamicDim(dim)) {
1264 shapes.push_back(
b.getIndexAttr(shapedType.getDimSize(dim)));
1267 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1269 : allResultDimValues[pos];
1274 reifiedReturnShapes.emplace_back(std::move(shapes));
1283 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1284 if (!dpsIface.isDpsInput(opOperand))
1285 return operandNumber;
1286 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1287 assert(!dpsIface.isDpsInit(opOperand));
1290 return cast<DestinationStyleOpInterface>(*this->getOperation())
1291 .getNumDpsInputs() +
1292 operandNumber - start;
1296 LinalgOp linalgOp = cast<LinalgOp>(op);
1298 if (!linalgOp.hasPureTensorSemantics() &&
1300 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1304 if (linalgOp.hasDynamicIndexingMaps())
1305 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1309 if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1314 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1315 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1317 unsigned numLoops = linalgOp.getNumLoops();
1321 <<
" dim(s) to match the number of loops";
1324 linalgOp.getReductionDims(redDims);
1326 if (!linalgOp.getShapesToLoopsMap())
1327 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1330 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1331 return op->
emitOpError(
"expects to have 1 region with 1 block");
1339 Block &block = linalgOp->getRegion(0).front();
1341 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1342 return op->
emitOpError(
"expected as many non-induction variable region "
1343 "arguments as the number of input/output operands");
1345 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1347 if (isa<MemRefType, RankedTensorType>(elementType))
1350 if (elementType != argType)
1351 return op->
emitOpError(
"expected type of bb argument #")
1353 <<
" to match element or self type of the corresponding operand ("
1354 << elementType <<
")";
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity, bool allowNonIdentityMaps)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
bool visit(AffineExpr expr)
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether a given genericOp is semantically equivalent to a single linalg elementwise unary op,...
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::function_ref< Fn > function_ref
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
SmallVector< unsigned, 2 > k
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.
SmallVector< unsigned, 2 > depth
SmallVector< unsigned, 2 > outputImage
SmallVector< unsigned, 2 > outputChannel
SmallVector< int64_t, 2 > dilations
SmallVector< int64_t, 2 > strides
SmallVector< unsigned, 2 > inputChannel
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > filterLoop