32 #include "llvm/ADT/SetVector.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
37 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
38 #include "mlir/Dialect/Linalg/Passes.h.inc"
41 #define DEBUG_TYPE "linalg-drop-unit-dims"
85 LogicalResult matchAndRewrite(GenericOp genericOp,
87 if (!genericOp.hasPureTensorSemantics())
89 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
92 auto outputOperands = genericOp.getDpsInitsMutable();
95 if (genericOp.getMatchingBlockArgument(&op).use_empty())
97 candidates.insert(&op);
100 if (candidates.empty())
104 int64_t origNumInput = genericOp.getNumDpsInputs();
108 newIndexingMaps.append(indexingMaps.begin(),
109 std::next(indexingMaps.begin(), origNumInput));
111 newInputOperands.push_back(op->get());
112 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
114 newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
119 llvm::to_vector(genericOp.getDpsInits());
123 auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
124 auto empty = rewriter.
create<tensor::EmptyOp>(
127 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
128 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
131 auto newOp = rewriter.
create<GenericOp>(
132 loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
133 newIndexingMaps, genericOp.getIteratorTypesArray(),
137 Region ®ion = newOp.getRegion();
140 for (
auto bbarg : genericOp.getRegionInputArgs())
144 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
149 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
150 if (candidates.count(&op))
156 for (
auto &op : genericOp.getBody()->getOperations()) {
157 rewriter.
clone(op, mapper);
159 rewriter.
replaceOp(genericOp, newOp.getResults());
230 const llvm::SmallDenseSet<unsigned> &unitDims,
232 for (IndexOp indexOp :
233 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
236 if (unitDims.count(indexOp.getDim()) != 0) {
240 unsigned droppedDims = llvm::count_if(
241 unitDims, [&](
unsigned dim) {
return dim < indexOp.getDim(); });
242 if (droppedDims != 0)
244 indexOp.getDim() - droppedDims);
257 auto origResultType = cast<RankedTensorType>(origDest.
getType());
258 if (rankReductionStrategy ==
260 unsigned rank = origResultType.getRank();
266 loc, result, origDest, offsets, sizes, strides);
269 assert(rankReductionStrategy ==
271 "unknown rank reduction strategy");
273 .
create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation)
284 if (
auto memrefType = dyn_cast<MemRefType>(operand.
getType())) {
285 if (rankReductionStrategy ==
287 FailureOr<Value> rankReducingExtract =
288 memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
290 assert(succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
291 return *rankReducingExtract;
295 rankReductionStrategy ==
297 "unknown rank reduction strategy");
298 MemRefLayoutAttrInterface layout;
299 auto targetType =
MemRefType::get(targetShape, memrefType.getElementType(),
300 layout, memrefType.getMemorySpace());
301 return rewriter.
create<memref::CollapseShapeOp>(loc, targetType, operand,
304 if (
auto tensorType = dyn_cast<RankedTensorType>(operand.
getType())) {
305 if (rankReductionStrategy ==
307 FailureOr<Value> rankReducingExtract =
308 tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
310 assert(succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
311 return *rankReducingExtract;
315 rankReductionStrategy ==
317 "unknown rank reduction strategy");
320 return rewriter.
create<tensor::CollapseShapeOp>(loc, targetType, operand,
323 llvm_unreachable(
"unsupported operand type");
338 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
343 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
347 auto isUnitDim = [&](
unsigned dim) {
348 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
349 unsigned oldPosition = dimExpr.getPosition();
350 return !oldDimsToNewDimsMap.count(oldPosition) &&
351 (operandShape[dim] == 1);
355 if (operandShape[dim] == 1) {
356 auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
357 return constAffineExpr && constAffineExpr.getValue() == 0;
363 while (dim < operandShape.size() && isUnitDim(dim))
364 reassociationGroup.push_back(dim++);
365 while (dim < operandShape.size()) {
366 assert(!isUnitDim(dim) &&
"expected non unit-extent");
367 reassociationGroup.push_back(dim);
368 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
369 newIndexExprs.push_back(newExpr);
373 while (dim < operandShape.size() && isUnitDim(dim)) {
374 reassociationGroup.push_back(dim++);
377 reassociationGroup.clear();
381 newIndexExprs, context);
385 FailureOr<DropUnitDimsResult>
389 if (indexingMaps.empty())
399 "invalid indexing maps for operation");
405 if (allowedUnitDims.empty()) {
407 genericOp,
"control function returns no allowed unit dims to prune");
409 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
410 allowedUnitDims.end());
411 llvm::SmallDenseSet<unsigned> unitDims;
413 if (
AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
414 if (dims[dimExpr.getPosition()] == 1 &&
415 unitDimsFilter.count(expr.index()))
416 unitDims.insert(expr.index());
423 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
425 unsigned newDims = 0;
426 for (
auto [index, attr] :
428 if (unitDims.count(index)) {
429 dimReplacements.push_back(
432 newIteratorTypes.push_back(attr);
433 oldDimToNewDimMap[index] = newDims;
434 dimReplacements.push_back(
455 auto hasCollapsibleType = [](
OpOperand &operand) {
456 Type operandType = operand.get().getType();
457 if (
auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
458 return memrefOperandType.getLayout().isIdentity();
460 if (
auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
461 return tensorOperandType.getEncoding() ==
nullptr;
465 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
466 auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
468 if (!hasCollapsibleType(opOperand)) {
471 newIndexingMaps.push_back(newIndexingMap);
472 targetShapes.push_back(llvm::to_vector(shape));
473 collapsed.push_back(
false);
474 reassociations.push_back({});
478 rewriter.
getContext(), genericOp, &opOperand, oldDimToNewDimMap,
480 reassociations.push_back(replacementInfo.reassociation);
481 newIndexingMaps.push_back(replacementInfo.indexMap);
482 targetShapes.push_back(replacementInfo.targetShape);
483 collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
484 indexingMap.getNumResults()));
489 if (newIndexingMaps == indexingMaps ||
500 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
501 int64_t idx = opOperand.getOperandNumber();
502 if (!collapsed[idx]) {
503 newOperands.push_back(opOperand.get());
506 newOperands.push_back(
collapseValue(rewriter, loc, opOperand.get(),
507 targetShapes[idx], reassociations[idx],
508 options.rankReductionStrategy));
518 resultTypes.reserve(genericOp.getNumResults());
519 for (
unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
520 resultTypes.push_back(newOutputs[i].
getType());
521 GenericOp replacementOp =
522 rewriter.
create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
523 newIndexingMaps, newIteratorTypes);
525 replacementOp.getRegion().begin());
533 for (
auto [index, result] :
llvm::enumerate(replacementOp.getResults())) {
534 unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
535 Value origDest = genericOp.getDpsInitOperand(index)->get();
536 if (!collapsed[opOperandIndex]) {
537 resultReplacements.push_back(result);
541 reassociations[opOperandIndex],
542 options.rankReductionStrategy);
543 resultReplacements.push_back(expandedValue);
555 LogicalResult matchAndRewrite(GenericOp genericOp,
557 FailureOr<DropUnitDimsResult> result =
559 if (failed(result)) {
562 rewriter.
replaceOp(genericOp, result->replacements);
581 LogicalResult matchAndRewrite(tensor::PadOp padOp,
585 if (allowedUnitDims.empty()) {
587 padOp,
"control function returns no allowed unit dims to prune");
590 if (padOp.getSourceType().getEncoding()) {
592 padOp,
"cannot collapse dims of tensor with encoding");
599 Value paddingVal = padOp.getConstantPaddingValue();
602 padOp,
"unimplemented: non-constant padding value");
606 int64_t padRank = sourceShape.size();
610 return maybeInt && *maybeInt == 0;
613 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
614 allowedUnitDims.end());
615 llvm::SmallDenseSet<unsigned> unitDims;
619 for (
const auto [dim, size, low, high] :
620 zip_equal(llvm::seq(
static_cast<int64_t
>(0), padRank), sourceShape,
621 padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
622 if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
623 isStaticZero(high)) {
624 unitDims.insert(dim);
626 newShape.push_back(size);
627 newLowPad.push_back(low);
628 newHighPad.push_back(high);
632 if (unitDims.empty()) {
639 while (dim < padRank && unitDims.contains(dim))
640 reassociationGroup.push_back(dim++);
641 while (dim < padRank) {
642 assert(!unitDims.contains(dim) &&
"expected non unit-extent");
643 reassociationGroup.push_back(dim);
646 while (dim < padRank && unitDims.contains(dim))
647 reassociationGroup.push_back(dim++);
648 reassociationMap.push_back(reassociationGroup);
649 reassociationGroup.clear();
652 Value collapsedSource =
653 collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
654 reassociationMap,
options.rankReductionStrategy);
656 auto newPadOp = rewriter.
create<tensor::PadOp>(
657 padOp.getLoc(),
Type(), collapsedSource, newLowPad,
658 newHighPad, paddingVal, padOp.getNofold());
660 Value dest = padOp.getResult();
661 if (
options.rankReductionStrategy ==
664 int64_t numUnitDims = 0;
665 for (
auto dim : llvm::seq(
static_cast<int64_t
>(0), padRank)) {
666 if (unitDims.contains(dim)) {
672 rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
674 dest = rewriter.
create<tensor::EmptyOp>(
675 padOp.getLoc(), expandedSizes,
676 padOp.getResultType().getElementType());
679 Value expandedValue =
680 expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
681 reassociationMap,
options.rankReductionStrategy);
682 rewriter.
replaceOp(padOp, expandedValue);
693 struct RankReducedExtractSliceOp
697 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
699 RankedTensorType resultType = sliceOp.getType();
701 for (
auto size : resultType.getShape())
704 if (!reassociation ||
705 reassociation->size() ==
static_cast<size_t>(resultType.getRank()))
711 auto rankReducedType = cast<RankedTensorType>(
712 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
713 reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
717 Value newSlice = rewriter.
create<tensor::ExtractSliceOp>(
718 loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
720 sliceOp, resultType, newSlice, *reassociation);
727 template <
typename InsertOpTy>
731 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
733 RankedTensorType sourceType = insertSliceOp.getSourceType();
735 for (
auto size : sourceType.getShape())
738 if (!reassociation ||
739 reassociation->size() ==
static_cast<size_t>(sourceType.getRank()))
742 Location loc = insertSliceOp.getLoc();
743 tensor::CollapseShapeOp reshapedSource;
749 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
751 reshapedSource = rewriter.
create<tensor::CollapseShapeOp>(
752 loc, insertSliceOp.getSource(), *reassociation);
755 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
756 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
757 insertSliceOp.getMixedStrides());
768 auto *context =
patterns.getContext();
772 patterns.add<RankReducedExtractSliceOp,
773 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
774 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
776 linalg::FillOp::getCanonicalizationPatterns(
patterns, context);
777 tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);
778 tensor::EmptyOp::getCanonicalizationPatterns(
patterns, context);
779 tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
788 auto *context =
patterns.getContext();
792 linalg::FillOp::getCanonicalizationPatterns(
patterns, context);
793 tensor::EmptyOp::getCanonicalizationPatterns(
patterns, context);
801 if (
options.rankReductionStrategy ==
804 }
else if (
options.rankReductionStrategy ==
806 ReassociativeReshape) {
818 struct LinalgFoldUnitExtentDimsPass
819 :
public impl::LinalgFoldUnitExtentDimsPassBase<
820 LinalgFoldUnitExtentDimsPass> {
821 using impl::LinalgFoldUnitExtentDimsPassBase<
822 LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
823 void runOnOperation()
override {
828 if (useRankReducingSlices) {
845 getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
847 bool lastDim = pos == rank - 1;
849 for (int64_t i = 0; i < rank - 1; i++) {
850 if (i == pos || (lastDim && i == pos - 1))
858 return reassociation;
867 auto valType = cast<ShapedType>(val.
getType());
869 collapsedShape.erase(collapsedShape.begin() + pos);
871 rewriter, val.
getLoc(), val, collapsedShape,
872 getReassociationForReshapeAtDim(valType.getRank(), pos),
882 template <
typename FromOpTy,
typename ToOpTy>
890 assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
891 "expected 3 operands and dims");
892 return llvm::map_to_vector(
893 llvm::zip(operands, operandCollapseDims), [&](
auto pair) {
894 return collapseSingletonDimAt(rewriter, std::get<0>(pair),
901 RankedTensorType expandedType, int64_t dim)
const {
902 return rewriter.
create<tensor::ExpandShapeOp>(
903 result.
getLoc(), expandedType, result,
904 getReassociationForReshapeAtDim(expandedType.getRank(), dim));
907 LogicalResult matchAndRewrite(FromOpTy contractionOp,
909 if (contractionOp.hasUserDefinedMaps()) {
911 contractionOp,
"ops with user-defined maps are not supported");
914 auto loc = contractionOp.getLoc();
915 auto inputs = contractionOp.getDpsInputs();
916 auto inits = contractionOp.getDpsInits();
917 if (inputs.size() != 2 || inits.size() != 1)
919 "expected 2 inputs and 1 init");
920 auto lhs = inputs[0];
921 auto rhs = inputs[1];
922 auto init = inits[0];
926 if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
928 "no reducable dims found");
931 collapseOperands(rewriter, operands, operandUnitDims);
932 Value collapsedLhs = collapsedOperands[0];
933 Value collapsedRhs = collapsedOperands[1];
934 Value collapsedInit = collapsedOperands[2];
936 if (isa<RankedTensorType>(collapsedInit.
getType()))
937 collapsedResultTy.push_back(collapsedInit.
getType());
938 auto collapsedOp = rewriter.
create<ToOpTy>(
939 loc, collapsedResultTy,
ValueRange{collapsedLhs, collapsedRhs},
941 for (
auto attr : contractionOp->getAttrs()) {
942 if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
943 attr.getName() ==
"indexing_maps")
945 collapsedOp->
setAttr(attr.getName(), attr.getValue());
948 auto results = contractionOp.getResults();
949 assert(results.size() < 2 &&
"expected at most one result");
950 if (results.empty()) {
951 rewriter.
replaceOp(contractionOp, collapsedOp);
955 expandResult(rewriter, collapsedOp.getResultTensors()[0],
956 cast<RankedTensorType>(results[0].getType()),
957 operandUnitDims[2]));
966 virtual LogicalResult
967 getOperandUnitDims(LinalgOp op,
972 template <
typename FromOpTy,
typename ToOpTy>
973 struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
974 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
978 getOperandUnitDims(LinalgOp op,
980 FailureOr<ContractionDimensions> maybeContractionDims =
982 if (failed(maybeContractionDims)) {
983 LLVM_DEBUG(llvm::dbgs() <<
"could not infer contraction dims");
988 if (contractionDims.
batch.size() != 1)
990 auto batchDim = contractionDims.
batch[0];
992 op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
993 if (bOperands.size() != 3 || llvm::any_of(bOperands, [](
auto pair) {
994 return cast<ShapedType>(std::get<0>(pair).getType())
995 .getShape()[std::get<1>(pair)] != 1;
997 LLVM_DEBUG(llvm::dbgs() <<
"specified unit dims not found");
1002 std::get<1>(bOperands[1]),
1003 std::get<1>(bOperands[2])};
1009 template <
typename FromOpTy,
typename ToOpTy>
1010 struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1011 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1014 static bool constexpr reduceLeft =
1015 (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1016 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1017 (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
1018 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1019 (std::is_same_v<FromOpTy, MatmulOp> &&
1020 std::is_same_v<ToOpTy, VecmatOp>) ||
1021 (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
1022 std::is_same_v<ToOpTy, VecmatOp>) ||
1023 (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1027 getOperandUnitDims(LinalgOp op,
1029 FailureOr<ContractionDimensions> maybeContractionDims =
1031 if (failed(maybeContractionDims)) {
1032 LLVM_DEBUG(llvm::dbgs() <<
"could not infer contraction dims");
1037 if constexpr (reduceLeft) {
1038 auto m = contractionDims.
m[0];
1040 op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1041 if (mOperands.size() != 2)
1043 if (llvm::all_of(mOperands, [](
auto pair) {
1044 return cast<ShapedType>(std::get<0>(pair).
getType())
1045 .getShape()[std::get<1>(pair)] == 1;
1048 std::get<1>(mOperands[1])};
1052 auto n = contractionDims.
n[0];
1054 op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1055 if (nOperands.size() != 2)
1057 if (llvm::all_of(nOperands, [](
auto pair) {
1058 return cast<ShapedType>(std::get<0>(pair).
getType())
1059 .getShape()[std::get<1>(pair)] == 1;
1062 std::get<1>(nOperands[1])};
1066 LLVM_DEBUG(llvm::dbgs() <<
"specified unit dims not found");
1077 patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1079 .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
1082 .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
1084 patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1085 patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1088 patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1089 patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1090 patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
1091 patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
1093 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1094 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1095 patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
1097 patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
1101 patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1102 patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
static Value expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Expand the given value so that the type matches the type of origDest.
static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet< unsigned > &unitDims, RewriterBase &rewriter)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, llvm::SmallDenseMap< unsigned, unsigned > &oldDimsToNewDimsMap, ArrayRef< AffineExpr > dimReplacements)
static void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns that are used to canonicalize the use of unit-extent dims for broadcasting.
static void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
static Value collapseValue(RewriterBase &rewriter, Location loc, Value operand, ArrayRef< int64_t > targetShape, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Collapse the given value so that the type matches the type of origOutput.
static llvm::ManagedStatic< PassManagerOptions > options
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
ArrayRef< AffineExpr > getResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns)
A pattern that converts init operands to input operands.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns)
Adds patterns that reduce the rank of named contraction ops that have unit dimensions in the operand(...
FailureOr< DropUnitDimsResult > dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, const ControlDropUnitDims &options)
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Compute the modified metadata for an operands of operation whose unit dims are being dropped.
SmallVector< ReassociationIndices > reassociation
SmallVector< int64_t > targetShape
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
Transformation to drop unit-extent dimensions from linalg.generic operations.