22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
34#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
43 for (
auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
48 if (indexingMaps.empty()) {
51 return linalgOp.getNumLoops() == 0;
54 indexingMaps, linalgOp.getContext())) !=
AffineMap();
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
72 Block *body = op.getBlock();
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
78 return yieldOp->getOperand(0) == body->
getArgument(0);
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
96 Block *body = op.getBody();
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
104 Value yieldOperand = yieldOp->getOperand(0);
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
140std::optional<SmallVector<int64_t>>
142 if (
auto broadcastOp = dyn_cast<BroadcastOp>(linalgOp.getOperation()))
144 broadcastOp.getDimensions().end());
146 auto op = dyn_cast<GenericOp>(linalgOp.getOperation());
151 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
152 !op.isSingleYieldOp())
155 auto srcTy = op.getDpsInputOperand(0)->get().getType();
156 auto dstTy = op.getDpsInitOperand(0)->get().getType();
157 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
158 !isa<MemRefType, RankedTensorType>(dstTy))
164 auto dstMap = op.getIndexingMapsArray()[1];
165 if (!dstMap.isIdentity())
169 auto srcMap = op.getIndexingMapsArray()[0];
171 if (srcMap.getResults().size() >= dstMap.getResults().size())
175 for (
unsigned i = 0; i < srcMap.getNumResults(); ++i) {
176 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
179 int64_t pos = expr.getPosition();
180 if (i > 0 && pos <= position[i - 1])
182 position.push_back(expr.getPosition());
186 auto numDims = srcMap.getNumDims();
188 for (
auto dim : llvm::seq<int64_t>(0, numDims)) {
189 if (!llvm::is_contained(position, dim))
190 broadcastedDims.push_back(dim);
192 return broadcastedDims;
198std::optional<SmallVector<int64_t>>
203 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
204 !op.isSingleYieldOp())
207 auto mapRange = op.getIndexingMapsArray();
208 if (mapRange.size() != 2)
211 auto mapOfInput = mapRange.front();
212 auto mapOfResult = mapRange.back();
216 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
220 for (
unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
221 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
222 permutation[expr.getPosition()] = i;
232 bool allowNonIdentityMaps) {
234 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
239 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
240 (!allowNonIdentityMaps &&
241 !llvm::all_of(op.getIndexingMapsArray(),
242 [](
AffineMap map) { return map.isIdentity(); })))
246 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
253 Block *body = op.getBody();
265 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
266 return !(!yieldOp || yieldOp.getNumOperands() != 1 ||
267 yieldOp->getOperand(0).getDefiningOp() != oper);
271 bool allowNonIdentityMaps) {
277 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
283 bool allowNonIdentityMaps) {
289 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
290 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
291 return !(!op.payloadUsesValueFromOperand(inputOpOperand0) ||
292 !op.payloadUsesValueFromOperand(inputOpOperand1));
306 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
307 if (!iface || !iface.hasNoEffect())
317 llvm::raw_ostream &errs) {
319 errs <<
"no terminator in the block";
324 errs <<
"expected block with 3 arguments";
330 errs <<
"expected terminator with 1 operand";
338 errs <<
"expected reduction op to be binary";
347 errs <<
"expected reduction to take block argument #2 as one of the "
348 "operands (modulo unary casts)";
353 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
357 errs <<
"expected elementwise op to be binary";
361 if (!isaPair(elementwiseOp, reductionOp)) {
362 errs <<
"expected reduction/elementwise op kind not satisfied";
375 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
382template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
384 static_assert(
sizeof...(Args) % 2 == 0,
385 "expected an even number of template arguments");
386 if (isa<AddOpTy>(
add) && isa<MulOpTy>(
mul))
389 if constexpr (
sizeof...(Args) > 0)
397template <
typename... Args>
409static llvm::SmallDenseSet<int64_t>
412 utils::IteratorType iter) {
413 assert(iterators.size() == indexingMap.
getNumDims());
414 llvm::SmallDenseSet<int64_t> res;
416 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
417 if (iterators[d.getPosition()] == iter &&
419 return e.isFunctionOfDim(d.getPosition());
421 res.insert(d.getPosition());
428auto par = utils::IteratorType::parallel;
429auto red = utils::IteratorType::reduction;
436static FailureOr<SmallVector<utils::IteratorType>>
442 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
443 iterators[dim.getPosition()] = par;
458static FailureOr<ContractionDimensions>
461 llvm::SmallDenseSet<int64_t> a =
463 llvm::SmallDenseSet<int64_t>
b =
465 llvm::SmallDenseSet<int64_t> c =
469 llvm::SmallDenseSet<int64_t> ac = a;
470 llvm::set_intersect(ac, c);
471 llvm::set_subtract(ac,
b);
473 llvm::SmallDenseSet<int64_t> bc =
b;
474 llvm::set_intersect(bc, c);
475 llvm::set_subtract(bc, a);
477 llvm::SmallDenseSet<int64_t> batches = a;
478 llvm::set_intersect(batches,
b);
479 llvm::set_intersect(batches, c);
482 llvm::SmallDenseSet<int64_t> ra =
484 llvm::SmallDenseSet<int64_t> rb =
486 llvm::set_intersect(ra, rb);
494 llvm::sort(dimensions.
batch);
495 llvm::sort(dimensions.
m);
496 llvm::sort(dimensions.
n);
497 llvm::sort(dimensions.
k);
501FailureOr<ContractionDimensions>
503 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
506 linalgOp.getIteratorTypesArray());
509FailureOr<ContractionDimensions>
511 if (indexingMaps.size() != 3)
514 if (failed(iterators))
533 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
536 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
538 auto mapRange = linalgOp.getIndexingMapsArray();
539 if (linalgOp.getNumReductionLoops() == 0)
541 if (llvm::any_of(mapRange,
547 arith::MulFOp, arith::AddFOp,
548 arith::MulIOp, arith::AddIOp,
549 complex::MulOp, complex::AddOp,
550 arith::AndIOp, arith::OrIOp>(
551 *linalgOp.getBlock())) {
558 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
568 return "expected a LinalgOp";
570 return "expected op with 2 inputs and 1 output";
572 return "expected at least 1 reduction";
574 return "expected indexing maps to be projected permutations";
576 return "expected add/mul op in the body";
580 llvm_unreachable(
"unhandled MatchContractionResult case");
587 return isa<ContractionOpInterface>(op) ||
620 return isa<T>(
lhs) ? cast<T>(
lhs) : (isa<T>(
rhs) ? cast<T>(
rhs) :
nullptr);
632struct ConvAccessExprWalker
635 llvm::SmallDenseSet<int64_t> convolvedDims;
637 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
639 llvm::SmallDenseSet<int64_t> unConvolvedDims;
641 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
645 void clearMultiUseDims(AffineMap map) {
646 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
647 if (llvm::count_if(map.
getResults(), [dimPos](AffineExpr e) {
648 return e.isFunctionOfDim(dimPos);
650 convolvedDims.erase(dimPos);
651 unConvolvedDims.erase(dimPos);
654 auto it = convolvedDimMapping.find(dimPos);
655 if (it != convolvedDimMapping.end()) {
656 int64_t pairedDim = it->second;
657 convolvedDims.erase(pairedDim);
658 unConvolvedDims.erase(pairedDim);
659 strideAndDilationMapping.erase(pairedDim);
660 convolvedDimMapping.erase(dimPos);
661 convolvedDimMapping.erase(pairedDim);
667 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
669 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
672 unConvolvedDims.insert(position);
676 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) {
return failure(); }
678 LogicalResult visitConstantExpr(AffineConstantExpr expr) {
return failure(); }
680 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
682 if (binaryExpr.
getKind() != AffineExprKind::Add)
684 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
685 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
688 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
689 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
693 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
694 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
696 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
699 strideAndDilationMapping[dim] =
701 convolvedDims.insert(dim);
704 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
705 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
707 auto lhsExpr = symbolMulExpr.getLHS();
708 auto rhsExpr = symbolMulExpr.getRHS();
717 if (!mulExpr || !dimExpr)
720 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
722 strideAndDilationMapping[dim] = mulExpr;
723 convolvedDims.insert(dim);
733 "expected map to have projected permutations");
734 llvm::SmallDenseSet<int64_t> preservedDims;
736 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
737 return preservedDims;
743 for (
auto e : exprs) {
744 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
745 assert(constantExpr &&
"Found non-constant stride/dilation");
746 vals.push_back(constantExpr.getValue());
771 ConvAccessExprWalker &inputExprWalker,
bool allowEmptyConvolvedDims,
774 AffineMap outputMap = indexingMaps.back();
775 llvm::SmallDenseSet<int64_t> filterDims =
777 llvm::SmallDenseSet<int64_t> outputDims =
781 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
782 llvm::set_intersect(batch, outputDims);
783 llvm::set_subtract(batch, filterDims);
786 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
787 llvm::set_intersect(oi, outputDims);
790 llvm::SmallDenseSet<int64_t> oc = filterDims;
791 llvm::set_intersect(oc, outputDims);
792 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
795 llvm::SmallDenseSet<int64_t> depth = filterDims;
796 llvm::set_intersect(depth, outputDims);
797 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
799 llvm::SmallDenseSet<int64_t> filterReducedDims =
803 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
804 llvm::set_intersect(fl, filterReducedDims);
807 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
808 llvm::set_intersect(ic, filterReducedDims);
810 if (oi.empty() && !allowEmptyConvolvedDims)
824 llvm::sort(dimensions.
batch);
828 llvm::sort(dimensions.
depth);
834 dimensions.
filterLoop.push_back(inputExprWalker.convolvedDimMapping[oiDim]);
837 if (!nativeStrides) {
840 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
843 dimensions.
strides = llvm::to_vector<2>(nativeStrides.getValues<
int64_t>());
845 if (!nativeDilations) {
848 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
852 llvm::to_vector<2>(nativeDilations.getValues<
int64_t>());
885FailureOr<ConvolutionDimensions>
887 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
890 auto indexingMaps = linalgOp.getIndexingMapsArray();
893 ConvAccessExprWalker inputExprWalker;
894 for (
AffineExpr expr : indexingMaps[0].getResults())
895 (
void)inputExprWalker.visit(expr);
896 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
899 indexingMaps, linalgOp.getIteratorTypesArray(), inputExprWalker,
905FailureOr<ConvolutionDimensions>
907 if (indexingMaps.size() != 3)
911 FailureOr<SmallVector<utils::IteratorType>> iterators =
913 if (failed(iterators))
917 ConvAccessExprWalker inputExprWalker;
918 for (
AffineExpr expr : indexingMaps[0].getResults())
919 (
void)inputExprWalker.visit(expr);
920 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
946 bool allowEmptyConvolvedDims) {
947 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
950 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
953 auto indexingMaps = linalgOp.getIndexingMapsArray();
956 ConvAccessExprWalker inputExprWalker;
957 if (llvm::any_of(indexingMaps[0].getResults(),
959 return failed(inputExprWalker.visit(expr));
965 if (!indexingMaps[1].isProjectedPermutation() ||
966 !indexingMaps.back().isProjectedPermutation())
969 auto iteratorTypes = linalgOp.getIteratorTypesArray();
971 llvm::SmallDenseSet<int64_t> outputDims =
973 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
987 llvm::SmallDenseSet<int64_t> allLoopDims;
988 for (
auto outputExpr : indexingMaps.back().getResults()) {
989 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
990 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
991 !filterDims.count(outputDim)) {
993 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
995 allLoopDims.insert(outputDim);
998 if (inputExprWalker.convolvedDims.count(outputDim) &&
999 !filterDims.count(outputDim)) {
1001 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
1003 allLoopDims.insert(outputDim);
1006 if (!inputExprWalker.convolvedDims.count(outputDim) &&
1007 !inputExprWalker.unConvolvedDims.count(outputDim) &&
1008 filterDims.count(outputDim)) {
1010 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
1012 allLoopDims.insert(outputDim);
1015 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
1016 filterDims.count(outputDim)) {
1018 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
1020 allLoopDims.insert(outputDim);
1025 for (
auto filterExpr : indexingMaps[1].getResults()) {
1026 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
1027 if (outputDims.count(filterDim) &&
1028 !inputExprWalker.unConvolvedDims.count(filterDim) &&
1029 !inputExprWalker.convolvedDims.count(filterDim)) {
1033 if (inputExprWalker.convolvedDims.count(filterDim) &&
1034 !outputDims.count(filterDim)) {
1036 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1038 if (allLoopDims.count(filterDim))
1040 allLoopDims.insert(filterDim);
1043 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1044 !outputDims.count(filterDim)) {
1046 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1048 if (allLoopDims.count(filterDim))
1050 allLoopDims.insert(filterDim);
1053 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1054 outputDims.count(filterDim)) {
1061 if (allLoopDims.size() != linalgOp.getNumLoops())
1064 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1069 indexingMaps, iteratorTypes, inputExprWalker, allowEmptyConvolvedDims,
1072 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
1083 return "expected a LinalgOp";
1085 return "expected op with 2 inputs and 1 output";
1087 return "unexpected input index map for convolutions";
1089 return "expected output/filter indexing maps to be projected permutations";
1091 return "unexpected loop dimension for convolution op";
1093 return "expected all iterators used to access outputs to be parallel";
1095 return "expected all iterators not used to access outputs to be reduction";
1097 return "expected convolved dim to be non-empty";
1101 llvm_unreachable(
"unhandled MatchConvolutionResult case");
1105 bool allowEmptyConvolvedDims) {
1107 linalgOp.getOperation(),
nullptr, allowEmptyConvolvedDims) ==
1123enum class MatchFillResult {
1133 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1135 return MatchFillResult::NotLinalgOp;
1136 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1137 return MatchFillResult::WrongNumOperands;
1139 OpOperand *value = linalgOp.getDpsInputOperand(0);
1140 if (!linalgOp.isScalar(value))
1141 return MatchFillResult::NotScalarInput;
1144 OpOperand *output = linalgOp.getDpsInitOperand(0);
1147 if (scalarType != outputElementType)
1148 return MatchFillResult::TypeMismatch;
1150 return MatchFillResult::Success;
1155 if (res == MatchFillResult::NotLinalgOp)
1156 return op->
emitError(
"expected a LinalgOp");
1157 if (res == MatchFillResult::WrongNumOperands)
1158 return op->
emitError(
"expected op with 1 input and 1 output");
1159 if (res == MatchFillResult::NotScalarInput)
1160 return op->
emitError(
"expected op with scalar input");
1161 if (res == MatchFillResult::TypeMismatch) {
1162 auto linalgOp = cast<linalg::LinalgOp>(op);
1163 Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
1164 Type outputElementType =
1166 return op->
emitOpError(
"expected fill value type (")
1167 << scalarType <<
") to match output element type ("
1168 << outputElementType <<
")";
1181 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
1182 for (
int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1190 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
1191 for (
OpOperand &opOperand : getOperation()->getOpOperands())
1192 llvm::append_range(res,
getShape(&opOperand));
1197 AffineMap map = getLoopsToShapesMap();
1199 auto viewSizes = createFlatListOfOperandDims(
b, loc);
1200 SmallVector<Range, 4> res(numDims);
1201 for (
unsigned idx = 0; idx < numRes; ++idx) {
1203 if (
auto d = dyn_cast<AffineDimExpr>(
result)) {
1204 if (res[d.getPosition()].offset)
1206 res[d.getPosition()] =
1207 Range{
b.getIndexAttr(0), viewSizes[idx],
b.getIndexAttr(1)};
1218 : positions(std::move(positions)) {}
1233 llvm::SmallBitVector positions;
1236static std::pair<int64_t, int64_t>
1240 for (
OpOperand *input : op.getDpsInputOperands())
1241 inputRankSum += op.getRank(input);
1242 for (
OpOperand &output : op.getDpsInitsMutable())
1243 outputRankSum += op.getRank(&output);
1244 return {inputRankSum, inputRankSum + outputRankSum};
1259 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1267 AffineMap loopToResultsShapeMap = loopsToShapesMap.
getSliceMap(
1268 resultShapesSubMapPos.first,
1269 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1270 AffineMap resultShapesFromInputShapesMap =
1271 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1275 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1276 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1277 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1278 Location loc = getOperation()->getLoc();
1279 IRRewriter rewriter(
b);
1280 SmallVector<OpFoldResult> allResultDimValues =
1282 rewriter, loc, resultShapesFromInputShapesMap,
1283 createFlatListOfOperandDims(
b, loc));
1285 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.
getResults();
1286 for (OpOperand &opOperand : getDpsInitsMutable()) {
1287 SmallVector<OpFoldResult> shapes;
1288 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1289 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1290 if (!shapedType.isDynamicDim(dim)) {
1292 shapes.push_back(
b.getIndexAttr(shapedType.getDimSize(dim)));
1295 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1297 : allResultDimValues[pos];
1302 reifiedReturnShapes.emplace_back(std::move(shapes));
1311 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1312 if (!dpsIface.isDpsInput(opOperand))
1313 return operandNumber;
1314 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1315 assert(!dpsIface.isDpsInit(opOperand));
1318 return cast<DestinationStyleOpInterface>(*this->getOperation())
1319 .getNumDpsInputs() +
1320 operandNumber - start;
1324 LinalgOp linalgOp = cast<LinalgOp>(op);
1326 if (!linalgOp.hasPureTensorSemantics() &&
1328 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1332 if (linalgOp.hasDynamicIndexingMaps())
1333 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1337 if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1342 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1343 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1345 unsigned numLoops = linalgOp.getNumLoops();
1349 <<
" dim(s) to match the number of loops";
1352 linalgOp.getReductionDims(redDims);
1354 if (!linalgOp.getShapesToLoopsMap())
1355 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1358 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1359 return op->
emitOpError(
"expects to have 1 region with 1 block");
1367 Block &block = linalgOp->getRegion(0).front();
1369 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1370 return op->
emitOpError(
"expected as many non-induction variable region "
1371 "arguments as the number of input/output operands");
1373 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1375 if (isa<MemRefType, RankedTensorType>(elementType))
1378 if (elementType != argType)
1379 return op->
emitOpError(
"expected type of bb argument #")
1381 <<
" to match element or self type of the corresponding operand ("
1382 << elementType <<
")";
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims, DenseIntElementsAttr nativeStrides, DenseIntElementsAttr nativeDilations)
Classifies dimensions in the indexingMaps used by a convolution subcomputation, as captured by inputE...
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity, bool allowNonIdentityMaps)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
bool visit(AffineExpr expr)
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether a given genericOp is semantically equivalent to a single linalg elementwise unary op,...
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::function_ref< Fn > function_ref
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
SmallVector< unsigned, 2 > k
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.
SmallVector< unsigned, 2 > depth
SmallVector< unsigned, 2 > outputImage
SmallVector< unsigned, 2 > outputChannel
SmallVector< int64_t, 2 > dilations
SmallVector< int64_t, 2 > strides
SmallVector< unsigned, 2 > inputChannel
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > filterLoop