32 #include "llvm/ADT/SetVector.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
37 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
38 #include "mlir/Dialect/Linalg/Passes.h.inc"
41 #define DEBUG_TYPE "linalg-drop-unit-dims"
85 LogicalResult matchAndRewrite(GenericOp genericOp,
87 if (!genericOp.hasPureTensorSemantics())
89 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
92 auto outputOperands = genericOp.getDpsInitsMutable();
95 if (genericOp.getMatchingBlockArgument(&op).use_empty())
97 candidates.insert(&op);
100 if (candidates.empty())
104 int64_t origNumInput = genericOp.getNumDpsInputs();
108 newIndexingMaps.append(indexingMaps.begin(),
109 std::next(indexingMaps.begin(), origNumInput));
111 newInputOperands.push_back(op->get());
112 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
114 newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
119 llvm::to_vector(genericOp.getDpsInits());
123 auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
124 auto empty = rewriter.
create<tensor::EmptyOp>(
127 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
128 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
131 auto newOp = rewriter.
create<GenericOp>(
132 loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
133 newIndexingMaps, genericOp.getIteratorTypesArray(),
137 Region ®ion = newOp.getRegion();
140 for (
auto bbarg : genericOp.getRegionInputArgs())
144 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
149 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
150 if (candidates.count(&op))
156 for (
auto &op : genericOp.getBody()->getOperations()) {
157 rewriter.
clone(op, mapper);
159 rewriter.
replaceOp(genericOp, newOp.getResults());
230 const llvm::SmallDenseSet<unsigned> &unitDims,
232 for (IndexOp indexOp :
233 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
236 if (unitDims.count(indexOp.getDim()) != 0) {
240 unsigned droppedDims = llvm::count_if(
241 unitDims, [&](
unsigned dim) {
return dim < indexOp.getDim(); });
242 if (droppedDims != 0)
244 indexOp.getDim() - droppedDims);
257 auto origResultType = cast<RankedTensorType>(origDest.
getType());
258 if (rankReductionStrategy ==
260 unsigned rank = origResultType.getRank();
266 loc, result, origDest, offsets, sizes, strides);
269 assert(rankReductionStrategy ==
271 "unknown rank reduction strategy");
273 .
create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation)
284 if (
auto memrefType = dyn_cast<MemRefType>(operand.
getType())) {
285 if (rankReductionStrategy ==
287 FailureOr<Value> rankReducingExtract =
288 memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
290 assert(succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
291 return *rankReducingExtract;
295 rankReductionStrategy ==
297 "unknown rank reduction strategy");
298 MemRefLayoutAttrInterface layout;
299 auto targetType =
MemRefType::get(targetShape, memrefType.getElementType(),
300 layout, memrefType.getMemorySpace());
301 return rewriter.
create<memref::CollapseShapeOp>(loc, targetType, operand,
304 if (
auto tensorType = dyn_cast<RankedTensorType>(operand.
getType())) {
305 if (rankReductionStrategy ==
307 FailureOr<Value> rankReducingExtract =
308 tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
310 assert(succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
311 return *rankReducingExtract;
315 rankReductionStrategy ==
317 "unknown rank reduction strategy");
320 return rewriter.
create<tensor::CollapseShapeOp>(loc, targetType, operand,
323 llvm_unreachable(
"unsupported operand type");
338 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
343 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
347 auto isUnitDim = [&](
unsigned dim) {
348 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
349 unsigned oldPosition = dimExpr.getPosition();
350 return !oldDimsToNewDimsMap.count(oldPosition) &&
351 (operandShape[dim] == 1);
355 if (operandShape[dim] == 1) {
356 auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
357 return constAffineExpr && constAffineExpr.getValue() == 0;
363 while (dim < operandShape.size() && isUnitDim(dim))
364 reassociationGroup.push_back(dim++);
365 while (dim < operandShape.size()) {
366 assert(!isUnitDim(dim) &&
"expected non unit-extent");
367 reassociationGroup.push_back(dim);
368 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
369 newIndexExprs.push_back(newExpr);
373 while (dim < operandShape.size() && isUnitDim(dim)) {
374 reassociationGroup.push_back(dim++);
377 reassociationGroup.clear();
381 newIndexExprs, context);
385 FailureOr<DropUnitDimsResult>
389 if (indexingMaps.empty())
398 "invalid indexing maps for operation");
404 if (allowedUnitDims.empty()) {
406 genericOp,
"control function returns no allowed unit dims to prune");
408 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
409 allowedUnitDims.end());
410 llvm::SmallDenseSet<unsigned> unitDims;
412 if (
AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
413 if (dims[dimExpr.getPosition()] == 1 &&
414 unitDimsFilter.count(expr.index()))
415 unitDims.insert(expr.index());
422 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
424 unsigned newDims = 0;
425 for (
auto [index, attr] :
427 if (unitDims.count(index)) {
428 dimReplacements.push_back(
431 newIteratorTypes.push_back(attr);
432 oldDimToNewDimMap[index] = newDims;
433 dimReplacements.push_back(
454 auto hasCollapsibleType = [](
OpOperand &operand) {
455 Type operandType = operand.get().getType();
456 if (
auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
457 return memrefOperandType.getLayout().isIdentity();
459 if (
auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
460 return tensorOperandType.getEncoding() ==
nullptr;
464 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
465 auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
467 if (!hasCollapsibleType(opOperand)) {
470 newIndexingMaps.push_back(newIndexingMap);
471 targetShapes.push_back(llvm::to_vector(shape));
472 collapsed.push_back(
false);
473 reassociations.push_back({});
477 rewriter.
getContext(), genericOp, &opOperand, oldDimToNewDimMap,
479 reassociations.push_back(replacementInfo.reassociation);
480 newIndexingMaps.push_back(replacementInfo.indexMap);
481 targetShapes.push_back(replacementInfo.targetShape);
482 collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
483 indexingMap.getNumResults()));
488 if (newIndexingMaps == indexingMaps ||
498 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
499 int64_t idx = opOperand.getOperandNumber();
500 if (!collapsed[idx]) {
501 newOperands.push_back(opOperand.get());
504 newOperands.push_back(
collapseValue(rewriter, loc, opOperand.get(),
505 targetShapes[idx], reassociations[idx],
506 options.rankReductionStrategy));
516 resultTypes.reserve(genericOp.getNumResults());
517 for (
unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
518 resultTypes.push_back(newOutputs[i].
getType());
519 GenericOp replacementOp =
520 rewriter.
create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
521 newIndexingMaps, newIteratorTypes);
523 replacementOp.getRegion().begin());
531 for (
auto [index, result] :
llvm::enumerate(replacementOp.getResults())) {
532 unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
533 Value origDest = genericOp.getDpsInitOperand(index)->get();
534 if (!collapsed[opOperandIndex]) {
535 resultReplacements.push_back(result);
539 reassociations[opOperandIndex],
540 options.rankReductionStrategy);
541 resultReplacements.push_back(expandedValue);
553 LogicalResult matchAndRewrite(GenericOp genericOp,
555 FailureOr<DropUnitDimsResult> result =
557 if (failed(result)) {
560 rewriter.
replaceOp(genericOp, result->replacements);
579 LogicalResult matchAndRewrite(tensor::PadOp padOp,
583 if (allowedUnitDims.empty()) {
585 padOp,
"control function returns no allowed unit dims to prune");
588 if (padOp.getSourceType().getEncoding()) {
590 padOp,
"cannot collapse dims of tensor with encoding");
597 Value paddingVal = padOp.getConstantPaddingValue();
600 padOp,
"unimplemented: non-constant padding value");
604 int64_t padRank = sourceShape.size();
608 return maybeInt && *maybeInt == 0;
611 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
612 allowedUnitDims.end());
613 llvm::SmallDenseSet<unsigned> unitDims;
617 for (
const auto [dim, size, low, high] :
618 zip_equal(llvm::seq(
static_cast<int64_t
>(0), padRank), sourceShape,
619 padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
620 if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
621 isStaticZero(high)) {
622 unitDims.insert(dim);
624 newShape.push_back(size);
625 newLowPad.push_back(low);
626 newHighPad.push_back(high);
630 if (unitDims.empty()) {
637 while (dim < padRank && unitDims.contains(dim))
638 reassociationGroup.push_back(dim++);
639 while (dim < padRank) {
640 assert(!unitDims.contains(dim) &&
"expected non unit-extent");
641 reassociationGroup.push_back(dim);
644 while (dim < padRank && unitDims.contains(dim))
645 reassociationGroup.push_back(dim++);
646 reassociationMap.push_back(reassociationGroup);
647 reassociationGroup.clear();
650 Value collapsedSource =
651 collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
652 reassociationMap,
options.rankReductionStrategy);
654 auto newPadOp = rewriter.
create<tensor::PadOp>(
655 padOp.getLoc(),
Type(), collapsedSource, newLowPad,
656 newHighPad, paddingVal, padOp.getNofold());
658 Value dest = padOp.getResult();
659 if (
options.rankReductionStrategy ==
662 int64_t numUnitDims = 0;
663 for (
auto dim : llvm::seq(
static_cast<int64_t
>(0), padRank)) {
664 if (unitDims.contains(dim)) {
670 rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
672 dest = rewriter.
create<tensor::EmptyOp>(
673 padOp.getLoc(), expandedSizes,
674 padOp.getResultType().getElementType());
677 Value expandedValue =
678 expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
679 reassociationMap,
options.rankReductionStrategy);
680 rewriter.
replaceOp(padOp, expandedValue);
691 struct RankReducedExtractSliceOp
695 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
697 RankedTensorType resultType = sliceOp.getType();
699 for (
auto size : resultType.getShape())
702 if (!reassociation ||
703 reassociation->size() ==
static_cast<size_t>(resultType.getRank()))
709 auto rankReducedType = cast<RankedTensorType>(
710 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
711 reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
715 Value newSlice = rewriter.
create<tensor::ExtractSliceOp>(
716 loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
718 sliceOp, resultType, newSlice, *reassociation);
725 template <
typename InsertOpTy>
729 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
731 RankedTensorType sourceType = insertSliceOp.getSourceType();
733 for (
auto size : sourceType.getShape())
736 if (!reassociation ||
737 reassociation->size() ==
static_cast<size_t>(sourceType.getRank()))
740 Location loc = insertSliceOp.getLoc();
741 tensor::CollapseShapeOp reshapedSource;
747 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
749 reshapedSource = rewriter.
create<tensor::CollapseShapeOp>(
750 loc, insertSliceOp.getSource(), *reassociation);
753 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
754 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
755 insertSliceOp.getMixedStrides());
768 patterns.
add<DropPadUnitDims>(context,
options);
770 patterns.
add<RankReducedExtractSliceOp,
771 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
772 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
774 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
775 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
776 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
777 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
788 patterns.
add<DropPadUnitDims>(context,
options);
790 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
791 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
799 if (
options.rankReductionStrategy ==
802 }
else if (
options.rankReductionStrategy ==
804 ReassociativeReshape) {
811 patterns.
add<MoveInitOperandsToInput>(patterns.
getContext());
816 struct LinalgFoldUnitExtentDimsPass
817 :
public impl::LinalgFoldUnitExtentDimsPassBase<
818 LinalgFoldUnitExtentDimsPass> {
819 using impl::LinalgFoldUnitExtentDimsPassBase<
820 LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
821 void runOnOperation()
override {
826 if (useRankReducingSlices) {
843 getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
845 bool lastDim = pos == rank - 1;
847 for (int64_t i = 0; i < rank - 1; i++) {
848 if (i == pos || (lastDim && i == pos - 1))
856 return reassociation;
865 auto valType = cast<ShapedType>(val.
getType());
867 collapsedShape.erase(collapsedShape.begin() + pos);
869 rewriter, val.
getLoc(), val, collapsedShape,
870 getReassociationForReshapeAtDim(valType.getRank(), pos),
880 template <
typename FromOpTy,
typename ToOpTy>
888 assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
889 "expected 3 operands and dims");
890 return llvm::map_to_vector(
891 llvm::zip(operands, operandCollapseDims), [&](
auto pair) {
892 return collapseSingletonDimAt(rewriter, std::get<0>(pair),
899 RankedTensorType expandedType, int64_t dim)
const {
900 return rewriter.
create<tensor::ExpandShapeOp>(
901 result.
getLoc(), expandedType, result,
902 getReassociationForReshapeAtDim(expandedType.getRank(), dim));
905 LogicalResult matchAndRewrite(FromOpTy contractionOp,
908 auto loc = contractionOp.
getLoc();
909 auto inputs = contractionOp.getDpsInputs();
910 auto inits = contractionOp.getDpsInits();
911 if (inputs.size() != 2 || inits.size() != 1)
913 "expected 2 inputs and 1 init");
914 auto lhs = inputs[0];
915 auto rhs = inputs[1];
916 auto init = inits[0];
920 if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
922 "no reducable dims found");
925 collapseOperands(rewriter, operands, operandUnitDims);
926 Value collapsedLhs = collapsedOperands[0];
927 Value collapsedRhs = collapsedOperands[1];
928 Value collapsedInit = collapsedOperands[2];
930 if (isa<RankedTensorType>(collapsedInit.
getType()))
931 collapsedResultTy.push_back(collapsedInit.
getType());
932 auto collapsedOp = rewriter.
create<ToOpTy>(
933 loc, collapsedResultTy,
ValueRange{collapsedLhs, collapsedRhs},
935 for (
auto attr : contractionOp->getAttrs()) {
936 if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
938 collapsedOp->
setAttr(attr.getName(), attr.getValue());
941 auto results = contractionOp.getResults();
942 assert(results.size() < 2 &&
"expected at most one result");
943 if (results.empty()) {
944 rewriter.
replaceOp(contractionOp, collapsedOp);
948 expandResult(rewriter, collapsedOp.getResultTensors()[0],
949 cast<RankedTensorType>(results[0].getType()),
950 operandUnitDims[2]));
959 virtual LogicalResult
960 getOperandUnitDims(LinalgOp op,
965 template <
typename FromOpTy,
typename ToOpTy>
966 struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
967 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
971 getOperandUnitDims(LinalgOp op,
973 FailureOr<ContractionDimensions> maybeContractionDims =
975 if (failed(maybeContractionDims)) {
976 LLVM_DEBUG(llvm::dbgs() <<
"could not infer contraction dims");
981 if (contractionDims.
batch.size() != 1)
983 auto batchDim = contractionDims.
batch[0];
985 op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
986 if (bOperands.size() != 3 || llvm::any_of(bOperands, [](
auto pair) {
987 return cast<ShapedType>(std::get<0>(pair).getType())
988 .getShape()[std::get<1>(pair)] != 1;
990 LLVM_DEBUG(llvm::dbgs() <<
"specified unit dims not found");
995 std::get<1>(bOperands[1]),
996 std::get<1>(bOperands[2])};
1002 template <
typename FromOpTy,
typename ToOpTy>
1003 struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1004 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1007 static bool constexpr reduceLeft =
1008 (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1009 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1010 (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
1011 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1012 (std::is_same_v<FromOpTy, MatmulOp> &&
1013 std::is_same_v<ToOpTy, VecmatOp>) ||
1014 (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
1015 std::is_same_v<ToOpTy, VecmatOp>) ||
1016 (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1020 getOperandUnitDims(LinalgOp op,
1022 FailureOr<ContractionDimensions> maybeContractionDims =
1024 if (failed(maybeContractionDims)) {
1025 LLVM_DEBUG(llvm::dbgs() <<
"could not infer contraction dims");
1030 if constexpr (reduceLeft) {
1031 auto m = contractionDims.
m[0];
1033 op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1034 if (mOperands.size() != 2)
1036 if (llvm::all_of(mOperands, [](
auto pair) {
1037 return cast<ShapedType>(std::get<0>(pair).
getType())
1038 .getShape()[std::get<1>(pair)] == 1;
1041 std::get<1>(mOperands[1])};
1045 auto n = contractionDims.
n[0];
1047 op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1048 if (nOperands.size() != 2)
1050 if (llvm::all_of(nOperands, [](
auto pair) {
1051 return cast<ShapedType>(std::get<0>(pair).
getType())
1052 .getShape()[std::get<1>(pair)] == 1;
1055 std::get<1>(nOperands[1])};
1059 LLVM_DEBUG(llvm::dbgs() <<
"specified unit dims not found");
1070 patterns.
add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1072 .
add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
1075 .
add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
1077 patterns.
add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1078 patterns.
add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1081 patterns.
add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1082 patterns.
add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1083 patterns.
add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
1084 patterns.
add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
1086 patterns.
add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1087 patterns.
add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1088 patterns.
add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
1090 patterns.
add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
1094 patterns.
add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1095 patterns.
add<RankReduceMatmul<VecmatOp, DotOp>>(context);
static Value expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Expand the given value so that the type matches the type of origDest.
static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet< unsigned > &unitDims, RewriterBase &rewriter)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, llvm::SmallDenseMap< unsigned, unsigned > &oldDimsToNewDimsMap, ArrayRef< AffineExpr > dimReplacements)
static void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns that are used to canonicalize the use of unit-extent dims for broadcasting.
static void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
static Value collapseValue(RewriterBase &rewriter, Location loc, Value operand, ArrayRef< int64_t > targetShape, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Collapse the given value so that the type matches the type of origOutput.
static llvm::ManagedStatic< PassManagerOptions > options
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
ArrayRef< AffineExpr > getResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns)
A pattern that converts init operands to input operands.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns)
Adds patterns that reduce the rank of named contraction ops that have unit dimensions in the operand(...
FailureOr< DropUnitDimsResult > dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, const ControlDropUnitDims &options)
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Compute the modified metadata for an operands of operation whose unit dims are being dropped.
SmallVector< ReassociationIndices > reassociation
SmallVector< int64_t > targetShape
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
Transformation to drop unit-extent dimensions from linalg.generic operations.