31 #include "llvm/Support/Debug.h"
34 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
35 #include "mlir/Dialect/Linalg/Passes.h.inc"
38 #define DEBUG_TYPE "linalg-drop-unit-dims"
82 LogicalResult matchAndRewrite(GenericOp genericOp,
84 if (!genericOp.hasPureTensorSemantics())
86 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
89 auto outputOperands = genericOp.getDpsInitsMutable();
92 if (genericOp.getMatchingBlockArgument(&op).use_empty())
94 candidates.insert(&op);
97 if (candidates.empty())
101 int64_t origNumInput = genericOp.getNumDpsInputs();
105 newIndexingMaps.append(indexingMaps.begin(),
106 std::next(indexingMaps.begin(), origNumInput));
108 newInputOperands.push_back(op->get());
109 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
111 newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
116 llvm::to_vector(genericOp.getDpsInits());
120 auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
121 auto empty = tensor::EmptyOp::create(
125 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
126 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
129 auto newOp = GenericOp::create(
130 rewriter, loc, genericOp.getResultTypes(), newInputOperands,
131 newOutputOperands, newIndexingMaps, genericOp.getIteratorTypesArray(),
135 Region ®ion = newOp.getRegion();
138 for (
auto bbarg : genericOp.getRegionInputArgs())
142 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
147 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
148 if (candidates.count(&op))
154 for (
auto &op : genericOp.getBody()->getOperations()) {
155 rewriter.
clone(op, mapper);
157 rewriter.
replaceOp(genericOp, newOp.getResults());
228 const llvm::SmallDenseSet<unsigned> &unitDims,
230 for (IndexOp indexOp :
231 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
234 if (unitDims.count(indexOp.getDim()) != 0) {
238 unsigned droppedDims = llvm::count_if(
239 unitDims, [&](
unsigned dim) {
return dim < indexOp.getDim(); });
240 if (droppedDims != 0)
242 indexOp.getDim() - droppedDims);
255 auto origResultType = cast<RankedTensorType>(origDest.
getType());
256 if (rankReductionStrategy ==
258 unsigned rank = origResultType.getRank();
264 loc, result, origDest, offsets, sizes, strides);
267 assert(rankReductionStrategy ==
269 "unknown rank reduction strategy");
270 return tensor::ExpandShapeOp::create(rewriter, loc, origResultType, result,
282 if (
auto memrefType = dyn_cast<MemRefType>(operand.
getType())) {
283 if (rankReductionStrategy ==
285 FailureOr<Value> rankReducingExtract =
286 memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
288 assert(succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
289 return *rankReducingExtract;
293 rankReductionStrategy ==
295 "unknown rank reduction strategy");
296 MemRefLayoutAttrInterface layout;
297 auto targetType =
MemRefType::get(targetShape, memrefType.getElementType(),
298 layout, memrefType.getMemorySpace());
299 return memref::CollapseShapeOp::create(rewriter, loc, targetType, operand,
302 if (
auto tensorType = dyn_cast<RankedTensorType>(operand.
getType())) {
303 if (rankReductionStrategy ==
305 FailureOr<Value> rankReducingExtract =
306 tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
308 assert(succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
309 return *rankReducingExtract;
313 rankReductionStrategy ==
315 "unknown rank reduction strategy");
318 return tensor::CollapseShapeOp::create(rewriter, loc, targetType, operand,
321 llvm_unreachable(
"unsupported operand type");
336 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
341 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
345 auto isUnitDim = [&](
unsigned dim) {
346 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
347 unsigned oldPosition = dimExpr.getPosition();
348 return !oldDimsToNewDimsMap.count(oldPosition) &&
349 (operandShape[dim] == 1);
353 if (operandShape[dim] == 1) {
354 auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
355 return constAffineExpr && constAffineExpr.getValue() == 0;
361 while (dim < operandShape.size() && isUnitDim(dim))
362 reassociationGroup.push_back(dim++);
363 while (dim < operandShape.size()) {
364 assert(!isUnitDim(dim) &&
"expected non unit-extent");
365 reassociationGroup.push_back(dim);
366 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
367 newIndexExprs.push_back(newExpr);
371 while (dim < operandShape.size() && isUnitDim(dim)) {
372 reassociationGroup.push_back(dim++);
375 reassociationGroup.clear();
379 newIndexExprs, context);
383 FailureOr<DropUnitDimsResult>
387 auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation());
390 op,
"op should implement DestinationStyleOpInterface");
394 if (indexingMaps.empty())
404 "invalid indexing maps for operation");
408 for (
OpOperand &opOperand : op->getOpOperands())
409 llvm::append_range(allShapesSizes, op.getStaticOperandShape(&opOperand));
413 if (allowedUnitDims.empty()) {
415 op,
"control function returns no allowed unit dims to prune");
417 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
418 allowedUnitDims.end());
419 llvm::SmallDenseSet<unsigned> unitDims;
421 if (
AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
422 if (allShapesSizes[dimExpr.getPosition()] == 1 &&
423 unitDimsFilter.count(expr.index()))
424 unitDims.insert(expr.index());
430 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
432 unsigned newDims = 0;
433 for (
auto index : llvm::seq<int64_t>(op.getStaticLoopRanges().size())) {
434 if (unitDims.count(index)) {
435 dimReplacements.push_back(
438 oldDimToNewDimMap[index] = newDims;
439 dimReplacements.push_back(
460 auto hasCollapsibleType = [](
OpOperand &operand) {
461 Type operandType = operand.get().getType();
462 if (
auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
463 return memrefOperandType.getLayout().isIdentity();
465 if (
auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
466 return tensorOperandType.getEncoding() ==
nullptr;
470 for (
OpOperand &opOperand : op->getOpOperands()) {
471 auto indexingMap = op.getMatchingIndexingMap(&opOperand);
473 if (!hasCollapsibleType(opOperand)) {
476 newIndexingMaps.push_back(newIndexingMap);
477 targetShapes.push_back(llvm::to_vector(shape));
478 collapsed.push_back(
false);
479 reassociations.push_back({});
482 auto replacementInfo =
484 oldDimToNewDimMap, dimReplacements);
485 reassociations.push_back(replacementInfo.reassociation);
486 newIndexingMaps.push_back(replacementInfo.indexMap);
487 targetShapes.push_back(replacementInfo.targetShape);
488 collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
489 indexingMap.getNumResults()));
494 if (newIndexingMaps == indexingMaps ||
505 for (
OpOperand &opOperand : op->getOpOperands()) {
506 int64_t idx = opOperand.getOperandNumber();
507 if (!collapsed[idx]) {
508 newOperands.push_back(opOperand.get());
511 newOperands.push_back(
collapseValue(rewriter, loc, opOperand.get(),
512 targetShapes[idx], reassociations[idx],
513 options.rankReductionStrategy));
516 IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
517 loc, rewriter, op, newOperands, newIndexingMaps, unitDims);
522 for (
auto [index, result] :
llvm::enumerate(replacementOp->getResults())) {
523 unsigned opOperandIndex = index + dpsOp.getNumDpsInputs();
524 Value origDest = dpsOp.getDpsInitOperand(index)->get();
525 if (!collapsed[opOperandIndex]) {
526 resultReplacements.push_back(result);
530 reassociations[opOperandIndex],
531 options.rankReductionStrategy);
532 resultReplacements.push_back(expandedValue);
538 FailureOr<DropUnitDimsResult>
545 const llvm::SmallDenseSet<unsigned> &droppedDims)
546 -> IndexingMapOpInterface {
547 auto genericOp = cast<GenericOp>(op);
551 for (
auto [index, attr] :
553 if (!droppedDims.count(index))
554 newIteratorTypes.push_back(attr);
564 resultTypes.reserve(genericOp.getNumResults());
565 for (
unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
566 resultTypes.push_back(newOutputs[i].
getType());
567 GenericOp replacementOp =
568 GenericOp::create(b, loc, resultTypes, newInputs, newOutputs,
569 newIndexingMaps, newIteratorTypes);
570 b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
571 replacementOp.getRegion().begin());
577 return replacementOp;
589 LogicalResult matchAndRewrite(GenericOp genericOp,
591 FailureOr<DropUnitDimsResult> result =
596 rewriter.
replaceOp(genericOp, result->replacements);
615 LogicalResult matchAndRewrite(tensor::PadOp padOp,
619 if (allowedUnitDims.empty()) {
621 padOp,
"control function returns no allowed unit dims to prune");
624 if (padOp.getSourceType().getEncoding()) {
626 padOp,
"cannot collapse dims of tensor with encoding");
633 Value paddingVal = padOp.getConstantPaddingValue();
636 padOp,
"unimplemented: non-constant padding value");
641 int64_t padRank = sourceShape.size();
647 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
648 allowedUnitDims.end());
649 llvm::SmallDenseSet<unsigned> unitDims;
654 for (
const auto [dim, size, outSize, low, high] : zip_equal(
655 llvm::seq(
static_cast<int64_t
>(0), padRank), sourceShape,
656 resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
657 if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
658 isStaticZero(high)) {
659 unitDims.insert(dim);
661 newShape.push_back(size);
662 newResultShape.push_back(outSize);
663 newLowPad.push_back(low);
664 newHighPad.push_back(high);
668 if (unitDims.empty()) {
675 while (dim < padRank && unitDims.contains(dim))
676 reassociationGroup.push_back(dim++);
677 while (dim < padRank) {
678 assert(!unitDims.contains(dim) &&
"expected non unit-extent");
679 reassociationGroup.push_back(dim);
682 while (dim < padRank && unitDims.contains(dim))
683 reassociationGroup.push_back(dim++);
684 reassociationMap.push_back(reassociationGroup);
685 reassociationGroup.clear();
688 Value collapsedSource =
689 collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
690 reassociationMap,
options.rankReductionStrategy);
693 newResultShape, padOp.getResultType().getElementType());
694 auto newPadOp = tensor::PadOp::create(
695 rewriter, padOp.getLoc(), newResultType, collapsedSource,
696 newLowPad, newHighPad, paddingVal, padOp.getNofold());
698 Value dest = padOp.getResult();
699 if (
options.rankReductionStrategy ==
702 int64_t numUnitDims = 0;
703 for (
auto dim : llvm::seq(
static_cast<int64_t
>(0), padRank)) {
704 if (unitDims.contains(dim)) {
710 rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
712 dest = tensor::EmptyOp::create(rewriter, padOp.getLoc(), expandedSizes,
713 padOp.getResultType().getElementType());
716 Value expandedValue =
717 expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
718 reassociationMap,
options.rankReductionStrategy);
719 rewriter.
replaceOp(padOp, expandedValue);
730 struct RankReducedExtractSliceOp
734 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
736 RankedTensorType resultType = sliceOp.getType();
738 for (
auto size : resultType.getShape())
741 if (!reassociation ||
742 reassociation->size() ==
static_cast<size_t>(resultType.getRank()))
748 auto rankReducedType = cast<RankedTensorType>(
749 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
750 reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
754 Value newSlice = tensor::ExtractSliceOp::create(
755 rewriter, loc, rankReducedType, sliceOp.getSource(), offsets, sizes,
758 sliceOp, resultType, newSlice, *reassociation);
765 template <
typename InsertOpTy>
769 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
771 RankedTensorType sourceType = insertSliceOp.getSourceType();
773 for (
auto size : sourceType.getShape())
776 if (!reassociation ||
777 reassociation->size() ==
static_cast<size_t>(sourceType.getRank()))
780 Location loc = insertSliceOp.getLoc();
781 tensor::CollapseShapeOp reshapedSource;
787 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
789 reshapedSource = tensor::CollapseShapeOp::create(
790 rewriter, loc, insertSliceOp.getSource(), *reassociation);
793 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
794 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
795 insertSliceOp.getMixedStrides());
806 auto *context =
patterns.getContext();
810 patterns.add<RankReducedExtractSliceOp,
811 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
812 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
814 linalg::FillOp::getCanonicalizationPatterns(
patterns, context);
815 tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);
816 tensor::EmptyOp::getCanonicalizationPatterns(
patterns, context);
817 tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
826 auto *context =
patterns.getContext();
830 linalg::FillOp::getCanonicalizationPatterns(
patterns, context);
831 tensor::EmptyOp::getCanonicalizationPatterns(
patterns, context);
839 if (
options.rankReductionStrategy ==
842 }
else if (
options.rankReductionStrategy ==
844 ReassociativeReshape) {
856 struct LinalgFoldUnitExtentDimsPass
857 :
public impl::LinalgFoldUnitExtentDimsPassBase<
858 LinalgFoldUnitExtentDimsPass> {
859 using impl::LinalgFoldUnitExtentDimsPassBase<
860 LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
861 void runOnOperation()
override {
866 if (useRankReducingSlices) {
883 getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
885 bool lastDim = pos == rank - 1;
887 for (int64_t i = 0; i < rank - 1; i++) {
888 if (i == pos || (lastDim && i == pos - 1))
896 return reassociation;
905 auto valType = cast<ShapedType>(val.
getType());
907 collapsedShape.erase(collapsedShape.begin() + pos);
909 rewriter, val.
getLoc(), val, collapsedShape,
910 getReassociationForReshapeAtDim(valType.getRank(), pos),
920 template <
typename FromOpTy,
typename ToOpTy>
928 assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
929 "expected 3 operands and dims");
930 return llvm::map_to_vector(
931 llvm::zip(operands, operandCollapseDims), [&](
auto pair) {
932 return collapseSingletonDimAt(rewriter, std::get<0>(pair),
939 RankedTensorType expandedType, int64_t dim)
const {
940 return tensor::ExpandShapeOp::create(
941 rewriter, result.
getLoc(), expandedType, result,
942 getReassociationForReshapeAtDim(expandedType.getRank(), dim));
945 LogicalResult matchAndRewrite(FromOpTy contractionOp,
947 if (contractionOp.hasUserDefinedMaps()) {
949 contractionOp,
"ops with user-defined maps are not supported");
952 auto loc = contractionOp.getLoc();
953 auto inputs = contractionOp.getDpsInputs();
954 auto inits = contractionOp.getDpsInits();
955 if (inputs.size() != 2 || inits.size() != 1)
957 "expected 2 inputs and 1 init");
958 auto lhs = inputs[0];
959 auto rhs = inputs[1];
960 auto init = inits[0];
964 if (
failed(getOperandUnitDims(contractionOp, operandUnitDims)))
966 "no reducable dims found");
969 collapseOperands(rewriter, operands, operandUnitDims);
970 Value collapsedLhs = collapsedOperands[0];
971 Value collapsedRhs = collapsedOperands[1];
972 Value collapsedInit = collapsedOperands[2];
974 if (isa<RankedTensorType>(collapsedInit.
getType()))
975 collapsedResultTy.push_back(collapsedInit.
getType());
976 auto collapsedOp = ToOpTy::create(rewriter, loc, collapsedResultTy,
979 for (
auto attr : contractionOp->getAttrs()) {
980 if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
981 attr.getName() ==
"indexing_maps")
983 collapsedOp->setAttr(attr.getName(), attr.getValue());
986 auto results = contractionOp.getResults();
987 assert(results.size() < 2 &&
"expected at most one result");
988 if (results.empty()) {
989 rewriter.
replaceOp(contractionOp, collapsedOp);
993 expandResult(rewriter, collapsedOp.getResultTensors()[0],
994 cast<RankedTensorType>(results[0].getType()),
995 operandUnitDims[2]));
1004 virtual LogicalResult
1005 getOperandUnitDims(LinalgOp op,
1010 template <
typename FromOpTy,
typename ToOpTy>
1011 struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
1012 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1016 getOperandUnitDims(LinalgOp op,
1018 FailureOr<ContractionDimensions> maybeContractionDims =
1020 if (
failed(maybeContractionDims)) {
1021 LLVM_DEBUG(llvm::dbgs() <<
"could not infer contraction dims");
1026 if (contractionDims.
batch.size() != 1)
1028 auto batchDim = contractionDims.
batch[0];
1030 op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
1031 if (bOperands.size() != 3 || llvm::any_of(bOperands, [](
auto pair) {
1032 return cast<ShapedType>(std::get<0>(pair).getType())
1033 .getShape()[std::get<1>(pair)] != 1;
1035 LLVM_DEBUG(llvm::dbgs() <<
"specified unit dims not found");
1040 std::get<1>(bOperands[1]),
1041 std::get<1>(bOperands[2])};
1047 template <
typename FromOpTy,
typename ToOpTy>
1048 struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1049 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1052 static bool constexpr reduceLeft =
1053 (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1054 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1055 (std::is_same_v<FromOpTy, MatmulOp> &&
1056 std::is_same_v<ToOpTy, VecmatOp>) ||
1057 (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1061 getOperandUnitDims(LinalgOp op,
1063 FailureOr<ContractionDimensions> maybeContractionDims =
1065 if (
failed(maybeContractionDims)) {
1066 LLVM_DEBUG(llvm::dbgs() <<
"could not infer contraction dims");
1071 if constexpr (reduceLeft) {
1072 auto m = contractionDims.
m[0];
1074 op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1075 if (mOperands.size() != 2)
1077 if (llvm::all_of(mOperands, [](
auto pair) {
1078 return cast<ShapedType>(std::get<0>(pair).
getType())
1079 .getShape()[std::get<1>(pair)] == 1;
1082 std::get<1>(mOperands[1])};
1086 auto n = contractionDims.
n[0];
1088 op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1089 if (nOperands.size() != 2)
1091 if (llvm::all_of(nOperands, [](
auto pair) {
1092 return cast<ShapedType>(std::get<0>(pair).
getType())
1093 .getShape()[std::get<1>(pair)] == 1;
1096 std::get<1>(nOperands[1])};
1100 LLVM_DEBUG(llvm::dbgs() <<
"specified unit dims not found");
1111 patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1112 patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1113 patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1116 patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1117 patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1119 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1120 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1123 patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1124 patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
static Value expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Expand the given value so that the type matches the type of origDest.
static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet< unsigned > &unitDims, RewriterBase &rewriter)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand, llvm::SmallDenseMap< unsigned, unsigned > &oldDimsToNewDimsMap, ArrayRef< AffineExpr > dimReplacements)
static void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns that are used to canonicalize the use of unit-extent dims for broadcasting.
static void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
static Value collapseValue(RewriterBase &rewriter, Location loc, Value operand, ArrayRef< int64_t > targetShape, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Collapse the given value so that the type matches the type of origOutput.
static llvm::ManagedStatic< PassManagerOptions > options
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
ArrayRef< AffineExpr > getResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns)
A pattern that converts init operands to input operands.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns)
Adds patterns that reduce the rank of named contraction ops that have unit dimensions in the operand(...
std::function< IndexingMapOpInterface(Location loc, OpBuilder &, IndexingMapOpInterface, ArrayRef< Value > newOperands, ArrayRef< AffineMap > newIndexingMaps, const llvm::SmallDenseSet< unsigned > &droppedDims)> DroppedUnitDimsBuilder
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< DropUnitDimsResult > dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op, const DroppedUnitDimsBuilder &droppedUnitDimsBuilder, const ControlDropUnitDims &options)
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Compute the modified metadata for an operands of operation whose unit dims are being dropped.
SmallVector< ReassociationIndices > reassociation
SmallVector< int64_t > targetShape
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
Transformation to drop unit-extent dimensions from linalg.generic operations.