32 #include "llvm/ADT/SetVector.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
37 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
38 #include "mlir/Dialect/Linalg/Passes.h.inc"
41 #define DEBUG_TYPE "linalg-drop-unit-dims"
85 LogicalResult matchAndRewrite(GenericOp genericOp,
87 if (!genericOp.hasPureTensorSemantics())
89 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
92 auto outputOperands = genericOp.getDpsInitsMutable();
95 if (genericOp.getMatchingBlockArgument(&op).use_empty())
97 candidates.insert(&op);
100 if (candidates.empty())
104 int64_t origNumInput = genericOp.getNumDpsInputs();
108 newIndexingMaps.append(indexingMaps.begin(),
109 std::next(indexingMaps.begin(), origNumInput));
111 newInputOperands.push_back(op->get());
112 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
114 newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
119 llvm::to_vector(genericOp.getDpsInits());
123 auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
124 auto empty = rewriter.
create<tensor::EmptyOp>(
127 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
128 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
131 auto newOp = rewriter.
create<GenericOp>(
132 loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
133 newIndexingMaps, genericOp.getIteratorTypesArray(),
137 Region ®ion = newOp.getRegion();
140 for (
auto bbarg : genericOp.getRegionInputArgs())
144 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
149 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
150 if (candidates.count(&op))
156 for (
auto &op : genericOp.getBody()->getOperations()) {
157 rewriter.
clone(op, mapper);
159 rewriter.
replaceOp(genericOp, newOp.getResults());
230 const llvm::SmallDenseSet<unsigned> &unitDims,
232 for (IndexOp indexOp :
233 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
236 if (unitDims.count(indexOp.getDim()) != 0) {
240 unsigned droppedDims = llvm::count_if(
241 unitDims, [&](
unsigned dim) {
return dim < indexOp.getDim(); });
242 if (droppedDims != 0)
244 indexOp.getDim() - droppedDims);
257 auto origResultType = cast<RankedTensorType>(origDest.
getType());
258 if (rankReductionStrategy ==
260 unsigned rank = origResultType.getRank();
266 loc, result, origDest, offsets, sizes, strides);
269 assert(rankReductionStrategy ==
271 "unknown rank reduction strategy");
273 .
create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation)
284 if (
auto memrefType = dyn_cast<MemRefType>(operand.
getType())) {
285 if (rankReductionStrategy ==
287 FailureOr<Value> rankReducingExtract =
288 memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
290 assert(succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
291 return *rankReducingExtract;
295 rankReductionStrategy ==
297 "unknown rank reduction strategy");
298 MemRefLayoutAttrInterface layout;
299 auto targetType =
MemRefType::get(targetShape, memrefType.getElementType(),
300 layout, memrefType.getMemorySpace());
301 return rewriter.
create<memref::CollapseShapeOp>(loc, targetType, operand,
304 if (
auto tensorType = dyn_cast<RankedTensorType>(operand.
getType())) {
305 if (rankReductionStrategy ==
307 FailureOr<Value> rankReducingExtract =
308 tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
310 assert(succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
311 return *rankReducingExtract;
315 rankReductionStrategy ==
317 "unknown rank reduction strategy");
320 return rewriter.
create<tensor::CollapseShapeOp>(loc, targetType, operand,
323 llvm_unreachable(
"unsupported operand type");
338 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
343 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
347 auto isUnitDim = [&](
unsigned dim) {
348 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
349 unsigned oldPosition = dimExpr.getPosition();
350 return !oldDimsToNewDimsMap.count(oldPosition) &&
351 (operandShape[dim] == 1);
355 if (operandShape[dim] == 1) {
356 auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
357 return constAffineExpr && constAffineExpr.getValue() == 0;
363 while (dim < operandShape.size() && isUnitDim(dim))
364 reassociationGroup.push_back(dim++);
365 while (dim < operandShape.size()) {
366 assert(!isUnitDim(dim) &&
"expected non unit-extent");
367 reassociationGroup.push_back(dim);
368 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
369 newIndexExprs.push_back(newExpr);
373 while (dim < operandShape.size() && isUnitDim(dim)) {
374 reassociationGroup.push_back(dim++);
377 reassociationGroup.clear();
381 newIndexExprs, context);
385 FailureOr<DropUnitDimsResult>
389 if (indexingMaps.empty())
399 "invalid indexing maps for operation");
405 if (allowedUnitDims.empty()) {
407 genericOp,
"control function returns no allowed unit dims to prune");
409 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
410 allowedUnitDims.end());
411 llvm::SmallDenseSet<unsigned> unitDims;
413 if (
AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
414 if (dims[dimExpr.getPosition()] == 1 &&
415 unitDimsFilter.count(expr.index()))
416 unitDims.insert(expr.index());
423 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
425 unsigned newDims = 0;
426 for (
auto [index, attr] :
428 if (unitDims.count(index)) {
429 dimReplacements.push_back(
432 newIteratorTypes.push_back(attr);
433 oldDimToNewDimMap[index] = newDims;
434 dimReplacements.push_back(
455 auto hasCollapsibleType = [](
OpOperand &operand) {
456 Type operandType = operand.get().getType();
457 if (
auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
458 return memrefOperandType.getLayout().isIdentity();
460 if (
auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
461 return tensorOperandType.getEncoding() ==
nullptr;
465 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
466 auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
468 if (!hasCollapsibleType(opOperand)) {
471 newIndexingMaps.push_back(newIndexingMap);
472 targetShapes.push_back(llvm::to_vector(shape));
473 collapsed.push_back(
false);
474 reassociations.push_back({});
478 rewriter.
getContext(), genericOp, &opOperand, oldDimToNewDimMap,
480 reassociations.push_back(replacementInfo.reassociation);
481 newIndexingMaps.push_back(replacementInfo.indexMap);
482 targetShapes.push_back(replacementInfo.targetShape);
483 collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
484 indexingMap.getNumResults()));
489 if (newIndexingMaps == indexingMaps ||
500 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
501 int64_t idx = opOperand.getOperandNumber();
502 if (!collapsed[idx]) {
503 newOperands.push_back(opOperand.get());
506 newOperands.push_back(
collapseValue(rewriter, loc, opOperand.get(),
507 targetShapes[idx], reassociations[idx],
508 options.rankReductionStrategy));
518 resultTypes.reserve(genericOp.getNumResults());
519 for (
unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
520 resultTypes.push_back(newOutputs[i].
getType());
521 GenericOp replacementOp =
522 rewriter.
create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
523 newIndexingMaps, newIteratorTypes);
525 replacementOp.getRegion().begin());
533 for (
auto [index, result] :
llvm::enumerate(replacementOp.getResults())) {
534 unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
535 Value origDest = genericOp.getDpsInitOperand(index)->get();
536 if (!collapsed[opOperandIndex]) {
537 resultReplacements.push_back(result);
541 reassociations[opOperandIndex],
542 options.rankReductionStrategy);
543 resultReplacements.push_back(expandedValue);
555 LogicalResult matchAndRewrite(GenericOp genericOp,
557 FailureOr<DropUnitDimsResult> result =
559 if (failed(result)) {
562 rewriter.
replaceOp(genericOp, result->replacements);
581 LogicalResult matchAndRewrite(tensor::PadOp padOp,
585 if (allowedUnitDims.empty()) {
587 padOp,
"control function returns no allowed unit dims to prune");
590 if (padOp.getSourceType().getEncoding()) {
592 padOp,
"cannot collapse dims of tensor with encoding");
599 Value paddingVal = padOp.getConstantPaddingValue();
602 padOp,
"unimplemented: non-constant padding value");
606 int64_t padRank = sourceShape.size();
610 return maybeInt && *maybeInt == 0;
613 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
614 allowedUnitDims.end());
615 llvm::SmallDenseSet<unsigned> unitDims;
619 for (
const auto [dim, size, low, high] :
620 zip_equal(llvm::seq(
static_cast<int64_t
>(0), padRank), sourceShape,
621 padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
622 if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
623 isStaticZero(high)) {
624 unitDims.insert(dim);
626 newShape.push_back(size);
627 newLowPad.push_back(low);
628 newHighPad.push_back(high);
632 if (unitDims.empty()) {
639 while (dim < padRank && unitDims.contains(dim))
640 reassociationGroup.push_back(dim++);
641 while (dim < padRank) {
642 assert(!unitDims.contains(dim) &&
"expected non unit-extent");
643 reassociationGroup.push_back(dim);
646 while (dim < padRank && unitDims.contains(dim))
647 reassociationGroup.push_back(dim++);
648 reassociationMap.push_back(reassociationGroup);
649 reassociationGroup.clear();
652 Value collapsedSource =
653 collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
654 reassociationMap,
options.rankReductionStrategy);
656 auto newPadOp = rewriter.
create<tensor::PadOp>(
657 padOp.getLoc(),
Type(), collapsedSource, newLowPad,
658 newHighPad, paddingVal, padOp.getNofold());
660 Value dest = padOp.getResult();
661 if (
options.rankReductionStrategy ==
664 int64_t numUnitDims = 0;
665 for (
auto dim : llvm::seq(
static_cast<int64_t
>(0), padRank)) {
666 if (unitDims.contains(dim)) {
672 rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
674 dest = rewriter.
create<tensor::EmptyOp>(
675 padOp.getLoc(), expandedSizes,
676 padOp.getResultType().getElementType());
679 Value expandedValue =
680 expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
681 reassociationMap,
options.rankReductionStrategy);
682 rewriter.
replaceOp(padOp, expandedValue);
693 struct RankReducedExtractSliceOp
697 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
699 RankedTensorType resultType = sliceOp.getType();
701 for (
auto size : resultType.getShape())
704 if (!reassociation ||
705 reassociation->size() ==
static_cast<size_t>(resultType.getRank()))
711 auto rankReducedType = cast<RankedTensorType>(
712 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
713 reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
717 Value newSlice = rewriter.
create<tensor::ExtractSliceOp>(
718 loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
720 sliceOp, resultType, newSlice, *reassociation);
727 template <
typename InsertOpTy>
731 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
733 RankedTensorType sourceType = insertSliceOp.getSourceType();
735 for (
auto size : sourceType.getShape())
738 if (!reassociation ||
739 reassociation->size() ==
static_cast<size_t>(sourceType.getRank()))
742 Location loc = insertSliceOp.getLoc();
743 tensor::CollapseShapeOp reshapedSource;
749 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
751 reshapedSource = rewriter.
create<tensor::CollapseShapeOp>(
752 loc, insertSliceOp.getSource(), *reassociation);
755 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
756 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
757 insertSliceOp.getMixedStrides());
768 auto *context =
patterns.getContext();
772 patterns.add<RankReducedExtractSliceOp,
773 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
774 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
776 linalg::FillOp::getCanonicalizationPatterns(
patterns, context);
777 tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);
778 tensor::EmptyOp::getCanonicalizationPatterns(
patterns, context);
779 tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
788 auto *context =
patterns.getContext();
792 linalg::FillOp::getCanonicalizationPatterns(
patterns, context);
793 tensor::EmptyOp::getCanonicalizationPatterns(
patterns, context);
801 if (
options.rankReductionStrategy ==
804 }
else if (
options.rankReductionStrategy ==
806 ReassociativeReshape) {
818 struct LinalgFoldUnitExtentDimsPass
819 :
public impl::LinalgFoldUnitExtentDimsPassBase<
820 LinalgFoldUnitExtentDimsPass> {
821 using impl::LinalgFoldUnitExtentDimsPassBase<
822 LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
823 void runOnOperation()
override {
828 if (useRankReducingSlices) {
845 getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
847 bool lastDim = pos == rank - 1;
849 for (int64_t i = 0; i < rank - 1; i++) {
850 if (i == pos || (lastDim && i == pos - 1))
858 return reassociation;
867 auto valType = cast<ShapedType>(val.
getType());
869 collapsedShape.erase(collapsedShape.begin() + pos);
871 rewriter, val.
getLoc(), val, collapsedShape,
872 getReassociationForReshapeAtDim(valType.getRank(), pos),
882 template <
typename FromOpTy,
typename ToOpTy>
890 assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
891 "expected 3 operands and dims");
892 return llvm::map_to_vector(
893 llvm::zip(operands, operandCollapseDims), [&](
auto pair) {
894 return collapseSingletonDimAt(rewriter, std::get<0>(pair),
901 RankedTensorType expandedType, int64_t dim)
const {
902 return rewriter.
create<tensor::ExpandShapeOp>(
903 result.
getLoc(), expandedType, result,
904 getReassociationForReshapeAtDim(expandedType.getRank(), dim));
907 LogicalResult matchAndRewrite(FromOpTy contractionOp,
910 auto loc = contractionOp.
getLoc();
911 auto inputs = contractionOp.getDpsInputs();
912 auto inits = contractionOp.getDpsInits();
913 if (inputs.size() != 2 || inits.size() != 1)
915 "expected 2 inputs and 1 init");
916 auto lhs = inputs[0];
917 auto rhs = inputs[1];
918 auto init = inits[0];
922 if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
924 "no reducable dims found");
927 collapseOperands(rewriter, operands, operandUnitDims);
928 Value collapsedLhs = collapsedOperands[0];
929 Value collapsedRhs = collapsedOperands[1];
930 Value collapsedInit = collapsedOperands[2];
932 if (isa<RankedTensorType>(collapsedInit.
getType()))
933 collapsedResultTy.push_back(collapsedInit.
getType());
934 auto collapsedOp = rewriter.
create<ToOpTy>(
935 loc, collapsedResultTy,
ValueRange{collapsedLhs, collapsedRhs},
937 for (
auto attr : contractionOp->getAttrs()) {
938 if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
940 collapsedOp->
setAttr(attr.getName(), attr.getValue());
943 auto results = contractionOp.getResults();
944 assert(results.size() < 2 &&
"expected at most one result");
945 if (results.empty()) {
946 rewriter.
replaceOp(contractionOp, collapsedOp);
950 expandResult(rewriter, collapsedOp.getResultTensors()[0],
951 cast<RankedTensorType>(results[0].getType()),
952 operandUnitDims[2]));
961 virtual LogicalResult
962 getOperandUnitDims(LinalgOp op,
967 template <
typename FromOpTy,
typename ToOpTy>
968 struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
969 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
973 getOperandUnitDims(LinalgOp op,
975 FailureOr<ContractionDimensions> maybeContractionDims =
977 if (failed(maybeContractionDims)) {
978 LLVM_DEBUG(llvm::dbgs() <<
"could not infer contraction dims");
983 if (contractionDims.
batch.size() != 1)
985 auto batchDim = contractionDims.
batch[0];
987 op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
988 if (bOperands.size() != 3 || llvm::any_of(bOperands, [](
auto pair) {
989 return cast<ShapedType>(std::get<0>(pair).getType())
990 .getShape()[std::get<1>(pair)] != 1;
992 LLVM_DEBUG(llvm::dbgs() <<
"specified unit dims not found");
997 std::get<1>(bOperands[1]),
998 std::get<1>(bOperands[2])};
1004 template <
typename FromOpTy,
typename ToOpTy>
1005 struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1006 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1009 static bool constexpr reduceLeft =
1010 (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1011 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1012 (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
1013 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1014 (std::is_same_v<FromOpTy, MatmulOp> &&
1015 std::is_same_v<ToOpTy, VecmatOp>) ||
1016 (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
1017 std::is_same_v<ToOpTy, VecmatOp>) ||
1018 (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1022 getOperandUnitDims(LinalgOp op,
1024 FailureOr<ContractionDimensions> maybeContractionDims =
1026 if (failed(maybeContractionDims)) {
1027 LLVM_DEBUG(llvm::dbgs() <<
"could not infer contraction dims");
1032 if constexpr (reduceLeft) {
1033 auto m = contractionDims.
m[0];
1035 op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1036 if (mOperands.size() != 2)
1038 if (llvm::all_of(mOperands, [](
auto pair) {
1039 return cast<ShapedType>(std::get<0>(pair).
getType())
1040 .getShape()[std::get<1>(pair)] == 1;
1043 std::get<1>(mOperands[1])};
1047 auto n = contractionDims.
n[0];
1049 op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1050 if (nOperands.size() != 2)
1052 if (llvm::all_of(nOperands, [](
auto pair) {
1053 return cast<ShapedType>(std::get<0>(pair).
getType())
1054 .getShape()[std::get<1>(pair)] == 1;
1057 std::get<1>(nOperands[1])};
1061 LLVM_DEBUG(llvm::dbgs() <<
"specified unit dims not found");
1072 patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1074 .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
1077 .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
1079 patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1080 patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1083 patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1084 patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1085 patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
1086 patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
1088 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1089 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1090 patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
1092 patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
1096 patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1097 patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
static Value expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Expand the given value so that the type matches the type of origDest.
static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet< unsigned > &unitDims, RewriterBase &rewriter)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, llvm::SmallDenseMap< unsigned, unsigned > &oldDimsToNewDimsMap, ArrayRef< AffineExpr > dimReplacements)
static void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns that are used to canonicalize the use of unit-extent dims for broadcasting.
static void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
static Value collapseValue(RewriterBase &rewriter, Location loc, Value operand, ArrayRef< int64_t > targetShape, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Collapse the given value so that the type matches the type of origOutput.
static llvm::ManagedStatic< PassManagerOptions > options
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
ArrayRef< AffineExpr > getResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns)
A pattern that converts init operands to input operands.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns)
Adds patterns that reduce the rank of named contraction ops that have unit dimensions in the operand(...
FailureOr< DropUnitDimsResult > dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, const ControlDropUnitDims &options)
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Compute the modified metadata for an operands of operation whose unit dims are being dropped.
SmallVector< ReassociationIndices > reassociation
SmallVector< int64_t > targetShape
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
Transformation to drop unit-extent dimensions from linalg.generic operations.