22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
34#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
43 for (
auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
48 if (indexingMaps.empty()) {
51 return linalgOp.getNumLoops() == 0;
54 indexingMaps, linalgOp.getContext())) !=
AffineMap();
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
72 Block *body = op.getBlock();
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
78 return yieldOp->getOperand(0) == body->
getArgument(0);
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
96 Block *body = op.getBody();
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
104 Value yieldOperand = yieldOp->getOperand(0);
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
140std::optional<SmallVector<int64_t>>
142 if (
auto broadcastOp = dyn_cast<BroadcastOp>(linalgOp.getOperation()))
144 broadcastOp.getDimensions().end());
146 auto op = dyn_cast<GenericOp>(linalgOp.getOperation());
151 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
152 !op.isSingleYieldOp())
155 auto srcTy = op.getDpsInputOperand(0)->get().getType();
156 auto dstTy = op.getDpsInitOperand(0)->get().getType();
157 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
158 !isa<MemRefType, RankedTensorType>(dstTy))
164 auto dstMap = op.getIndexingMapsArray()[1];
165 if (!dstMap.isIdentity())
169 auto srcMap = op.getIndexingMapsArray()[0];
171 if (srcMap.getResults().size() >= dstMap.getResults().size())
175 for (
unsigned i = 0; i < srcMap.getNumResults(); ++i) {
176 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
179 int64_t pos = expr.getPosition();
180 if (i > 0 && pos <= position[i - 1])
182 position.push_back(expr.getPosition());
186 auto numDims = srcMap.getNumDims();
188 for (
auto dim : llvm::seq<int64_t>(0, numDims)) {
189 if (!llvm::is_contained(position, dim))
190 broadcastedDims.push_back(dim);
192 return broadcastedDims;
198std::optional<SmallVector<int64_t>>
203 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
204 !op.isSingleYieldOp())
207 auto mapRange = op.getIndexingMapsArray();
208 if (mapRange.size() != 2)
211 auto mapOfInput = mapRange.front();
212 auto mapOfResult = mapRange.back();
216 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
220 for (
unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
221 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
222 permutation[expr.getPosition()] = i;
233 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
237 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
238 !llvm::all_of(op.getIndexingMapsArray(),
239 [](
AffineMap map) { return map.isIdentity(); }))
243 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
250 Block *body = op.getBody();
258 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
259 return !(!yieldOp || yieldOp.getNumOperands() != 1 ||
260 yieldOp->getOperand(0).getDefiningOp() != oper);
269 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
279 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
280 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
281 return !(!op.payloadUsesValueFromOperand(inputOpOperand0) ||
282 !op.payloadUsesValueFromOperand(inputOpOperand1));
296 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
297 if (!iface || !iface.hasNoEffect())
307 llvm::raw_ostream &errs) {
309 errs <<
"no terminator in the block";
314 errs <<
"expected block with 3 arguments";
320 errs <<
"expected terminator with 1 operand";
328 errs <<
"expected reduction op to be binary";
337 errs <<
"expected reduction to take block argument #2 as one of the "
338 "operands (modulo unary casts)";
343 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
347 errs <<
"expected elementwise op to be binary";
351 if (!isaPair(elementwiseOp, reductionOp)) {
352 errs <<
"expected reduction/elementwise op kind not satisfied";
365 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
372template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
374 static_assert(
sizeof...(Args) % 2 == 0,
375 "expected an even number of template arguments");
376 if (isa<AddOpTy>(
add) && isa<MulOpTy>(
mul))
379 if constexpr (
sizeof...(Args) > 0)
387template <
typename... Args>
399static llvm::SmallDenseSet<int64_t>
402 utils::IteratorType iter) {
403 assert(iterators.size() == indexingMap.
getNumDims());
404 llvm::SmallDenseSet<int64_t> res;
406 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
407 if (iterators[d.getPosition()] == iter &&
409 return e.isFunctionOfDim(d.getPosition());
411 res.insert(d.getPosition());
418auto par = utils::IteratorType::parallel;
419auto red = utils::IteratorType::reduction;
426static FailureOr<SmallVector<utils::IteratorType>>
432 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
433 iterators[dim.getPosition()] = par;
448static FailureOr<ContractionDimensions>
451 llvm::SmallDenseSet<int64_t> a =
453 llvm::SmallDenseSet<int64_t>
b =
455 llvm::SmallDenseSet<int64_t> c =
459 llvm::SmallDenseSet<int64_t> ac = a;
460 llvm::set_intersect(ac, c);
461 llvm::set_subtract(ac,
b);
463 llvm::SmallDenseSet<int64_t> bc =
b;
464 llvm::set_intersect(bc, c);
465 llvm::set_subtract(bc, a);
467 llvm::SmallDenseSet<int64_t> batches = a;
468 llvm::set_intersect(batches,
b);
469 llvm::set_intersect(batches, c);
472 llvm::SmallDenseSet<int64_t> ra =
474 llvm::SmallDenseSet<int64_t> rb =
476 llvm::set_intersect(ra, rb);
484 llvm::sort(dimensions.
batch);
485 llvm::sort(dimensions.
m);
486 llvm::sort(dimensions.
n);
487 llvm::sort(dimensions.
k);
491FailureOr<ContractionDimensions>
493 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
496 linalgOp.getIteratorTypesArray());
499FailureOr<ContractionDimensions>
501 if (indexingMaps.size() != 3)
504 if (failed(iterators))
523 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
526 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
528 auto mapRange = linalgOp.getIndexingMapsArray();
529 if (linalgOp.getNumReductionLoops() == 0)
531 if (llvm::any_of(mapRange,
537 arith::MulFOp, arith::AddFOp,
538 arith::MulIOp, arith::AddIOp,
539 complex::MulOp, complex::AddOp,
540 arith::AndIOp, arith::OrIOp>(
541 *linalgOp.getBlock())) {
548 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
558 return "expected a LinalgOp";
560 return "expected op with 2 inputs and 1 output";
562 return "expected at least 1 reduction";
564 return "expected indexing maps to be projected permutations";
566 return "expected add/mul op in the body";
570 llvm_unreachable(
"unhandled MatchContractionResult case");
577 return isa<ContractionOpInterface>(op) ||
610 return isa<T>(
lhs) ? cast<T>(
lhs) : (isa<T>(
rhs) ? cast<T>(
rhs) :
nullptr);
622struct ConvAccessExprWalker
625 llvm::SmallDenseSet<int64_t> convolvedDims;
627 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
629 llvm::SmallDenseSet<int64_t> unConvolvedDims;
631 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
635 void clearMultiUseDims(AffineMap map) {
636 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
637 if (llvm::count_if(map.
getResults(), [dimPos](AffineExpr e) {
638 return e.isFunctionOfDim(dimPos);
640 convolvedDims.erase(dimPos);
641 unConvolvedDims.erase(dimPos);
644 auto it = convolvedDimMapping.find(dimPos);
645 if (it != convolvedDimMapping.end()) {
646 int64_t pairedDim = it->second;
647 convolvedDims.erase(pairedDim);
648 unConvolvedDims.erase(pairedDim);
649 strideAndDilationMapping.erase(pairedDim);
650 convolvedDimMapping.erase(dimPos);
651 convolvedDimMapping.erase(pairedDim);
657 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
659 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
662 unConvolvedDims.insert(position);
666 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) {
return failure(); }
668 LogicalResult visitConstantExpr(AffineConstantExpr expr) {
return failure(); }
670 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
672 if (binaryExpr.
getKind() != AffineExprKind::Add)
674 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
675 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
678 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
679 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
683 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
684 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
686 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
689 strideAndDilationMapping[dim] =
691 convolvedDims.insert(dim);
694 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
695 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
697 auto lhsExpr = symbolMulExpr.getLHS();
698 auto rhsExpr = symbolMulExpr.getRHS();
707 if (!mulExpr || !dimExpr)
710 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
712 strideAndDilationMapping[dim] = mulExpr;
713 convolvedDims.insert(dim);
723 "expected map to have projected permutations");
724 llvm::SmallDenseSet<int64_t> preservedDims;
726 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
727 return preservedDims;
733 for (
auto e : exprs) {
734 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
735 assert(constantExpr &&
"Found non-constant stride/dilation");
736 vals.push_back(constantExpr.getValue());
754static FailureOr<ConvolutionDimensions>
756 ConvAccessExprWalker &inputExprWalker,
757 bool allowEmptyConvolvedDims) {
759 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
761 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
763 filterMap, linalgOp.getIteratorTypesArray(), par);
765 outputMap, linalgOp.getIteratorTypesArray(), par);
768 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
769 llvm::set_intersect(batch, outputDims);
770 llvm::set_subtract(batch, filterDims);
773 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
774 llvm::set_intersect(oi, outputDims);
777 llvm::SmallDenseSet<int64_t> oc = filterDims;
778 llvm::set_intersect(oc, outputDims);
779 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
782 llvm::SmallDenseSet<int64_t> depth = filterDims;
783 llvm::set_intersect(depth, outputDims);
784 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
786 llvm::SmallDenseSet<int64_t> filterReducedDims =
788 linalgOp.getIteratorTypesArray(), red);
791 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
792 llvm::set_intersect(fl, filterReducedDims);
795 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
796 llvm::set_intersect(ic, filterReducedDims);
798 if (oi.empty() && !allowEmptyConvolvedDims)
812 llvm::sort(dimensions.
batch);
816 llvm::sort(dimensions.
depth);
822 dimensions.
filterLoop.push_back(inputExprWalker.convolvedDimMapping[oiDim]);
826 if (!nativeStrides) {
829 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
832 dimensions.
strides = llvm::to_vector<2>(nativeStrides.getValues<
int64_t>());
834 auto nativeDilations =
836 if (!nativeDilations) {
839 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
843 llvm::to_vector<2>(nativeDilations.getValues<
int64_t>());
876FailureOr<ConvolutionDimensions>
878 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
881 auto indexingMaps = linalgOp.getIndexingMapsArray();
884 ConvAccessExprWalker inputExprWalker;
885 for (
AffineExpr expr : indexingMaps[0].getResults())
886 (
void)inputExprWalker.visit(expr);
887 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
910 bool allowEmptyConvolvedDims) {
911 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
914 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
917 auto indexingMaps = linalgOp.getIndexingMapsArray();
920 ConvAccessExprWalker inputExprWalker;
921 if (llvm::any_of(indexingMaps[0].getResults(),
923 return failed(inputExprWalker.visit(expr));
929 if (!indexingMaps[1].isProjectedPermutation() ||
930 !indexingMaps.back().isProjectedPermutation())
933 auto iteratorTypes = linalgOp.getIteratorTypesArray();
935 llvm::SmallDenseSet<int64_t> outputDims =
937 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
951 llvm::SmallDenseSet<int64_t> allLoopDims;
952 for (
auto outputExpr : indexingMaps.back().getResults()) {
953 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
954 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
955 !filterDims.count(outputDim)) {
957 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
959 allLoopDims.insert(outputDim);
962 if (inputExprWalker.convolvedDims.count(outputDim) &&
963 !filterDims.count(outputDim)) {
965 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
967 allLoopDims.insert(outputDim);
970 if (!inputExprWalker.convolvedDims.count(outputDim) &&
971 !inputExprWalker.unConvolvedDims.count(outputDim) &&
972 filterDims.count(outputDim)) {
974 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
976 allLoopDims.insert(outputDim);
979 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
980 filterDims.count(outputDim)) {
982 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
984 allLoopDims.insert(outputDim);
989 for (
auto filterExpr : indexingMaps[1].getResults()) {
990 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
991 if (outputDims.count(filterDim) &&
992 !inputExprWalker.unConvolvedDims.count(filterDim) &&
993 !inputExprWalker.convolvedDims.count(filterDim)) {
997 if (inputExprWalker.convolvedDims.count(filterDim) &&
998 !outputDims.count(filterDim)) {
1000 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1002 if (allLoopDims.count(filterDim))
1004 allLoopDims.insert(filterDim);
1007 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1008 !outputDims.count(filterDim)) {
1010 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1012 if (allLoopDims.count(filterDim))
1014 allLoopDims.insert(filterDim);
1017 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1018 outputDims.count(filterDim)) {
1025 if (allLoopDims.size() != linalgOp.getNumLoops())
1028 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1033 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1034 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
1045 return "expected a LinalgOp";
1047 return "expected op with 2 inputs and 1 output";
1049 return "unexpected input index map for convolutions";
1051 return "expected output/filter indexing maps to be projected permutations";
1053 return "unexpected loop dimension for convolution op";
1055 return "expected all iterators used to access outputs to be parallel";
1057 return "expected all iterators not used to access outputs to be reduction";
1059 return "expected convolved dim to be non-empty";
1063 llvm_unreachable(
"unhandled MatchConvolutionResult case");
1067 bool allowEmptyConvolvedDims) {
1069 linalgOp.getOperation(),
nullptr, allowEmptyConvolvedDims) ==
1085enum class MatchFillResult {
1095 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1097 return MatchFillResult::NotLinalgOp;
1098 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1099 return MatchFillResult::WrongNumOperands;
1101 OpOperand *value = linalgOp.getDpsInputOperand(0);
1102 if (!linalgOp.isScalar(value))
1103 return MatchFillResult::NotScalarInput;
1106 OpOperand *output = linalgOp.getDpsInitOperand(0);
1109 if (scalarType != outputElementType)
1110 return MatchFillResult::TypeMismatch;
1112 return MatchFillResult::Success;
1117 if (res == MatchFillResult::NotLinalgOp)
1118 return op->
emitError(
"expected a LinalgOp");
1119 if (res == MatchFillResult::WrongNumOperands)
1120 return op->
emitError(
"expected op with 1 input and 1 output");
1121 if (res == MatchFillResult::NotScalarInput)
1122 return op->
emitError(
"expected op with scalar input");
1123 if (res == MatchFillResult::TypeMismatch) {
1124 auto linalgOp = cast<linalg::LinalgOp>(op);
1125 Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
1126 Type outputElementType =
1128 return op->
emitOpError(
"expected fill value type (")
1129 << scalarType <<
") to match output element type ("
1130 << outputElementType <<
")";
1143 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
1144 for (
int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1152 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
1153 for (
OpOperand &opOperand : getOperation()->getOpOperands())
1154 llvm::append_range(res,
getShape(&opOperand));
1159 AffineMap map = getLoopsToShapesMap();
1161 auto viewSizes = createFlatListOfOperandDims(
b, loc);
1162 SmallVector<Range, 4> res(numDims);
1163 for (
unsigned idx = 0; idx < numRes; ++idx) {
1165 if (
auto d = dyn_cast<AffineDimExpr>(
result)) {
1166 if (res[d.getPosition()].offset)
1168 res[d.getPosition()] =
1169 Range{
b.getIndexAttr(0), viewSizes[idx],
b.getIndexAttr(1)};
1180 : positions(std::move(positions)) {}
1195 llvm::SmallBitVector positions;
1198static std::pair<int64_t, int64_t>
1202 for (
OpOperand *input : op.getDpsInputOperands())
1203 inputRankSum += op.getRank(input);
1204 for (
OpOperand &output : op.getDpsInitsMutable())
1205 outputRankSum += op.getRank(&output);
1206 return {inputRankSum, inputRankSum + outputRankSum};
1221 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1229 AffineMap loopToResultsShapeMap = loopsToShapesMap.
getSliceMap(
1230 resultShapesSubMapPos.first,
1231 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1232 AffineMap resultShapesFromInputShapesMap =
1233 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1237 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1238 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1239 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1240 Location loc = getOperation()->getLoc();
1241 IRRewriter rewriter(
b);
1242 SmallVector<OpFoldResult> allResultDimValues =
1243 affine::makeComposedFoldedMultiResultAffineApply(
1244 rewriter, loc, resultShapesFromInputShapesMap,
1245 createFlatListOfOperandDims(
b, loc));
1247 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.
getResults();
1248 for (OpOperand &opOperand : getDpsInitsMutable()) {
1249 SmallVector<OpFoldResult> shapes;
1250 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1251 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1252 if (!shapedType.isDynamicDim(dim)) {
1254 shapes.push_back(
b.getIndexAttr(shapedType.getDimSize(dim)));
1257 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1259 : allResultDimValues[pos];
1264 reifiedReturnShapes.emplace_back(std::move(shapes));
1273 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1274 if (!dpsIface.isDpsInput(opOperand))
1275 return operandNumber;
1276 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1277 assert(!dpsIface.isDpsInit(opOperand));
1280 return cast<DestinationStyleOpInterface>(*this->getOperation())
1281 .getNumDpsInputs() +
1282 operandNumber - start;
1286 LinalgOp linalgOp = cast<LinalgOp>(op);
1288 if (!linalgOp.hasPureTensorSemantics() &&
1290 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1294 if (linalgOp.hasDynamicIndexingMaps())
1295 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1299 if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1304 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1305 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1307 unsigned numLoops = linalgOp.getNumLoops();
1311 <<
" dim(s) to match the number of loops";
1314 linalgOp.getReductionDims(redDims);
1316 if (!linalgOp.getShapesToLoopsMap())
1317 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1320 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1321 return op->
emitOpError(
"expects to have 1 region with 1 block");
1329 Block &block = linalgOp->getRegion(0).front();
1331 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1332 return op->
emitOpError(
"expected as many non-induction variable region "
1333 "arguments as the number of input/output operands");
1335 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1337 if (isa<MemRefType, RankedTensorType>(elementType))
1340 if (elementType != argType)
1341 return op->
emitOpError(
"expected type of bb argument #")
1343 <<
" to match element or self type of the corresponding operand ("
1344 << elementType <<
")";
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
bool visit(AffineExpr expr)
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::function_ref< Fn > function_ref
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
SmallVector< unsigned, 2 > k
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.
SmallVector< unsigned, 2 > depth
SmallVector< unsigned, 2 > outputImage
SmallVector< unsigned, 2 > outputChannel
SmallVector< int64_t, 2 > dilations
SmallVector< int64_t, 2 > strides
SmallVector< unsigned, 2 > inputChannel
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > filterLoop