22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
34#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
43 for (
auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
48 if (indexingMaps.empty()) {
51 return linalgOp.getNumLoops() == 0;
54 indexingMaps, linalgOp.getContext())) !=
AffineMap();
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
72 Block *body = op.getBlock();
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
78 return yieldOp->getOperand(0) == body->
getArgument(0);
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
96 Block *body = op.getBody();
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
104 Value yieldOperand = yieldOp->getOperand(0);
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
140std::optional<SmallVector<int64_t>>
143 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
144 !op.isSingleYieldOp())
147 auto srcTy = op.getDpsInputOperand(0)->get().getType();
148 auto dstTy = op.getDpsInitOperand(0)->get().getType();
149 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
150 !isa<MemRefType, RankedTensorType>(dstTy))
156 auto dstMap = op.getIndexingMapsArray()[1];
157 if (!dstMap.isIdentity())
161 auto srcMap = op.getIndexingMapsArray()[0];
163 if (srcMap.getResults().size() >= dstMap.getResults().size())
167 for (
unsigned i = 0; i < srcMap.getNumResults(); ++i) {
168 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
171 int64_t pos = expr.getPosition();
172 if (i > 0 && pos <= position[i - 1])
174 position.push_back(expr.getPosition());
178 auto numDims = srcMap.getNumDims();
180 for (
auto dim : llvm::seq<int64_t>(0, numDims)) {
181 if (!llvm::is_contained(position, dim))
182 broadcastedDims.push_back(dim);
184 return broadcastedDims;
190std::optional<SmallVector<int64_t>>
195 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
196 !op.isSingleYieldOp())
199 auto mapRange = op.getIndexingMapsArray();
200 if (mapRange.size() != 2)
203 auto mapOfInput = mapRange.front();
204 auto mapOfResult = mapRange.back();
208 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
212 for (
unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
213 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
214 permutation[expr.getPosition()] = i;
225 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
229 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
230 !llvm::all_of(op.getIndexingMapsArray(),
231 [](
AffineMap map) { return map.isIdentity(); }))
235 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
242 Block *body = op.getBody();
250 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
251 return !(!yieldOp || yieldOp.getNumOperands() != 1 ||
252 yieldOp->getOperand(0).getDefiningOp() != oper);
261 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
271 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
272 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
273 return !(!op.payloadUsesValueFromOperand(inputOpOperand0) ||
274 !op.payloadUsesValueFromOperand(inputOpOperand1));
288 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
289 if (!iface || !iface.hasNoEffect())
299 llvm::raw_ostream &errs) {
301 errs <<
"no terminator in the block";
306 errs <<
"expected block with 3 arguments";
312 errs <<
"expected terminator with 1 operand";
320 errs <<
"expected reduction op to be binary";
329 errs <<
"expected reduction to take block argument #2 as one of the "
330 "operands (modulo unary casts)";
335 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
339 errs <<
"expected elementwise op to be binary";
343 if (!isaPair(elementwiseOp, reductionOp)) {
344 errs <<
"expected reduction/elementwise op kind not satisfied";
357 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
364template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
366 static_assert(
sizeof...(Args) % 2 == 0,
367 "expected an even number of template arguments");
368 if (isa<AddOpTy>(
add) && isa<MulOpTy>(
mul))
371 if constexpr (
sizeof...(Args) > 0)
379template <
typename... Args>
391static llvm::SmallDenseSet<int64_t>
394 utils::IteratorType iter) {
395 assert(iterators.size() == indexingMap.
getNumDims());
396 llvm::SmallDenseSet<int64_t> res;
398 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
399 if (iterators[d.getPosition()] == iter &&
401 return e.isFunctionOfDim(d.getPosition());
403 res.insert(d.getPosition());
410auto par = utils::IteratorType::parallel;
411auto red = utils::IteratorType::reduction;
418static FailureOr<SmallVector<utils::IteratorType>>
424 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
425 iterators[dim.getPosition()] = par;
440static FailureOr<ContractionDimensions>
443 llvm::SmallDenseSet<int64_t> a =
445 llvm::SmallDenseSet<int64_t>
b =
447 llvm::SmallDenseSet<int64_t> c =
451 llvm::SmallDenseSet<int64_t> ac = a;
452 llvm::set_intersect(ac, c);
453 llvm::set_subtract(ac,
b);
455 llvm::SmallDenseSet<int64_t> bc =
b;
456 llvm::set_intersect(bc, c);
457 llvm::set_subtract(bc, a);
459 llvm::SmallDenseSet<int64_t> batches = a;
460 llvm::set_intersect(batches,
b);
461 llvm::set_intersect(batches, c);
464 llvm::SmallDenseSet<int64_t> ra =
466 llvm::SmallDenseSet<int64_t> rb =
468 llvm::set_intersect(ra, rb);
476 llvm::sort(dimensions.
batch);
477 llvm::sort(dimensions.
m);
478 llvm::sort(dimensions.
n);
479 llvm::sort(dimensions.
k);
483FailureOr<ContractionDimensions>
485 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
488 linalgOp.getIteratorTypesArray());
491FailureOr<ContractionDimensions>
493 if (indexingMaps.size() != 3)
496 if (failed(iterators))
515 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
518 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
520 auto mapRange = linalgOp.getIndexingMapsArray();
521 if (linalgOp.getNumReductionLoops() == 0)
523 if (llvm::any_of(mapRange,
529 arith::MulFOp, arith::AddFOp,
530 arith::MulIOp, arith::AddIOp,
531 complex::MulOp, complex::AddOp,
532 arith::AndIOp, arith::OrIOp>(
533 *linalgOp.getBlock())) {
540 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
550 return "expected a LinalgOp";
552 return "expected op with 2 inputs and 1 output";
554 return "expected at least 1 reduction";
556 return "expected indexing maps to be projected permutations";
558 return "expected add/mul op in the body";
562 llvm_unreachable(
"unhandled MatchContractionResult case");
569 return isa<ContractionOpInterface>(op) ||
602 return isa<T>(
lhs) ? cast<T>(
lhs) : (isa<T>(
rhs) ? cast<T>(
rhs) :
nullptr);
614struct ConvAccessExprWalker
617 llvm::SmallDenseSet<int64_t> convolvedDims;
619 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
621 llvm::SmallDenseSet<int64_t> unConvolvedDims;
623 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
627 void clearMultiUseDims(AffineMap map) {
628 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
629 if (llvm::count_if(map.
getResults(), [dimPos](AffineExpr e) {
630 return e.isFunctionOfDim(dimPos);
632 convolvedDims.erase(dimPos);
633 unConvolvedDims.erase(dimPos);
636 auto it = convolvedDimMapping.find(dimPos);
637 if (it != convolvedDimMapping.end()) {
638 int64_t pairedDim = it->second;
639 convolvedDims.erase(pairedDim);
640 unConvolvedDims.erase(pairedDim);
641 strideAndDilationMapping.erase(pairedDim);
642 convolvedDimMapping.erase(dimPos);
643 convolvedDimMapping.erase(pairedDim);
649 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
651 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
654 unConvolvedDims.insert(position);
658 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) {
return failure(); }
660 LogicalResult visitConstantExpr(AffineConstantExpr expr) {
return failure(); }
662 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
664 if (binaryExpr.
getKind() != AffineExprKind::Add)
666 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
667 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
670 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
671 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
675 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
676 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
678 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
681 strideAndDilationMapping[dim] =
683 convolvedDims.insert(dim);
686 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
687 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
689 auto lhsExpr = symbolMulExpr.getLHS();
690 auto rhsExpr = symbolMulExpr.getRHS();
699 if (!mulExpr || !dimExpr)
702 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
704 strideAndDilationMapping[dim] = mulExpr;
705 convolvedDims.insert(dim);
715 "expected map to have projected permutations");
716 llvm::SmallDenseSet<int64_t> preservedDims;
718 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
719 return preservedDims;
725 for (
auto e : exprs) {
726 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
727 assert(constantExpr &&
"Found non-constant stride/dilation");
728 vals.push_back(constantExpr.getValue());
740static FailureOr<ConvolutionDimensions>
742 ConvAccessExprWalker &inputExprWalker,
743 bool allowEmptyConvolvedDims) {
745 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
747 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
749 filterMap, linalgOp.getIteratorTypesArray(), par);
751 outputMap, linalgOp.getIteratorTypesArray(), par);
754 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
755 llvm::set_intersect(batch, outputDims);
756 llvm::set_subtract(batch, filterDims);
759 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
760 llvm::set_intersect(oi, outputDims);
763 llvm::SmallDenseSet<int64_t> oc = filterDims;
764 llvm::set_intersect(oc, outputDims);
765 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
768 llvm::SmallDenseSet<int64_t> depth = filterDims;
769 llvm::set_intersect(depth, outputDims);
770 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
772 llvm::SmallDenseSet<int64_t> filterReducedDims =
774 linalgOp.getIteratorTypesArray(), red);
777 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
778 llvm::set_intersect(fl, filterReducedDims);
781 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
782 llvm::set_intersect(ic, filterReducedDims);
784 if (oi.empty() && !allowEmptyConvolvedDims)
797 llvm::sort(dimensions.
batch);
802 llvm::sort(dimensions.
depth);
806 if (!nativeStrides) {
809 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
812 dimensions.
strides = llvm::to_vector<2>(nativeStrides.getValues<
int64_t>());
814 auto nativeDilations =
816 if (!nativeDilations) {
819 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
823 llvm::to_vector<2>(nativeDilations.getValues<
int64_t>());
852FailureOr<ConvolutionDimensions>
854 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
857 auto indexingMaps = linalgOp.getIndexingMapsArray();
860 ConvAccessExprWalker inputExprWalker;
861 for (
AffineExpr expr : indexingMaps[0].getResults())
862 (
void)inputExprWalker.visit(expr);
863 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
886 bool allowEmptyConvolvedDims) {
887 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
890 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
893 auto indexingMaps = linalgOp.getIndexingMapsArray();
896 ConvAccessExprWalker inputExprWalker;
897 if (llvm::any_of(indexingMaps[0].getResults(),
899 return failed(inputExprWalker.visit(expr));
905 if (!indexingMaps[1].isProjectedPermutation() ||
906 !indexingMaps.back().isProjectedPermutation())
909 auto iteratorTypes = linalgOp.getIteratorTypesArray();
911 llvm::SmallDenseSet<int64_t> outputDims =
913 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
927 llvm::SmallDenseSet<int64_t> allLoopDims;
928 for (
auto outputExpr : indexingMaps.back().getResults()) {
929 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
930 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
931 !filterDims.count(outputDim)) {
933 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
935 allLoopDims.insert(outputDim);
938 if (inputExprWalker.convolvedDims.count(outputDim) &&
939 !filterDims.count(outputDim)) {
941 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
943 allLoopDims.insert(outputDim);
946 if (!inputExprWalker.convolvedDims.count(outputDim) &&
947 !inputExprWalker.unConvolvedDims.count(outputDim) &&
948 filterDims.count(outputDim)) {
950 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
952 allLoopDims.insert(outputDim);
955 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
956 filterDims.count(outputDim)) {
958 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
960 allLoopDims.insert(outputDim);
965 for (
auto filterExpr : indexingMaps[1].getResults()) {
966 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
967 if (outputDims.count(filterDim) &&
968 !inputExprWalker.unConvolvedDims.count(filterDim) &&
969 !inputExprWalker.convolvedDims.count(filterDim)) {
973 if (inputExprWalker.convolvedDims.count(filterDim) &&
974 !outputDims.count(filterDim)) {
976 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
978 if (allLoopDims.count(filterDim))
980 allLoopDims.insert(filterDim);
983 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
984 !outputDims.count(filterDim)) {
986 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
988 if (allLoopDims.count(filterDim))
990 allLoopDims.insert(filterDim);
993 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
994 outputDims.count(filterDim)) {
1001 if (allLoopDims.size() != linalgOp.getNumLoops())
1004 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1009 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1010 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
1021 return "expected a LinalgOp";
1023 return "expected op with 2 inputs and 1 output";
1025 return "unexpected input index map for convolutions";
1027 return "expected output/filter indexing maps to be projected permutations";
1029 return "unexpected loop dimension for convolution op";
1031 return "expected all iterators used to access outputs to be parallel";
1033 return "expected all iterators not used to access outputs to be reduction";
1035 return "expected convolved dim to be non-empty";
1039 llvm_unreachable(
"unhandled MatchConvolutionResult case");
1043 bool allowEmptyConvolvedDims) {
1045 linalgOp.getOperation(),
nullptr, allowEmptyConvolvedDims) ==
1061enum class MatchFillResult {
1071 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1073 return MatchFillResult::NotLinalgOp;
1074 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1075 return MatchFillResult::WrongNumOperands;
1077 OpOperand *value = linalgOp.getDpsInputOperand(0);
1078 if (!linalgOp.isScalar(value))
1079 return MatchFillResult::NotScalarInput;
1082 OpOperand *output = linalgOp.getDpsInitOperand(0);
1085 if (scalarType != outputElementType)
1086 return MatchFillResult::TypeMismatch;
1088 return MatchFillResult::Success;
1093 if (res == MatchFillResult::NotLinalgOp)
1094 return op->
emitError(
"expected a LinalgOp");
1095 if (res == MatchFillResult::WrongNumOperands)
1096 return op->
emitError(
"expected op with 1 input and 1 output");
1097 if (res == MatchFillResult::NotScalarInput)
1098 return op->
emitError(
"expected op with scalar input");
1099 if (res == MatchFillResult::TypeMismatch) {
1100 auto linalgOp = cast<linalg::LinalgOp>(op);
1101 Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
1102 Type outputElementType =
1104 return op->
emitOpError(
"expected fill value type (")
1105 << scalarType <<
") to match output element type ("
1106 << outputElementType <<
")";
1119 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
1120 for (
int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1128 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
1129 for (
OpOperand &opOperand : getOperation()->getOpOperands())
1130 llvm::append_range(res,
getShape(&opOperand));
1135 AffineMap map = getLoopsToShapesMap();
1137 auto viewSizes = createFlatListOfOperandDims(
b, loc);
1138 SmallVector<Range, 4> res(numDims);
1139 for (
unsigned idx = 0; idx < numRes; ++idx) {
1141 if (
auto d = dyn_cast<AffineDimExpr>(
result)) {
1142 if (res[d.getPosition()].offset)
1144 res[d.getPosition()] =
1145 Range{
b.getIndexAttr(0), viewSizes[idx],
b.getIndexAttr(1)};
1156 : positions(std::move(positions)) {}
1171 llvm::SmallBitVector positions;
1174static std::pair<int64_t, int64_t>
1178 for (
OpOperand *input : op.getDpsInputOperands())
1179 inputRankSum += op.getRank(input);
1180 for (
OpOperand &output : op.getDpsInitsMutable())
1181 outputRankSum += op.getRank(&output);
1182 return {inputRankSum, inputRankSum + outputRankSum};
1197 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1205 AffineMap loopToResultsShapeMap = loopsToShapesMap.
getSliceMap(
1206 resultShapesSubMapPos.first,
1207 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1208 AffineMap resultShapesFromInputShapesMap =
1209 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1213 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1214 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1215 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1216 Location loc = getOperation()->getLoc();
1217 IRRewriter rewriter(
b);
1218 SmallVector<OpFoldResult> allResultDimValues =
1219 affine::makeComposedFoldedMultiResultAffineApply(
1220 rewriter, loc, resultShapesFromInputShapesMap,
1221 createFlatListOfOperandDims(
b, loc));
1223 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.
getResults();
1224 for (OpOperand &opOperand : getDpsInitsMutable()) {
1225 SmallVector<OpFoldResult> shapes;
1226 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1227 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1228 if (!shapedType.isDynamicDim(dim)) {
1230 shapes.push_back(
b.getIndexAttr(shapedType.getDimSize(dim)));
1233 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1235 : allResultDimValues[pos];
1240 reifiedReturnShapes.emplace_back(std::move(shapes));
1249 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1250 if (!dpsIface.isDpsInput(opOperand))
1251 return operandNumber;
1252 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1253 assert(!dpsIface.isDpsInit(opOperand));
1256 return cast<DestinationStyleOpInterface>(*this->getOperation())
1257 .getNumDpsInputs() +
1258 operandNumber - start;
1262 LinalgOp linalgOp = cast<LinalgOp>(op);
1264 if (!linalgOp.hasPureTensorSemantics() &&
1266 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1270 if (linalgOp.hasDynamicIndexingMaps())
1271 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1275 if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1280 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1281 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1283 unsigned numLoops = linalgOp.getNumLoops();
1287 <<
" dim(s) to match the number of loops";
1290 linalgOp.getReductionDims(redDims);
1292 if (!linalgOp.getShapesToLoopsMap())
1293 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1296 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1297 return op->
emitOpError(
"expects to have 1 region with 1 block");
1305 Block &block = linalgOp->getRegion(0).front();
1307 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1308 return op->
emitOpError(
"expected as many non-induction variable region "
1309 "arguments as the number of input/output operands");
1311 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1313 if (isa<MemRefType, RankedTensorType>(elementType))
1316 if (elementType != argType)
1317 return op->
emitOpError(
"expected type of bb argument #")
1319 <<
" to match element or self type of the corresponding operand ("
1320 << elementType <<
")";
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
bool visit(AffineExpr expr)
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::function_ref< Fn > function_ref
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
SmallVector< unsigned, 2 > k
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.
SmallVector< unsigned, 2 > depth
SmallVector< unsigned, 2 > outputImage
SmallVector< unsigned, 2 > outputChannel
SmallVector< int64_t, 2 > dilations
SmallVector< int64_t, 2 > strides
SmallVector< unsigned, 2 > inputChannel
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > filterLoop