32 #include "llvm/ADT/SetVector.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
37 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
38 #include "mlir/Dialect/Linalg/Passes.h.inc"
41 #define DEBUG_TYPE "linalg-drop-unit-dims"
85 LogicalResult matchAndRewrite(GenericOp genericOp,
87 if (!genericOp.hasPureTensorSemantics())
89 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
92 auto outputOperands = genericOp.getDpsInitsMutable();
95 if (genericOp.getMatchingBlockArgument(&op).use_empty())
97 candidates.insert(&op);
100 if (candidates.empty())
104 int64_t origNumInput = genericOp.getNumDpsInputs();
108 newIndexingMaps.append(indexingMaps.begin(),
109 std::next(indexingMaps.begin(), origNumInput));
111 newInputOperands.push_back(op->get());
112 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
114 newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
119 llvm::to_vector(genericOp.getDpsInits());
123 auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
124 auto empty = rewriter.
create<tensor::EmptyOp>(
127 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
128 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
131 auto newOp = rewriter.
create<GenericOp>(
132 loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
133 newIndexingMaps, genericOp.getIteratorTypesArray(),
137 Region ®ion = newOp.getRegion();
140 for (
auto bbarg : genericOp.getRegionInputArgs())
144 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
149 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
150 if (candidates.count(&op))
156 for (
auto &op : genericOp.getBody()->getOperations()) {
157 rewriter.
clone(op, mapper);
159 rewriter.
replaceOp(genericOp, newOp.getResults());
234 const llvm::SmallDenseSet<unsigned> &unitDims,
236 for (IndexOp indexOp :
237 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
240 if (unitDims.count(indexOp.getDim()) != 0) {
244 unsigned droppedDims = llvm::count_if(
245 unitDims, [&](
unsigned dim) {
return dim < indexOp.getDim(); });
246 if (droppedDims != 0)
248 indexOp.getDim() - droppedDims);
261 auto origResultType = cast<RankedTensorType>(origDest.
getType());
262 if (rankReductionStrategy ==
264 unsigned rank = origResultType.getRank();
270 loc, result, origDest, offsets, sizes, strides);
273 assert(rankReductionStrategy ==
275 "unknown rank reduction strategy");
277 .
create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation)
288 if (
auto memrefType = dyn_cast<MemRefType>(operand.
getType())) {
289 if (rankReductionStrategy ==
291 FailureOr<Value> rankReducingExtract =
292 memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
294 assert(succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
295 return *rankReducingExtract;
299 rankReductionStrategy ==
301 "unknown rank reduction strategy");
302 MemRefLayoutAttrInterface layout;
303 auto targetType =
MemRefType::get(targetShape, memrefType.getElementType(),
304 layout, memrefType.getMemorySpace());
305 return rewriter.
create<memref::CollapseShapeOp>(loc, targetType, operand,
308 if (
auto tensorType = dyn_cast<RankedTensorType>(operand.
getType())) {
309 if (rankReductionStrategy ==
311 FailureOr<Value> rankReducingExtract =
312 tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
314 assert(succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
315 return *rankReducingExtract;
319 rankReductionStrategy ==
321 "unknown rank reduction strategy");
324 return rewriter.
create<tensor::CollapseShapeOp>(loc, targetType, operand,
327 llvm_unreachable(
"unsupported operand type");
342 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
347 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
351 auto isUnitDim = [&](
unsigned dim) {
352 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
353 unsigned oldPosition = dimExpr.getPosition();
354 return !oldDimsToNewDimsMap.count(oldPosition) &&
355 (operandShape[dim] == 1);
359 if (operandShape[dim] == 1) {
360 auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
361 return constAffineExpr && constAffineExpr.getValue() == 0;
367 while (dim < operandShape.size() && isUnitDim(dim))
368 reassociationGroup.push_back(dim++);
369 while (dim < operandShape.size()) {
370 assert(!isUnitDim(dim) &&
"expected non unit-extent");
371 reassociationGroup.push_back(dim);
372 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
373 newIndexExprs.push_back(newExpr);
377 while (dim < operandShape.size() && isUnitDim(dim)) {
378 reassociationGroup.push_back(dim++);
381 reassociationGroup.clear();
385 newIndexExprs, context);
392 if (indexingMaps.empty())
401 "invalid indexing maps for operation");
407 if (allowedUnitDims.empty()) {
409 genericOp,
"control function returns no allowed unit dims to prune");
411 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
412 allowedUnitDims.end());
413 llvm::SmallDenseSet<unsigned> unitDims;
415 if (
AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
416 if (dims[dimExpr.getPosition()] == 1 &&
417 unitDimsFilter.count(expr.index()))
418 unitDims.insert(expr.index());
425 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
427 unsigned newDims = 0;
428 for (
auto [index, attr] :
430 if (unitDims.count(index)) {
431 dimReplacements.push_back(
434 newIteratorTypes.push_back(attr);
435 oldDimToNewDimMap[index] = newDims;
436 dimReplacements.push_back(
457 auto hasCollapsibleType = [](
OpOperand &operand) {
458 Type operandType = operand.get().getType();
459 if (
auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
460 return memrefOperandType.getLayout().isIdentity();
462 if (
auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
463 return tensorOperandType.getEncoding() ==
nullptr;
467 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
468 auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
470 if (!hasCollapsibleType(opOperand)) {
473 newIndexingMaps.push_back(newIndexingMap);
474 targetShapes.push_back(llvm::to_vector(shape));
475 collapsed.push_back(
false);
476 reassociations.push_back({});
480 rewriter.
getContext(), genericOp, &opOperand, oldDimToNewDimMap,
482 reassociations.push_back(replacementInfo.reassociation);
483 newIndexingMaps.push_back(replacementInfo.indexMap);
484 targetShapes.push_back(replacementInfo.targetShape);
485 collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
486 indexingMap.getNumResults()));
491 if (newIndexingMaps == indexingMaps ||
501 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
502 int64_t idx = opOperand.getOperandNumber();
503 if (!collapsed[idx]) {
504 newOperands.push_back(opOperand.get());
507 newOperands.push_back(
collapseValue(rewriter, loc, opOperand.get(),
508 targetShapes[idx], reassociations[idx],
509 options.rankReductionStrategy));
519 resultTypes.reserve(genericOp.getNumResults());
520 for (
unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
521 resultTypes.push_back(newOutputs[i].
getType());
522 GenericOp replacementOp =
523 rewriter.
create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
524 newIndexingMaps, newIteratorTypes);
526 replacementOp.getRegion().begin());
535 for (
auto [index, result] :
llvm::enumerate(replacementOp.getResults())) {
536 unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
537 Value origDest = genericOp.getDpsInitOperand(index)->get();
538 if (!collapsed[opOperandIndex]) {
539 resultReplacements.push_back(result);
543 reassociations[opOperandIndex],
544 options.rankReductionStrategy);
545 resultReplacements.push_back(expandedValue);
548 rewriter.
replaceOp(genericOp, resultReplacements);
558 LogicalResult matchAndRewrite(GenericOp genericOp,
578 LogicalResult matchAndRewrite(tensor::PadOp padOp,
582 if (allowedUnitDims.empty()) {
584 padOp,
"control function returns no allowed unit dims to prune");
587 if (padOp.getSourceType().getEncoding()) {
589 padOp,
"cannot collapse dims of tensor with encoding");
596 Value paddingVal = padOp.getConstantPaddingValue();
599 padOp,
"unimplemented: non-constant padding value");
603 int64_t padRank = sourceShape.size();
607 return maybeInt && *maybeInt == 0;
610 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
611 allowedUnitDims.end());
612 llvm::SmallDenseSet<unsigned> unitDims;
616 for (
const auto [dim, size, low, high] :
617 zip_equal(llvm::seq(
static_cast<int64_t
>(0), padRank), sourceShape,
618 padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
619 if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
620 isStaticZero(high)) {
621 unitDims.insert(dim);
623 newShape.push_back(size);
624 newLowPad.push_back(low);
625 newHighPad.push_back(high);
629 if (unitDims.empty()) {
636 while (dim < padRank && unitDims.contains(dim))
637 reassociationGroup.push_back(dim++);
638 while (dim < padRank) {
639 assert(!unitDims.contains(dim) &&
"expected non unit-extent");
640 reassociationGroup.push_back(dim);
643 while (dim < padRank && unitDims.contains(dim))
644 reassociationGroup.push_back(dim++);
645 reassociationMap.push_back(reassociationGroup);
646 reassociationGroup.clear();
649 Value collapsedSource =
650 collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
651 reassociationMap,
options.rankReductionStrategy);
653 auto newPadOp = rewriter.
create<tensor::PadOp>(
654 padOp.getLoc(),
Type(), collapsedSource, newLowPad,
655 newHighPad, paddingVal, padOp.getNofold());
657 Value dest = padOp.getResult();
658 if (
options.rankReductionStrategy ==
661 int64_t numUnitDims = 0;
662 for (
auto dim : llvm::seq(
static_cast<int64_t
>(0), padRank)) {
663 if (unitDims.contains(dim)) {
669 rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
671 dest = rewriter.
create<tensor::EmptyOp>(
672 padOp.getLoc(), expandedSizes,
673 padOp.getResultType().getElementType());
676 Value expandedValue =
677 expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
678 reassociationMap,
options.rankReductionStrategy);
679 rewriter.
replaceOp(padOp, expandedValue);
690 struct RankReducedExtractSliceOp
694 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
696 RankedTensorType resultType = sliceOp.getType();
698 for (
auto size : resultType.getShape())
701 if (!reassociation ||
702 reassociation->size() ==
static_cast<size_t>(resultType.getRank()))
708 auto rankReducedType = cast<RankedTensorType>(
709 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
710 reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
714 Value newSlice = rewriter.
create<tensor::ExtractSliceOp>(
715 loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
717 sliceOp, resultType, newSlice, *reassociation);
724 template <
typename InsertOpTy>
728 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
730 RankedTensorType sourceType = insertSliceOp.getSourceType();
732 for (
auto size : sourceType.getShape())
735 if (!reassociation ||
736 reassociation->size() ==
static_cast<size_t>(sourceType.getRank()))
739 Location loc = insertSliceOp.getLoc();
740 tensor::CollapseShapeOp reshapedSource;
746 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
748 reshapedSource = rewriter.
create<tensor::CollapseShapeOp>(
749 loc, insertSliceOp.getSource(), *reassociation);
752 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
753 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
754 insertSliceOp.getMixedStrides());
767 patterns.
add<DropPadUnitDims>(context,
options);
769 patterns.
add<RankReducedExtractSliceOp,
770 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
771 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
773 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
774 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
775 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
776 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
786 options.rankReductionStrategy =
789 patterns.
add<DropPadUnitDims>(context,
options);
791 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
792 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
800 if (
options.rankReductionStrategy ==
803 }
else if (
options.rankReductionStrategy ==
805 ReassociativeReshape) {
812 patterns.
add<MoveInitOperandsToInput>(patterns.
getContext());
817 struct LinalgFoldUnitExtentDimsPass
818 :
public impl::LinalgFoldUnitExtentDimsPassBase<
819 LinalgFoldUnitExtentDimsPass> {
820 using impl::LinalgFoldUnitExtentDimsPassBase<
821 LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
822 void runOnOperation()
override {
827 if (useRankReducingSlices) {
844 getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
846 bool lastDim = pos == rank - 1;
848 for (int64_t i = 0; i < rank - 1; i++) {
849 if (i == pos || (lastDim && i == pos - 1))
857 return reassociation;
866 auto valType = cast<ShapedType>(val.
getType());
868 collapsedShape.erase(collapsedShape.begin() + pos);
870 rewriter, val.
getLoc(), val, collapsedShape,
871 getReassociationForReshapeAtDim(valType.getRank(), pos),
881 template <
typename FromOpTy,
typename ToOpTy>
889 assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
890 "expected 3 operands and dims");
891 return llvm::map_to_vector(
892 llvm::zip(operands, operandCollapseDims), [&](
auto pair) {
893 return collapseSingletonDimAt(rewriter, std::get<0>(pair),
900 RankedTensorType expandedType, int64_t dim)
const {
901 return rewriter.
create<tensor::ExpandShapeOp>(
902 result.
getLoc(), expandedType, result,
903 getReassociationForReshapeAtDim(expandedType.getRank(), dim));
906 LogicalResult matchAndRewrite(FromOpTy contractionOp,
909 auto loc = contractionOp.
getLoc();
910 auto inputs = contractionOp.getDpsInputs();
911 auto inits = contractionOp.getDpsInits();
912 if (inputs.size() != 2 || inits.size() != 1)
914 "expected 2 inputs and 1 init");
915 auto lhs = inputs[0];
916 auto rhs = inputs[1];
917 auto init = inits[0];
921 if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
923 "no reducable dims found");
926 collapseOperands(rewriter, operands, operandUnitDims);
927 Value collapsedLhs = collapsedOperands[0];
928 Value collapsedRhs = collapsedOperands[1];
929 Value collapsedInit = collapsedOperands[2];
931 if (isa<RankedTensorType>(collapsedInit.
getType()))
932 collapsedResultTy.push_back(collapsedInit.
getType());
933 auto collapsedOp = rewriter.
create<ToOpTy>(
934 loc, collapsedResultTy,
ValueRange{collapsedLhs, collapsedRhs},
936 for (
auto attr : contractionOp->getAttrs()) {
937 if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
939 collapsedOp->
setAttr(attr.getName(), attr.getValue());
942 auto results = contractionOp.getResults();
943 assert(results.size() < 2 &&
"expected at most one result");
944 if (results.empty()) {
945 rewriter.
replaceOp(contractionOp, collapsedOp);
949 expandResult(rewriter, collapsedOp.getResultTensors()[0],
950 cast<RankedTensorType>(results[0].getType()),
951 operandUnitDims[2]));
960 virtual LogicalResult
961 getOperandUnitDims(LinalgOp op,
966 template <
typename FromOpTy,
typename ToOpTy>
967 struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
968 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
972 getOperandUnitDims(LinalgOp op,
974 FailureOr<ContractionDimensions> maybeContractionDims =
976 if (failed(maybeContractionDims)) {
977 LLVM_DEBUG(llvm::dbgs() <<
"could not infer contraction dims");
982 if (contractionDims.
batch.size() != 1)
984 auto batchDim = contractionDims.
batch[0];
986 op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
987 if (bOperands.size() != 3 || llvm::any_of(bOperands, [](
auto pair) {
988 return cast<ShapedType>(std::get<0>(pair).getType())
989 .getShape()[std::get<1>(pair)] != 1;
991 LLVM_DEBUG(llvm::dbgs() <<
"specified unit dims not found");
996 std::get<1>(bOperands[1]),
997 std::get<1>(bOperands[2])};
1003 template <
typename FromOpTy,
typename ToOpTy>
1004 struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1005 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1008 static bool constexpr reduceLeft =
1009 (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1010 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1011 (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
1012 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1013 (std::is_same_v<FromOpTy, MatmulOp> &&
1014 std::is_same_v<ToOpTy, VecmatOp>) ||
1015 (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
1016 std::is_same_v<ToOpTy, VecmatOp>) ||
1017 (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1021 getOperandUnitDims(LinalgOp op,
1023 FailureOr<ContractionDimensions> maybeContractionDims =
1025 if (failed(maybeContractionDims)) {
1026 LLVM_DEBUG(llvm::dbgs() <<
"could not infer contraction dims");
1031 if constexpr (reduceLeft) {
1032 auto m = contractionDims.
m[0];
1034 op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1035 if (mOperands.size() != 2)
1037 if (llvm::all_of(mOperands, [](
auto pair) {
1038 return cast<ShapedType>(std::get<0>(pair).
getType())
1039 .getShape()[std::get<1>(pair)] == 1;
1042 std::get<1>(mOperands[1])};
1046 auto n = contractionDims.
n[0];
1048 op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1049 if (nOperands.size() != 2)
1051 if (llvm::all_of(nOperands, [](
auto pair) {
1052 return cast<ShapedType>(std::get<0>(pair).
getType())
1053 .getShape()[std::get<1>(pair)] == 1;
1056 std::get<1>(nOperands[1])};
1060 LLVM_DEBUG(llvm::dbgs() <<
"specified unit dims not found");
1071 patterns.
add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1073 .
add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
1076 .
add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
1078 patterns.
add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1079 patterns.
add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1082 patterns.
add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1083 patterns.
add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1084 patterns.
add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
1085 patterns.
add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
1087 patterns.
add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1088 patterns.
add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1089 patterns.
add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
1091 patterns.
add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
1095 patterns.
add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1096 patterns.
add<RankReduceMatmul<VecmatOp, DotOp>>(context);
static Value expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Expand the given value so that the type matches the type of origDest.
static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet< unsigned > &unitDims, RewriterBase &rewriter)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, llvm::SmallDenseMap< unsigned, unsigned > &oldDimsToNewDimsMap, ArrayRef< AffineExpr > dimReplacements)
static void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns that are used to canonicalize the use of unit-extent dims for broadcasting.
static void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
static Value collapseValue(RewriterBase &rewriter, Location loc, Value operand, ArrayRef< int64_t > targetShape, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Collapse the given value so that the type matches the type of origOutput.
static llvm::ManagedStatic< PassManagerOptions > options
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
ArrayRef< AffineExpr > getResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns)
A pattern that converts init operands to input operands.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns)
Adds patterns that reduce the rank of named contraction ops that have unit dimensions in the operand(...
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, const ControlDropUnitDims &options)
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Compute the modified metadata for an operands of operation whose unit dims are being dropped.
SmallVector< ReassociationIndices > reassociation
SmallVector< int64_t > targetShape
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
Transformation to drop unit-extent dimensions from linalg.generic operations.