22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
34#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
43 for (
auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(droppedOperands, &opOperand))
46 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
48 if (indexingMaps.empty()) {
51 return linalgOp.getNumLoops() == 0;
54 indexingMaps, linalgOp.getContext())) !=
AffineMap();
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
72 Block *body = op.getBlock();
75 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
76 if (!yieldOp || yieldOp.getNumOperands() != 1)
78 return yieldOp->getOperand(0) == body->
getArgument(0);
88 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89 op.getNumDpsInputs() != 0)
93 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
96 Block *body = op.getBody();
100 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
101 if (!yieldOp || yieldOp.getNumOperands() != 1)
104 Value yieldOperand = yieldOp->getOperand(0);
116 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117 !op.isSingleYieldOp())
121 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
125 OpOperand *value = op.getDpsInputOperand(0);
126 if (!op.isScalar(value))
140std::optional<SmallVector<int64_t>>
142 if (
auto broadcastOp = dyn_cast<BroadcastOp>(linalgOp.getOperation()))
144 broadcastOp.getDimensions().end());
146 auto op = dyn_cast<GenericOp>(linalgOp.getOperation());
151 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
152 !op.isSingleYieldOp())
155 auto srcTy = op.getDpsInputOperand(0)->get().getType();
156 auto dstTy = op.getDpsInitOperand(0)->get().getType();
157 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
158 !isa<MemRefType, RankedTensorType>(dstTy))
164 auto dstMap = op.getIndexingMapsArray()[1];
165 if (!dstMap.isIdentity())
169 auto srcMap = op.getIndexingMapsArray()[0];
171 if (srcMap.getResults().size() >= dstMap.getResults().size())
175 for (
unsigned i = 0; i < srcMap.getNumResults(); ++i) {
176 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
179 int64_t pos = expr.getPosition();
180 if (i > 0 && pos <= position[i - 1])
182 position.push_back(expr.getPosition());
186 auto numDims = srcMap.getNumDims();
188 for (
auto dim : llvm::seq<int64_t>(0, numDims)) {
189 if (!llvm::is_contained(position, dim))
190 broadcastedDims.push_back(dim);
192 return broadcastedDims;
198std::optional<SmallVector<int64_t>>
203 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
204 !op.isSingleYieldOp())
207 auto mapRange = op.getIndexingMapsArray();
208 if (mapRange.size() != 2)
211 auto mapOfInput = mapRange.front();
212 auto mapOfResult = mapRange.back();
216 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
220 for (
unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
221 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
222 permutation[expr.getPosition()] = i;
232 bool allowNonIdentityMaps) {
234 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
239 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
240 (!allowNonIdentityMaps &&
241 !llvm::all_of(op.getIndexingMapsArray(),
242 [](
AffineMap map) { return map.isIdentity(); })))
246 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
253 Block *body = op.getBody();
265 auto yieldOp = dyn_cast<linalg::YieldOp>(body->
back());
266 return !(!yieldOp || yieldOp.getNumOperands() != 1 ||
267 yieldOp->getOperand(0).getDefiningOp() != oper);
271 bool allowNonIdentityMaps) {
277 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
288 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
289 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
290 return !(!op.payloadUsesValueFromOperand(inputOpOperand0) ||
291 !op.payloadUsesValueFromOperand(inputOpOperand1));
305 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
306 if (!iface || !iface.hasNoEffect())
316 llvm::raw_ostream &errs) {
318 errs <<
"no terminator in the block";
323 errs <<
"expected block with 3 arguments";
329 errs <<
"expected terminator with 1 operand";
337 errs <<
"expected reduction op to be binary";
346 errs <<
"expected reduction to take block argument #2 as one of the "
347 "operands (modulo unary casts)";
352 isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
356 errs <<
"expected elementwise op to be binary";
360 if (!isaPair(elementwiseOp, reductionOp)) {
361 errs <<
"expected reduction/elementwise op kind not satisfied";
374 errs <<
"expected elementwise op to apply to block arguments (modulo unary "
381template <
typename AddOpTy,
typename MulOpTy,
typename... Args>
383 static_assert(
sizeof...(Args) % 2 == 0,
384 "expected an even number of template arguments");
385 if (isa<AddOpTy>(
add) && isa<MulOpTy>(
mul))
388 if constexpr (
sizeof...(Args) > 0)
396template <
typename... Args>
408static llvm::SmallDenseSet<int64_t>
411 utils::IteratorType iter) {
412 assert(iterators.size() == indexingMap.
getNumDims());
413 llvm::SmallDenseSet<int64_t> res;
415 if (
auto d = dyn_cast<AffineDimExpr>(e)) {
416 if (iterators[d.getPosition()] == iter &&
418 return e.isFunctionOfDim(d.getPosition());
420 res.insert(d.getPosition());
427auto par = utils::IteratorType::parallel;
428auto red = utils::IteratorType::reduction;
435static FailureOr<SmallVector<utils::IteratorType>>
441 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
442 iterators[dim.getPosition()] = par;
457static FailureOr<ContractionDimensions>
460 llvm::SmallDenseSet<int64_t> a =
462 llvm::SmallDenseSet<int64_t>
b =
464 llvm::SmallDenseSet<int64_t> c =
468 llvm::SmallDenseSet<int64_t> ac = a;
469 llvm::set_intersect(ac, c);
470 llvm::set_subtract(ac,
b);
472 llvm::SmallDenseSet<int64_t> bc =
b;
473 llvm::set_intersect(bc, c);
474 llvm::set_subtract(bc, a);
476 llvm::SmallDenseSet<int64_t> batches = a;
477 llvm::set_intersect(batches,
b);
478 llvm::set_intersect(batches, c);
481 llvm::SmallDenseSet<int64_t> ra =
483 llvm::SmallDenseSet<int64_t> rb =
485 llvm::set_intersect(ra, rb);
493 llvm::sort(dimensions.
batch);
494 llvm::sort(dimensions.
m);
495 llvm::sort(dimensions.
n);
496 llvm::sort(dimensions.
k);
500FailureOr<ContractionDimensions>
502 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
505 linalgOp.getIteratorTypesArray());
508FailureOr<ContractionDimensions>
510 if (indexingMaps.size() != 3)
513 if (failed(iterators))
532 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
535 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
537 auto mapRange = linalgOp.getIndexingMapsArray();
538 if (linalgOp.getNumReductionLoops() == 0)
540 if (llvm::any_of(mapRange,
546 arith::MulFOp, arith::AddFOp,
547 arith::MulIOp, arith::AddIOp,
548 complex::MulOp, complex::AddOp,
549 arith::AndIOp, arith::OrIOp>(
550 *linalgOp.getBlock())) {
557 assert(succeeded(res) &&
"unexpected failure to infer contraction dims");
567 return "expected a LinalgOp";
569 return "expected op with 2 inputs and 1 output";
571 return "expected at least 1 reduction";
573 return "expected indexing maps to be projected permutations";
575 return "expected add/mul op in the body";
579 llvm_unreachable(
"unhandled MatchContractionResult case");
586 return isa<ContractionOpInterface>(op) ||
619 return isa<T>(
lhs) ? cast<T>(
lhs) : (isa<T>(
rhs) ? cast<T>(
rhs) :
nullptr);
631struct ConvAccessExprWalker
634 llvm::SmallDenseSet<int64_t> convolvedDims;
636 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
638 llvm::SmallDenseSet<int64_t> unConvolvedDims;
640 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
644 void clearMultiUseDims(AffineMap map) {
645 for (
int dimPos = 0, e = map.
getNumDims(); dimPos < e; ++dimPos) {
646 if (llvm::count_if(map.
getResults(), [dimPos](AffineExpr e) {
647 return e.isFunctionOfDim(dimPos);
649 convolvedDims.erase(dimPos);
650 unConvolvedDims.erase(dimPos);
653 auto it = convolvedDimMapping.find(dimPos);
654 if (it != convolvedDimMapping.end()) {
655 int64_t pairedDim = it->second;
656 convolvedDims.erase(pairedDim);
657 unConvolvedDims.erase(pairedDim);
658 strideAndDilationMapping.erase(pairedDim);
659 convolvedDimMapping.erase(dimPos);
660 convolvedDimMapping.erase(pairedDim);
666 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
668 if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
671 unConvolvedDims.insert(position);
675 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) {
return failure(); }
677 LogicalResult visitConstantExpr(AffineConstantExpr expr) {
return failure(); }
679 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
681 if (binaryExpr.
getKind() != AffineExprKind::Add)
683 auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getLHS());
684 auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.
getRHS());
687 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
688 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
692 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
693 if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
695 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
698 strideAndDilationMapping[dim] =
700 convolvedDims.insert(dim);
703 if (
auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
704 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
706 auto lhsExpr = symbolMulExpr.getLHS();
707 auto rhsExpr = symbolMulExpr.getRHS();
716 if (!mulExpr || !dimExpr)
719 if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
721 strideAndDilationMapping[dim] = mulExpr;
722 convolvedDims.insert(dim);
732 "expected map to have projected permutations");
733 llvm::SmallDenseSet<int64_t> preservedDims;
735 preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
736 return preservedDims;
742 for (
auto e : exprs) {
743 auto constantExpr = dyn_cast<AffineConstantExpr>(e);
744 assert(constantExpr &&
"Found non-constant stride/dilation");
745 vals.push_back(constantExpr.getValue());
763static FailureOr<ConvolutionDimensions>
765 ConvAccessExprWalker &inputExprWalker,
766 bool allowEmptyConvolvedDims) {
768 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
770 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
772 filterMap, linalgOp.getIteratorTypesArray(), par);
774 outputMap, linalgOp.getIteratorTypesArray(), par);
777 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
778 llvm::set_intersect(batch, outputDims);
779 llvm::set_subtract(batch, filterDims);
782 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
783 llvm::set_intersect(oi, outputDims);
786 llvm::SmallDenseSet<int64_t> oc = filterDims;
787 llvm::set_intersect(oc, outputDims);
788 llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
791 llvm::SmallDenseSet<int64_t> depth = filterDims;
792 llvm::set_intersect(depth, outputDims);
793 llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
795 llvm::SmallDenseSet<int64_t> filterReducedDims =
797 linalgOp.getIteratorTypesArray(), red);
800 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
801 llvm::set_intersect(fl, filterReducedDims);
804 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
805 llvm::set_intersect(ic, filterReducedDims);
807 if (oi.empty() && !allowEmptyConvolvedDims)
821 llvm::sort(dimensions.
batch);
825 llvm::sort(dimensions.
depth);
831 dimensions.
filterLoop.push_back(inputExprWalker.convolvedDimMapping[oiDim]);
835 if (!nativeStrides) {
838 strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
841 dimensions.
strides = llvm::to_vector<2>(nativeStrides.getValues<
int64_t>());
843 auto nativeDilations =
845 if (!nativeDilations) {
848 dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
852 llvm::to_vector<2>(nativeDilations.getValues<
int64_t>());
885FailureOr<ConvolutionDimensions>
887 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
890 auto indexingMaps = linalgOp.getIndexingMapsArray();
893 ConvAccessExprWalker inputExprWalker;
894 for (
AffineExpr expr : indexingMaps[0].getResults())
895 (
void)inputExprWalker.visit(expr);
896 inputExprWalker.clearMultiUseDims(indexingMaps[0]);
919 bool allowEmptyConvolvedDims) {
920 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
923 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
926 auto indexingMaps = linalgOp.getIndexingMapsArray();
929 ConvAccessExprWalker inputExprWalker;
930 if (llvm::any_of(indexingMaps[0].getResults(),
932 return failed(inputExprWalker.visit(expr));
938 if (!indexingMaps[1].isProjectedPermutation() ||
939 !indexingMaps.back().isProjectedPermutation())
942 auto iteratorTypes = linalgOp.getIteratorTypesArray();
944 llvm::SmallDenseSet<int64_t> outputDims =
946 llvm::SmallDenseSet<int64_t> filterDims =
getPreservedDims(indexingMaps[1]);
960 llvm::SmallDenseSet<int64_t> allLoopDims;
961 for (
auto outputExpr : indexingMaps.back().getResults()) {
962 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
963 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
964 !filterDims.count(outputDim)) {
966 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
968 allLoopDims.insert(outputDim);
971 if (inputExprWalker.convolvedDims.count(outputDim) &&
972 !filterDims.count(outputDim)) {
974 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
976 allLoopDims.insert(outputDim);
979 if (!inputExprWalker.convolvedDims.count(outputDim) &&
980 !inputExprWalker.unConvolvedDims.count(outputDim) &&
981 filterDims.count(outputDim)) {
983 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
985 allLoopDims.insert(outputDim);
988 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
989 filterDims.count(outputDim)) {
991 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
993 allLoopDims.insert(outputDim);
998 for (
auto filterExpr : indexingMaps[1].getResults()) {
999 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
1000 if (outputDims.count(filterDim) &&
1001 !inputExprWalker.unConvolvedDims.count(filterDim) &&
1002 !inputExprWalker.convolvedDims.count(filterDim)) {
1006 if (inputExprWalker.convolvedDims.count(filterDim) &&
1007 !outputDims.count(filterDim)) {
1009 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1011 if (allLoopDims.count(filterDim))
1013 allLoopDims.insert(filterDim);
1016 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1017 !outputDims.count(filterDim)) {
1019 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
1021 if (allLoopDims.count(filterDim))
1023 allLoopDims.insert(filterDim);
1026 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
1027 outputDims.count(filterDim)) {
1034 if (allLoopDims.size() != linalgOp.getNumLoops())
1037 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1042 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1043 assert(succeeded(res) &&
"unexpected failure to infer convolution dims");
1054 return "expected a LinalgOp";
1056 return "expected op with 2 inputs and 1 output";
1058 return "unexpected input index map for convolutions";
1060 return "expected output/filter indexing maps to be projected permutations";
1062 return "unexpected loop dimension for convolution op";
1064 return "expected all iterators used to access outputs to be parallel";
1066 return "expected all iterators not used to access outputs to be reduction";
1068 return "expected convolved dim to be non-empty";
1072 llvm_unreachable(
"unhandled MatchConvolutionResult case");
1076 bool allowEmptyConvolvedDims) {
1078 linalgOp.getOperation(),
nullptr, allowEmptyConvolvedDims) ==
1094enum class MatchFillResult {
1104 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1106 return MatchFillResult::NotLinalgOp;
1107 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1108 return MatchFillResult::WrongNumOperands;
1110 OpOperand *value = linalgOp.getDpsInputOperand(0);
1111 if (!linalgOp.isScalar(value))
1112 return MatchFillResult::NotScalarInput;
1115 OpOperand *output = linalgOp.getDpsInitOperand(0);
1118 if (scalarType != outputElementType)
1119 return MatchFillResult::TypeMismatch;
1121 return MatchFillResult::Success;
1126 if (res == MatchFillResult::NotLinalgOp)
1127 return op->
emitError(
"expected a LinalgOp");
1128 if (res == MatchFillResult::WrongNumOperands)
1129 return op->
emitError(
"expected op with 1 input and 1 output");
1130 if (res == MatchFillResult::NotScalarInput)
1131 return op->
emitError(
"expected op with scalar input");
1132 if (res == MatchFillResult::TypeMismatch) {
1133 auto linalgOp = cast<linalg::LinalgOp>(op);
1134 Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
1135 Type outputElementType =
1137 return op->
emitOpError(
"expected fill value type (")
1138 << scalarType <<
") to match output element type ("
1139 << outputElementType <<
")";
1152 for (
OpOperand &opOperand : getOperation()->getOpOperands()) {
1153 for (
int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1161 assert(!hasDynamicShape() &&
"expected operands to have static shapes");
1162 for (
OpOperand &opOperand : getOperation()->getOpOperands())
1163 llvm::append_range(res,
getShape(&opOperand));
1168 AffineMap map = getLoopsToShapesMap();
1170 auto viewSizes = createFlatListOfOperandDims(
b, loc);
1171 SmallVector<Range, 4> res(numDims);
1172 for (
unsigned idx = 0; idx < numRes; ++idx) {
1174 if (
auto d = dyn_cast<AffineDimExpr>(
result)) {
1175 if (res[d.getPosition()].offset)
1177 res[d.getPosition()] =
1178 Range{
b.getIndexAttr(0), viewSizes[idx],
b.getIndexAttr(1)};
1189 : positions(std::move(positions)) {}
1204 llvm::SmallBitVector positions;
1207static std::pair<int64_t, int64_t>
1211 for (
OpOperand *input : op.getDpsInputOperands())
1212 inputRankSum += op.getRank(input);
1213 for (
OpOperand &output : op.getDpsInitsMutable())
1214 outputRankSum += op.getRank(&output);
1215 return {inputRankSum, inputRankSum + outputRankSum};
1230 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1238 AffineMap loopToResultsShapeMap = loopsToShapesMap.
getSliceMap(
1239 resultShapesSubMapPos.first,
1240 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1241 AffineMap resultShapesFromInputShapesMap =
1242 loopToResultsShapeMap.
compose(getShapesToLoopsMap());
1246 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.
getNumDims());
1247 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1248 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1249 Location loc = getOperation()->getLoc();
1250 IRRewriter rewriter(
b);
1251 SmallVector<OpFoldResult> allResultDimValues =
1252 affine::makeComposedFoldedMultiResultAffineApply(
1253 rewriter, loc, resultShapesFromInputShapesMap,
1254 createFlatListOfOperandDims(
b, loc));
1256 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.
getResults();
1257 for (OpOperand &opOperand : getDpsInitsMutable()) {
1258 SmallVector<OpFoldResult> shapes;
1259 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1260 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1261 if (!shapedType.isDynamicDim(dim)) {
1263 shapes.push_back(
b.getIndexAttr(shapedType.getDimSize(dim)));
1266 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1268 : allResultDimValues[pos];
1273 reifiedReturnShapes.emplace_back(std::move(shapes));
1282 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1283 if (!dpsIface.isDpsInput(opOperand))
1284 return operandNumber;
1285 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1286 assert(!dpsIface.isDpsInit(opOperand));
1289 return cast<DestinationStyleOpInterface>(*this->getOperation())
1290 .getNumDpsInputs() +
1291 operandNumber - start;
1295 LinalgOp linalgOp = cast<LinalgOp>(op);
1297 if (!linalgOp.hasPureTensorSemantics() &&
1299 return op->
emitOpError(
"expected to have pure tensor or buffer semantics");
1303 if (linalgOp.hasDynamicIndexingMaps())
1304 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1308 if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1313 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
1314 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1316 unsigned numLoops = linalgOp.getNumLoops();
1320 <<
" dim(s) to match the number of loops";
1323 linalgOp.getReductionDims(redDims);
1325 if (!linalgOp.getShapesToLoopsMap())
1326 return op->
emitOpError(
"expected the shape-to-loops map to be non-null");
1329 if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1330 return op->
emitOpError(
"expects to have 1 region with 1 block");
1338 Block &block = linalgOp->getRegion(0).front();
1340 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.
getNumArguments())
1341 return op->
emitOpError(
"expected as many non-induction variable region "
1342 "arguments as the number of input/output operands");
1344 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1346 if (isa<MemRefType, RankedTensorType>(elementType))
1349 if (elementType != argType)
1350 return op->
emitOpError(
"expected type of bb argument #")
1352 <<
" to match element or self type of the corresponding operand ("
1353 << elementType <<
")";
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity, bool allowNonIdentityMaps)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Affine binary operation expression.
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
bool visit(AffineExpr expr)
See documentation for AffineExprVisitorBase.
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
MLIRContext * getContext() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
A symbolic identifier appearing in an affine expression.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
@ NotProjectedPermutations
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
@ NotProjectedPermutations
@ NonOutputDimNotReduction
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a broadcast operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp, bool allowNonIdentityMaps=false)
Checks whether a given genericOp is semantically equivalent to a single linalg elementwise unary op,...
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::function_ref< Fn > function_ref
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
SmallVector< unsigned, 2 > k
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.
SmallVector< unsigned, 2 > depth
SmallVector< unsigned, 2 > outputImage
SmallVector< unsigned, 2 > outputChannel
SmallVector< int64_t, 2 > dilations
SmallVector< int64_t, 2 > strides
SmallVector< unsigned, 2 > inputChannel
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > filterLoop