32#include "llvm/Support/Debug.h"
35#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
36#include "mlir/Dialect/Linalg/Passes.h.inc"
39#define DEBUG_TYPE "linalg-drop-unit-dims"
83 LogicalResult matchAndRewrite(GenericOp genericOp,
85 if (!genericOp.hasPureTensorSemantics())
87 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
90 auto outputOperands = genericOp.getDpsInitsMutable();
93 if (genericOp.getMatchingBlockArgument(&op).use_empty())
95 candidates.insert(&op);
98 if (candidates.empty())
102 int64_t origNumInput = genericOp.getNumDpsInputs();
106 newIndexingMaps.append(indexingMaps.begin(),
107 std::next(indexingMaps.begin(), origNumInput));
109 newInputOperands.push_back(op->get());
110 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
112 newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
117 llvm::to_vector(genericOp.getDpsInits());
121 auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
122 auto empty = tensor::EmptyOp::create(
126 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
127 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
130 auto newOp = GenericOp::create(
131 rewriter, loc, genericOp.getResultTypes(), newInputOperands,
132 newOutputOperands, newIndexingMaps, genericOp.getIteratorTypesArray(),
136 Region ®ion = newOp.getRegion();
139 for (
auto bbarg : genericOp.getRegionInputArgs())
143 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
148 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
149 if (candidates.count(&op))
155 for (
auto &op : genericOp.getBody()->getOperations()) {
156 rewriter.
clone(op, mapper);
158 rewriter.
replaceOp(genericOp, newOp.getResults());
229 const llvm::SmallDenseSet<unsigned> &unitDims,
231 for (IndexOp indexOp :
232 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
235 if (unitDims.count(indexOp.getDim()) != 0) {
239 unsigned droppedDims = llvm::count_if(
240 unitDims, [&](
unsigned dim) {
return dim < indexOp.getDim(); });
241 if (droppedDims != 0)
243 indexOp.getDim() - droppedDims);
254 auto origResultType = cast<RankedTensorType>(origDest.
getType());
255 if (origResultType.getEncoding() !=
nullptr) {
261 unsigned rank = origResultType.getRank();
262 SmallVector<OpFoldResult> offsets(rank, rewriter.
getIndexAttr(0));
263 SmallVector<OpFoldResult> sizes =
265 SmallVector<OpFoldResult> strides(rank, rewriter.
getIndexAttr(1));
267 loc,
result, origDest, offsets, sizes, strides);
272 "unknown rank reduction strategy");
273 return tensor::ExpandShapeOp::create(rewriter, loc, origResultType,
result,
279ControlDropUnitDims::collapseValue(RewriterBase &rewriter, Location loc,
280 Value operand, ArrayRef<int64_t> targetShape,
281 ArrayRef<ReassociationIndices> reassociation,
283 if (
auto memrefType = dyn_cast<MemRefType>(operand.
getType())) {
284 if (!memrefType.getLayout().isIdentity()) {
290 FailureOr<Value> rankReducingExtract =
291 memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
293 assert(succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
294 return *rankReducingExtract;
300 "unknown rank reduction strategy");
301 MemRefLayoutAttrInterface layout;
302 auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
303 layout, memrefType.getMemorySpace());
304 return memref::CollapseShapeOp::create(rewriter, loc, targetType, operand,
308 if (
auto tensorType = dyn_cast<RankedTensorType>(operand.
getType())) {
309 if (tensorType.getEncoding() !=
nullptr) {
315 FailureOr<Value> rankReducingExtract =
316 tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
318 assert(succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
319 return *rankReducingExtract;
325 "unknown rank reduction strategy");
327 RankedTensorType::get(targetShape, tensorType.getElementType());
328 return tensor::CollapseShapeOp::create(rewriter, loc, targetType, operand,
332 llvm_unreachable(
"unsupported operand type");
347 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
352 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
356 auto isUnitDim = [&](
unsigned dim) {
357 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
358 unsigned oldPosition = dimExpr.getPosition();
359 return !oldDimsToNewDimsMap.count(oldPosition) &&
360 (operandShape[dim] == 1);
364 if (operandShape[dim] == 1) {
369 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
370 auto constAffineExpr = dyn_cast<AffineConstantExpr>(newExpr);
371 return constAffineExpr && constAffineExpr.getValue() == 0;
377 while (dim < operandShape.size() && isUnitDim(dim))
378 reassociationGroup.push_back(dim++);
379 while (dim < operandShape.size()) {
380 assert(!isUnitDim(dim) &&
"expected non unit-extent");
381 reassociationGroup.push_back(dim);
382 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
383 newIndexExprs.push_back(newExpr);
387 while (dim < operandShape.size() && isUnitDim(dim)) {
388 reassociationGroup.push_back(dim++);
391 reassociationGroup.clear();
395 newIndexExprs, context);
399FailureOr<DropUnitDimsResult>
403 auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation());
406 op,
"op should implement DestinationStyleOpInterface");
410 if (indexingMaps.empty())
420 "invalid indexing maps for operation");
424 for (
OpOperand &opOperand : op->getOpOperands())
425 llvm::append_range(allShapesSizes, op.getStaticOperandShape(&opOperand));
429 if (allowedUnitDims.empty()) {
431 op,
"control function returns no allowed unit dims to prune");
433 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
434 allowedUnitDims.end());
435 llvm::SmallDenseSet<unsigned> unitDims;
436 for (
const auto &expr : enumerate(invertedMap.
getResults())) {
437 if (
AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
438 if (allShapesSizes[dimExpr.getPosition()] == 1 &&
439 unitDimsFilter.count(expr.index()))
440 unitDims.insert(expr.index());
446 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
448 unsigned newDims = 0;
449 for (
auto index : llvm::seq<int64_t>(op.getStaticLoopRanges().size())) {
450 if (unitDims.count(
index)) {
451 dimReplacements.push_back(
454 oldDimToNewDimMap[
index] = newDims;
455 dimReplacements.push_back(
476 for (
OpOperand &opOperand : op->getOpOperands()) {
477 auto indexingMap = op.getMatchingIndexingMap(&opOperand);
478 auto replacementInfo =
480 oldDimToNewDimMap, dimReplacements);
481 reassociations.push_back(replacementInfo.reassociation);
482 newIndexingMaps.push_back(replacementInfo.indexMap);
483 targetShapes.push_back(replacementInfo.targetShape);
484 collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
485 indexingMap.getNumResults()));
490 if (newIndexingMaps == indexingMaps ||
502 for (
OpOperand &opOperand : op->getOpOperands()) {
503 int64_t idx = opOperand.getOperandNumber();
504 if (!collapsed[idx]) {
505 newOperands.push_back(opOperand.get());
508 FailureOr<Value> collapsed =
509 options.collapseFn(rewriter, loc, opOperand.get(), targetShapes[idx],
511 if (failed(collapsed)) {
515 newOperands.push_back(collapsed.value());
518 IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
519 loc, rewriter, op, newOperands, newIndexingMaps, unitDims);
526 for (
auto [
index,
result] : llvm::enumerate(replacementOp->getResults())) {
527 unsigned opOperandIndex =
index + dpsOp.getNumDpsInputs();
528 Value origDest = dpsOp.getDpsInitOperand(
index)->get();
529 if (!collapsed[opOperandIndex]) {
530 resultReplacements.push_back(
result);
533 FailureOr<Value> expanded =
535 reassociations[opOperandIndex],
options);
536 if (failed(expanded)) {
540 resultReplacements.push_back(expanded.value());
546FailureOr<DropUnitDimsResult>
553 const llvm::SmallDenseSet<unsigned> &droppedDims)
554 -> IndexingMapOpInterface {
555 auto genericOp = cast<GenericOp>(op);
559 for (
auto [
index, attr] :
560 llvm::enumerate(genericOp.getIteratorTypesArray())) {
561 if (!droppedDims.count(
index))
562 newIteratorTypes.push_back(attr);
572 resultTypes.reserve(genericOp.getNumResults());
573 for (
unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
574 resultTypes.push_back(newOutputs[i].
getType());
575 GenericOp replacementOp =
576 GenericOp::create(
b, loc, resultTypes, newInputs, newOutputs,
577 newIndexingMaps, newIteratorTypes);
578 b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
579 replacementOp.getRegion().begin());
585 return replacementOp;
597 LogicalResult matchAndRewrite(GenericOp genericOp,
598 PatternRewriter &rewriter)
const override {
599 FailureOr<DropUnitDimsResult>
result =
618struct DropPadUnitDims :
public OpRewritePattern<tensor::PadOp> {
619 DropPadUnitDims(MLIRContext *context, ControlDropUnitDims
options = {},
620 PatternBenefit benefit = 1)
621 : OpRewritePattern(context, benefit),
options(std::move(
options)) {}
623 LogicalResult matchAndRewrite(tensor::PadOp padOp,
624 PatternRewriter &rewriter)
const override {
626 SmallVector<unsigned> allowedUnitDims =
options.controlFn(padOp);
627 if (allowedUnitDims.empty()) {
629 padOp,
"control function returns no allowed unit dims to prune");
632 if (padOp.getSourceType().getEncoding()) {
634 padOp,
"cannot collapse dims of tensor with encoding");
641 Value paddingVal = padOp.getConstantPaddingValue();
644 padOp,
"unimplemented: non-constant padding value");
647 ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
648 ArrayRef<int64_t> resultShape = padOp.getResultType().getShape();
649 int64_t padRank = sourceShape.size();
651 auto isStaticZero = [](OpFoldResult f) {
655 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
656 allowedUnitDims.end());
657 llvm::SmallDenseSet<unsigned> unitDims;
658 SmallVector<int64_t> newShape;
659 SmallVector<int64_t> newResultShape;
660 SmallVector<OpFoldResult> newLowPad;
661 SmallVector<OpFoldResult> newHighPad;
662 for (
const auto [dim, size, outSize, low, high] : zip_equal(
663 llvm::seq(
static_cast<int64_t
>(0), padRank), sourceShape,
664 resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
665 if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
666 isStaticZero(high)) {
667 unitDims.insert(dim);
669 newShape.push_back(size);
670 newResultShape.push_back(outSize);
671 newLowPad.push_back(low);
672 newHighPad.push_back(high);
676 if (unitDims.empty()) {
681 SmallVector<ReassociationIndices> reassociationMap;
683 while (dim < padRank && unitDims.contains(dim))
684 reassociationGroup.push_back(dim++);
685 while (dim < padRank) {
686 assert(!unitDims.contains(dim) &&
"expected non unit-extent");
687 reassociationGroup.push_back(dim);
690 while (dim < padRank && unitDims.contains(dim))
691 reassociationGroup.push_back(dim++);
692 reassociationMap.push_back(reassociationGroup);
693 reassociationGroup.clear();
696 FailureOr<Value> collapsedSource =
697 options.collapseFn(rewriter, padOp.getLoc(), padOp.getSource(),
698 newShape, reassociationMap,
options);
699 if (
failed(collapsedSource)) {
703 auto newResultType = RankedTensorType::get(
704 newResultShape, padOp.getResultType().getElementType());
705 auto newPadOp = tensor::PadOp::create(
706 rewriter, padOp.getLoc(), newResultType,
707 collapsedSource.value(), newLowPad, newHighPad, paddingVal,
710 Value dest = padOp.getResult();
711 if (
options.rankReductionStrategy ==
712 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
713 SmallVector<OpFoldResult> expandedSizes;
714 int64_t numUnitDims = 0;
715 for (
auto dim : llvm::seq(
static_cast<int64_t
>(0), padRank)) {
716 if (unitDims.contains(dim)) {
722 rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
724 dest = tensor::EmptyOp::create(rewriter, padOp.getLoc(), expandedSizes,
725 padOp.getResultType().getElementType());
728 FailureOr<Value> expandedValue =
729 options.expandFn(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
731 if (failed(expandedValue)) {
734 rewriter.
replaceOp(padOp, expandedValue.value());
745struct RankReducedExtractSliceOp
749 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
751 RankedTensorType resultType = sliceOp.getType();
753 for (
auto size : resultType.getShape())
756 if (!reassociation ||
757 reassociation->size() ==
static_cast<size_t>(resultType.getRank()))
763 auto rankReducedType = cast<RankedTensorType>(
764 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
765 reassociation->size(), sliceOp.getSourceType(), sizes));
768 Value newSlice = tensor::ExtractSliceOp::create(
769 rewriter, loc, rankReducedType, sliceOp.getSource(), offsets, sizes,
772 sliceOp, resultType, newSlice, *reassociation);
779template <
typename InsertOpTy>
783 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
785 RankedTensorType sourceType = insertSliceOp.getSourceType();
787 for (
auto size : sourceType.getShape())
790 if (!reassociation ||
791 reassociation->size() ==
static_cast<size_t>(sourceType.getRank()))
794 Location loc = insertSliceOp.getLoc();
795 tensor::CollapseShapeOp reshapedSource;
801 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
803 reshapedSource = tensor::CollapseShapeOp::create(
804 rewriter, loc, insertSliceOp.getSource(), *reassociation);
807 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
808 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
809 insertSliceOp.getMixedStrides());
819 auto *context =
patterns.getContext();
826 auto *context =
patterns.getContext();
827 bool reassociativeReshape =
828 options.rankReductionStrategy ==
830 if (reassociativeReshape) {
831 patterns.add<RankReducedExtractSliceOp,
832 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
833 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
835 tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);
836 tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
838 linalg::FillOp::getCanonicalizationPatterns(
patterns, context);
839 tensor::EmptyOp::getCanonicalizationPatterns(
patterns, context);
852struct LinalgFoldUnitExtentDimsPass
854 LinalgFoldUnitExtentDimsPass> {
856 LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
857 void runOnOperation()
override {
861 if (useRankReducingSlices) {
862 options.rankReductionStrategy = linalg::ControlDropUnitDims::
863 RankReductionStrategy::ExtractInsertSlice;
875 RewritePatternSet
patterns(context);
890static SmallVector<ReassociationIndices>
891getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
892 SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
893 bool lastDim = pos == rank - 1;
895 for (int64_t i = 0; i < rank - 1; i++) {
896 if (i == pos || (lastDim && i == pos - 1))
904 return reassociation;
909static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
913 auto valType = cast<ShapedType>(val.
getType());
914 SmallVector<int64_t> collapsedShape(valType.getShape());
915 collapsedShape.erase(collapsedShape.begin() + pos);
917 FailureOr<Value> collapsed = control.
collapseFn(
918 rewriter, val.
getLoc(), val, collapsedShape,
919 getReassociationForReshapeAtDim(valType.getRank(), pos), control);
920 assert(llvm::succeeded(collapsed) &&
"Collapsing the value failed");
921 return collapsed.value();
930template <
typename FromOpTy,
typename ToOpTy>
931struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
932 using OpRewritePattern<FromOpTy>::OpRewritePattern;
936 collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
937 ArrayRef<int64_t> operandCollapseDims)
const {
938 assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
939 "expected 3 operands and dims");
940 return llvm::map_to_vector(
941 llvm::zip(operands, operandCollapseDims), [&](
auto pair) {
942 return collapseSingletonDimAt(rewriter, std::get<0>(pair),
948 Value expandResult(PatternRewriter &rewriter, Value
result,
949 RankedTensorType expandedType, int64_t dim)
const {
950 return tensor::ExpandShapeOp::create(
952 getReassociationForReshapeAtDim(expandedType.getRank(), dim));
955 LogicalResult matchAndRewrite(FromOpTy contractionOp,
956 PatternRewriter &rewriter)
const override {
957 if (contractionOp.hasUserDefinedMaps()) {
959 contractionOp,
"ops with user-defined maps are not supported");
962 auto loc = contractionOp.getLoc();
963 auto inputs = contractionOp.getDpsInputs();
964 auto inits = contractionOp.getDpsInits();
965 if (inputs.size() != 2 || inits.size() != 1)
967 "expected 2 inputs and 1 init");
968 auto lhs = inputs[0];
969 auto rhs = inputs[1];
970 auto init = inits[0];
971 SmallVector<Value> operands{
lhs,
rhs, init};
973 SmallVector<int64_t> operandUnitDims;
974 if (
failed(getOperandUnitDims(contractionOp, operandUnitDims)))
976 "no reducable dims found");
978 SmallVector<Value> collapsedOperands =
979 collapseOperands(rewriter, operands, operandUnitDims);
980 Value collapsedLhs = collapsedOperands[0];
981 Value collapsedRhs = collapsedOperands[1];
982 Value collapsedInit = collapsedOperands[2];
983 SmallVector<Type, 1> collapsedResultTy;
984 if (isa<RankedTensorType>(collapsedInit.
getType()))
985 collapsedResultTy.push_back(collapsedInit.
getType());
986 auto collapsedOp = ToOpTy::create(rewriter, loc, collapsedResultTy,
989 for (
auto attr : contractionOp->getAttrs()) {
990 if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
991 attr.getName() ==
"indexing_maps")
993 collapsedOp->setAttr(attr.getName(), attr.getValue());
996 auto results = contractionOp.getResults();
997 assert(results.size() < 2 &&
"expected at most one result");
998 if (results.empty()) {
999 rewriter.
replaceOp(contractionOp, collapsedOp);
1003 expandResult(rewriter, collapsedOp.getResultTensors()[0],
1004 cast<RankedTensorType>(results[0].getType()),
1005 operandUnitDims[2]));
1014 virtual LogicalResult
1015 getOperandUnitDims(LinalgOp op,
1016 SmallVectorImpl<int64_t> &operandUnitDims)
const = 0;
1020template <
typename FromOpTy,
typename ToOpTy>
1021struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
1022 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1026 getOperandUnitDims(LinalgOp op,
1027 SmallVectorImpl<int64_t> &operandUnitDims)
const override {
1028 FailureOr<ContractionDimensions> maybeContractionDims =
1030 if (
failed(maybeContractionDims)) {
1031 LLVM_DEBUG(llvm::dbgs() <<
"could not infer contraction dims");
1034 const ContractionDimensions &contractionDims = maybeContractionDims.value();
1036 if (contractionDims.
batch.size() != 1)
1038 auto batchDim = contractionDims.
batch[0];
1039 SmallVector<std::pair<Value, unsigned>, 3> bOperands;
1040 op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
1041 if (bOperands.size() != 3 || llvm::any_of(bOperands, [](
auto pair) {
1042 return cast<ShapedType>(std::get<0>(pair).getType())
1043 .getShape()[std::get<1>(pair)] != 1;
1045 LLVM_DEBUG(llvm::dbgs() <<
"specified unit dims not found");
1049 operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
1050 std::get<1>(bOperands[1]),
1051 std::get<1>(bOperands[2])};
1057template <
typename FromOpTy,
typename ToOpTy>
1058struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1059 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1062 static bool constexpr reduceLeft =
1063 (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1064 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1065 (std::is_same_v<FromOpTy, MatmulOp> &&
1066 std::is_same_v<ToOpTy, VecmatOp>) ||
1067 (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1071 getOperandUnitDims(LinalgOp op,
1072 SmallVectorImpl<int64_t> &operandUnitDims)
const override {
1073 FailureOr<ContractionDimensions> maybeContractionDims =
1075 if (
failed(maybeContractionDims)) {
1076 LLVM_DEBUG(llvm::dbgs() <<
"could not infer contraction dims");
1079 const ContractionDimensions &contractionDims = maybeContractionDims.value();
1081 if constexpr (reduceLeft) {
1082 auto m = contractionDims.
m[0];
1083 SmallVector<std::pair<Value, unsigned>, 2> mOperands;
1084 op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1085 if (mOperands.size() != 2)
1087 if (llvm::all_of(mOperands, [](
auto pair) {
1088 return cast<ShapedType>(std::get<0>(pair).
getType())
1089 .getShape()[std::get<1>(pair)] == 1;
1091 operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
1092 std::get<1>(mOperands[1])};
1096 auto n = contractionDims.
n[0];
1097 SmallVector<std::pair<Value, unsigned>, 2> nOperands;
1098 op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1099 if (nOperands.size() != 2)
1101 if (llvm::all_of(nOperands, [](
auto pair) {
1102 return cast<ShapedType>(std::get<0>(pair).
getType())
1103 .getShape()[std::get<1>(pair)] == 1;
1105 operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
1106 std::get<1>(nOperands[1])};
1110 LLVM_DEBUG(llvm::dbgs() <<
"specified unit dims not found");
1121 patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1122 patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1123 patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1126 patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1127 patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1129 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1130 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1133 patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1134 patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet< unsigned > &unitDims, RewriterBase &rewriter)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand, llvm::SmallDenseMap< unsigned, unsigned > &oldDimsToNewDimsMap, ArrayRef< AffineExpr > dimReplacements)
static llvm::ManagedStatic< PassManagerOptions > options
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
ArrayRef< AffineExpr > getResults() const
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Specialization of arith.constant op that returns an integer of index type.
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns)
A pattern that converts init operands to input operands.
std::function< IndexingMapOpInterface( Location loc, OpBuilder &, IndexingMapOpInterface, ArrayRef< Value > newOperands, ArrayRef< AffineMap > newIndexingMaps, const llvm::SmallDenseSet< unsigned > &droppedDims)> DroppedUnitDimsBuilder
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns)
Adds patterns that reduce the rank of named contraction ops that have unit dimensions in the operand(...
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors and memref.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< DropUnitDimsResult > dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op, const DroppedUnitDimsBuilder &droppedUnitDimsBuilder, const ControlDropUnitDims &options)
Drop unit extent dimensions from the op and its operands.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
void populateFoldUnitExtentDimsCanonicalizationPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Populates canonicalization patterns that simplify IR after folding unit-extent dimensions.
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
SmallVector< int64_t, 2 > ReassociationIndices
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Compute the modified metadata for an operands of operation whose unit dims are being dropped.
SmallVector< ReassociationIndices > reassociation
SmallVector< int64_t > targetShape
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
Transformation to drop unit-extent dimensions from linalg.generic operations.
RankReductionStrategy rankReductionStrategy
CollapseFnTy collapseFn
Function to control how operands are collapsed into their new target shape after dropping unit extent...