32#include "llvm/Support/Debug.h"
35#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
36#include "mlir/Dialect/Linalg/Passes.h.inc"
39#define DEBUG_TYPE "linalg-drop-unit-dims"
83 LogicalResult matchAndRewrite(GenericOp genericOp,
85 if (!genericOp.hasPureTensorSemantics())
87 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
90 auto outputOperands = genericOp.getDpsInitsMutable();
93 if (genericOp.getMatchingBlockArgument(&op).use_empty())
95 candidates.insert(&op);
98 if (candidates.empty())
102 int64_t origNumInput = genericOp.getNumDpsInputs();
106 newIndexingMaps.append(indexingMaps.begin(),
107 std::next(indexingMaps.begin(), origNumInput));
109 newInputOperands.push_back(op->get());
110 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
112 newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
117 llvm::to_vector(genericOp.getDpsInits());
121 auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
122 auto empty = tensor::EmptyOp::create(
126 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
127 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
130 auto newOp = GenericOp::create(
131 rewriter, loc, genericOp.getResultTypes(), newInputOperands,
132 newOutputOperands, newIndexingMaps, genericOp.getIteratorTypesArray(),
136 Region ®ion = newOp.getRegion();
139 for (
auto bbarg : genericOp.getRegionInputArgs())
143 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
148 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
149 if (candidates.count(&op))
155 for (
auto &op : genericOp.getBody()->getOperations()) {
156 rewriter.
clone(op, mapper);
158 rewriter.
replaceOp(genericOp, newOp.getResults());
229 const llvm::SmallDenseSet<unsigned> &unitDims,
231 for (IndexOp indexOp :
232 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
235 if (unitDims.count(indexOp.getDim()) != 0) {
239 unsigned droppedDims = llvm::count_if(
240 unitDims, [&](
unsigned dim) {
return dim < indexOp.getDim(); });
241 if (droppedDims != 0)
243 indexOp.getDim() - droppedDims);
254 auto origResultType = cast<RankedTensorType>(origDest.
getType());
255 if (origResultType.getEncoding() !=
nullptr) {
261 unsigned rank = origResultType.getRank();
262 SmallVector<OpFoldResult> offsets(rank, rewriter.
getIndexAttr(0));
263 SmallVector<OpFoldResult> sizes =
265 SmallVector<OpFoldResult> strides(rank, rewriter.
getIndexAttr(1));
267 loc,
result, origDest, offsets, sizes, strides);
272 "unknown rank reduction strategy");
273 return tensor::ExpandShapeOp::create(rewriter, loc, origResultType,
result,
279ControlDropUnitDims::collapseValue(RewriterBase &rewriter, Location loc,
280 Value operand, ArrayRef<int64_t> targetShape,
281 ArrayRef<ReassociationIndices> reassociation,
283 if (
auto memrefType = dyn_cast<MemRefType>(operand.
getType())) {
284 if (!memrefType.getLayout().isIdentity()) {
290 FailureOr<Value> rankReducingExtract =
291 memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
293 assert(succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
294 return *rankReducingExtract;
300 "unknown rank reduction strategy");
301 MemRefLayoutAttrInterface layout;
302 auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
303 layout, memrefType.getMemorySpace());
304 return memref::CollapseShapeOp::create(rewriter, loc, targetType, operand,
308 if (
auto tensorType = dyn_cast<RankedTensorType>(operand.
getType())) {
309 if (tensorType.getEncoding() !=
nullptr) {
315 FailureOr<Value> rankReducingExtract =
316 tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
318 assert(succeeded(rankReducingExtract) &&
"not a unit-extent collapse");
319 return *rankReducingExtract;
325 "unknown rank reduction strategy");
327 RankedTensorType::get(targetShape, tensorType.getElementType());
328 return tensor::CollapseShapeOp::create(rewriter, loc, targetType, operand,
332 llvm_unreachable(
"unsupported operand type");
347 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
352 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
356 auto isUnitDim = [&](
unsigned dim) {
357 if (
auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
358 unsigned oldPosition = dimExpr.getPosition();
359 return !oldDimsToNewDimsMap.count(oldPosition) &&
360 (operandShape[dim] == 1);
364 if (operandShape[dim] == 1) {
365 auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
366 return constAffineExpr && constAffineExpr.getValue() == 0;
372 while (dim < operandShape.size() && isUnitDim(dim))
373 reassociationGroup.push_back(dim++);
374 while (dim < operandShape.size()) {
375 assert(!isUnitDim(dim) &&
"expected non unit-extent");
376 reassociationGroup.push_back(dim);
377 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
378 newIndexExprs.push_back(newExpr);
382 while (dim < operandShape.size() && isUnitDim(dim)) {
383 reassociationGroup.push_back(dim++);
386 reassociationGroup.clear();
390 newIndexExprs, context);
394FailureOr<DropUnitDimsResult>
398 auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation());
401 op,
"op should implement DestinationStyleOpInterface");
405 if (indexingMaps.empty())
415 "invalid indexing maps for operation");
419 for (
OpOperand &opOperand : op->getOpOperands())
420 llvm::append_range(allShapesSizes, op.getStaticOperandShape(&opOperand));
424 if (allowedUnitDims.empty()) {
426 op,
"control function returns no allowed unit dims to prune");
428 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
429 allowedUnitDims.end());
430 llvm::SmallDenseSet<unsigned> unitDims;
431 for (
const auto &expr : enumerate(invertedMap.
getResults())) {
432 if (
AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
433 if (allShapesSizes[dimExpr.getPosition()] == 1 &&
434 unitDimsFilter.count(expr.index()))
435 unitDims.insert(expr.index());
441 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
443 unsigned newDims = 0;
444 for (
auto index : llvm::seq<int64_t>(op.getStaticLoopRanges().size())) {
445 if (unitDims.count(
index)) {
446 dimReplacements.push_back(
449 oldDimToNewDimMap[
index] = newDims;
450 dimReplacements.push_back(
471 for (
OpOperand &opOperand : op->getOpOperands()) {
472 auto indexingMap = op.getMatchingIndexingMap(&opOperand);
473 auto replacementInfo =
475 oldDimToNewDimMap, dimReplacements);
476 reassociations.push_back(replacementInfo.reassociation);
477 newIndexingMaps.push_back(replacementInfo.indexMap);
478 targetShapes.push_back(replacementInfo.targetShape);
479 collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
480 indexingMap.getNumResults()));
485 if (newIndexingMaps == indexingMaps ||
497 for (
OpOperand &opOperand : op->getOpOperands()) {
498 int64_t idx = opOperand.getOperandNumber();
499 if (!collapsed[idx]) {
500 newOperands.push_back(opOperand.get());
503 FailureOr<Value> collapsed =
504 options.collapseFn(rewriter, loc, opOperand.get(), targetShapes[idx],
506 if (failed(collapsed)) {
510 newOperands.push_back(collapsed.value());
513 IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
514 loc, rewriter, op, newOperands, newIndexingMaps, unitDims);
521 for (
auto [
index,
result] : llvm::enumerate(replacementOp->getResults())) {
522 unsigned opOperandIndex =
index + dpsOp.getNumDpsInputs();
523 Value origDest = dpsOp.getDpsInitOperand(
index)->get();
524 if (!collapsed[opOperandIndex]) {
525 resultReplacements.push_back(
result);
528 FailureOr<Value> expanded =
530 reassociations[opOperandIndex],
options);
531 if (failed(expanded)) {
535 resultReplacements.push_back(expanded.value());
541FailureOr<DropUnitDimsResult>
548 const llvm::SmallDenseSet<unsigned> &droppedDims)
549 -> IndexingMapOpInterface {
550 auto genericOp = cast<GenericOp>(op);
554 for (
auto [
index, attr] :
555 llvm::enumerate(genericOp.getIteratorTypesArray())) {
556 if (!droppedDims.count(
index))
557 newIteratorTypes.push_back(attr);
567 resultTypes.reserve(genericOp.getNumResults());
568 for (
unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
569 resultTypes.push_back(newOutputs[i].
getType());
570 GenericOp replacementOp =
571 GenericOp::create(
b, loc, resultTypes, newInputs, newOutputs,
572 newIndexingMaps, newIteratorTypes);
573 b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
574 replacementOp.getRegion().begin());
580 return replacementOp;
592 LogicalResult matchAndRewrite(GenericOp genericOp,
593 PatternRewriter &rewriter)
const override {
594 FailureOr<DropUnitDimsResult>
result =
613struct DropPadUnitDims :
public OpRewritePattern<tensor::PadOp> {
614 DropPadUnitDims(MLIRContext *context, ControlDropUnitDims
options = {},
615 PatternBenefit benefit = 1)
616 : OpRewritePattern(context, benefit),
options(std::move(
options)) {}
618 LogicalResult matchAndRewrite(tensor::PadOp padOp,
619 PatternRewriter &rewriter)
const override {
621 SmallVector<unsigned> allowedUnitDims =
options.controlFn(padOp);
622 if (allowedUnitDims.empty()) {
624 padOp,
"control function returns no allowed unit dims to prune");
627 if (padOp.getSourceType().getEncoding()) {
629 padOp,
"cannot collapse dims of tensor with encoding");
636 Value paddingVal = padOp.getConstantPaddingValue();
639 padOp,
"unimplemented: non-constant padding value");
642 ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
643 ArrayRef<int64_t> resultShape = padOp.getResultType().getShape();
644 int64_t padRank = sourceShape.size();
646 auto isStaticZero = [](OpFoldResult f) {
650 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
651 allowedUnitDims.end());
652 llvm::SmallDenseSet<unsigned> unitDims;
653 SmallVector<int64_t> newShape;
654 SmallVector<int64_t> newResultShape;
655 SmallVector<OpFoldResult> newLowPad;
656 SmallVector<OpFoldResult> newHighPad;
657 for (
const auto [dim, size, outSize, low, high] : zip_equal(
658 llvm::seq(
static_cast<int64_t
>(0), padRank), sourceShape,
659 resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
660 if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
661 isStaticZero(high)) {
662 unitDims.insert(dim);
664 newShape.push_back(size);
665 newResultShape.push_back(outSize);
666 newLowPad.push_back(low);
667 newHighPad.push_back(high);
671 if (unitDims.empty()) {
676 SmallVector<ReassociationIndices> reassociationMap;
678 while (dim < padRank && unitDims.contains(dim))
679 reassociationGroup.push_back(dim++);
680 while (dim < padRank) {
681 assert(!unitDims.contains(dim) &&
"expected non unit-extent");
682 reassociationGroup.push_back(dim);
685 while (dim < padRank && unitDims.contains(dim))
686 reassociationGroup.push_back(dim++);
687 reassociationMap.push_back(reassociationGroup);
688 reassociationGroup.clear();
691 FailureOr<Value> collapsedSource =
692 options.collapseFn(rewriter, padOp.getLoc(), padOp.getSource(),
693 newShape, reassociationMap,
options);
694 if (
failed(collapsedSource)) {
698 auto newResultType = RankedTensorType::get(
699 newResultShape, padOp.getResultType().getElementType());
700 auto newPadOp = tensor::PadOp::create(
701 rewriter, padOp.getLoc(), newResultType,
702 collapsedSource.value(), newLowPad, newHighPad, paddingVal,
705 Value dest = padOp.getResult();
706 if (
options.rankReductionStrategy ==
707 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
708 SmallVector<OpFoldResult> expandedSizes;
709 int64_t numUnitDims = 0;
710 for (
auto dim : llvm::seq(
static_cast<int64_t
>(0), padRank)) {
711 if (unitDims.contains(dim)) {
717 rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
719 dest = tensor::EmptyOp::create(rewriter, padOp.getLoc(), expandedSizes,
720 padOp.getResultType().getElementType());
723 FailureOr<Value> expandedValue =
724 options.expandFn(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
726 if (failed(expandedValue)) {
729 rewriter.
replaceOp(padOp, expandedValue.value());
740struct RankReducedExtractSliceOp
744 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
746 RankedTensorType resultType = sliceOp.getType();
748 for (
auto size : resultType.getShape())
751 if (!reassociation ||
752 reassociation->size() ==
static_cast<size_t>(resultType.getRank()))
758 auto rankReducedType = cast<RankedTensorType>(
759 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
760 reassociation->size(), sliceOp.getSourceType(), sizes));
763 Value newSlice = tensor::ExtractSliceOp::create(
764 rewriter, loc, rankReducedType, sliceOp.getSource(), offsets, sizes,
767 sliceOp, resultType, newSlice, *reassociation);
774template <
typename InsertOpTy>
778 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
780 RankedTensorType sourceType = insertSliceOp.getSourceType();
782 for (
auto size : sourceType.getShape())
785 if (!reassociation ||
786 reassociation->size() ==
static_cast<size_t>(sourceType.getRank()))
789 Location loc = insertSliceOp.getLoc();
790 tensor::CollapseShapeOp reshapedSource;
796 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
798 reshapedSource = tensor::CollapseShapeOp::create(
799 rewriter, loc, insertSliceOp.getSource(), *reassociation);
802 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
803 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
804 insertSliceOp.getMixedStrides());
814 auto *context =
patterns.getContext();
821 auto *context =
patterns.getContext();
822 bool reassociativeReshape =
823 options.rankReductionStrategy ==
825 if (reassociativeReshape) {
826 patterns.add<RankReducedExtractSliceOp,
827 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
828 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
830 tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);
831 tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
833 linalg::FillOp::getCanonicalizationPatterns(
patterns, context);
834 tensor::EmptyOp::getCanonicalizationPatterns(
patterns, context);
847struct LinalgFoldUnitExtentDimsPass
849 LinalgFoldUnitExtentDimsPass> {
851 LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
852 void runOnOperation()
override {
856 if (useRankReducingSlices) {
857 options.rankReductionStrategy = linalg::ControlDropUnitDims::
858 RankReductionStrategy::ExtractInsertSlice;
870 RewritePatternSet
patterns(context);
885static SmallVector<ReassociationIndices>
886getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
887 SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
888 bool lastDim = pos == rank - 1;
890 for (int64_t i = 0; i < rank - 1; i++) {
891 if (i == pos || (lastDim && i == pos - 1))
899 return reassociation;
904static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
908 auto valType = cast<ShapedType>(val.
getType());
909 SmallVector<int64_t> collapsedShape(valType.getShape());
910 collapsedShape.erase(collapsedShape.begin() + pos);
912 FailureOr<Value> collapsed = control.
collapseFn(
913 rewriter, val.
getLoc(), val, collapsedShape,
914 getReassociationForReshapeAtDim(valType.getRank(), pos), control);
915 assert(llvm::succeeded(collapsed) &&
"Collapsing the value failed");
916 return collapsed.value();
925template <
typename FromOpTy,
typename ToOpTy>
926struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
927 using OpRewritePattern<FromOpTy>::OpRewritePattern;
931 collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
932 ArrayRef<int64_t> operandCollapseDims)
const {
933 assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
934 "expected 3 operands and dims");
935 return llvm::map_to_vector(
936 llvm::zip(operands, operandCollapseDims), [&](
auto pair) {
937 return collapseSingletonDimAt(rewriter, std::get<0>(pair),
943 Value expandResult(PatternRewriter &rewriter, Value
result,
944 RankedTensorType expandedType, int64_t dim)
const {
945 return tensor::ExpandShapeOp::create(
947 getReassociationForReshapeAtDim(expandedType.getRank(), dim));
950 LogicalResult matchAndRewrite(FromOpTy contractionOp,
951 PatternRewriter &rewriter)
const override {
952 if (contractionOp.hasUserDefinedMaps()) {
954 contractionOp,
"ops with user-defined maps are not supported");
957 auto loc = contractionOp.getLoc();
958 auto inputs = contractionOp.getDpsInputs();
959 auto inits = contractionOp.getDpsInits();
960 if (inputs.size() != 2 || inits.size() != 1)
962 "expected 2 inputs and 1 init");
963 auto lhs = inputs[0];
964 auto rhs = inputs[1];
965 auto init = inits[0];
966 SmallVector<Value> operands{
lhs,
rhs, init};
968 SmallVector<int64_t> operandUnitDims;
969 if (
failed(getOperandUnitDims(contractionOp, operandUnitDims)))
971 "no reducable dims found");
973 SmallVector<Value> collapsedOperands =
974 collapseOperands(rewriter, operands, operandUnitDims);
975 Value collapsedLhs = collapsedOperands[0];
976 Value collapsedRhs = collapsedOperands[1];
977 Value collapsedInit = collapsedOperands[2];
978 SmallVector<Type, 1> collapsedResultTy;
979 if (isa<RankedTensorType>(collapsedInit.
getType()))
980 collapsedResultTy.push_back(collapsedInit.
getType());
981 auto collapsedOp = ToOpTy::create(rewriter, loc, collapsedResultTy,
984 for (
auto attr : contractionOp->getAttrs()) {
985 if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
986 attr.getName() ==
"indexing_maps")
988 collapsedOp->setAttr(attr.getName(), attr.getValue());
991 auto results = contractionOp.getResults();
992 assert(results.size() < 2 &&
"expected at most one result");
993 if (results.empty()) {
994 rewriter.
replaceOp(contractionOp, collapsedOp);
998 expandResult(rewriter, collapsedOp.getResultTensors()[0],
999 cast<RankedTensorType>(results[0].getType()),
1000 operandUnitDims[2]));
1009 virtual LogicalResult
1010 getOperandUnitDims(LinalgOp op,
1011 SmallVectorImpl<int64_t> &operandUnitDims)
const = 0;
1015template <
typename FromOpTy,
typename ToOpTy>
1016struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
1017 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1021 getOperandUnitDims(LinalgOp op,
1022 SmallVectorImpl<int64_t> &operandUnitDims)
const override {
1023 FailureOr<ContractionDimensions> maybeContractionDims =
1025 if (
failed(maybeContractionDims)) {
1026 LLVM_DEBUG(llvm::dbgs() <<
"could not infer contraction dims");
1029 const ContractionDimensions &contractionDims = maybeContractionDims.value();
1031 if (contractionDims.
batch.size() != 1)
1033 auto batchDim = contractionDims.
batch[0];
1034 SmallVector<std::pair<Value, unsigned>, 3> bOperands;
1035 op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
1036 if (bOperands.size() != 3 || llvm::any_of(bOperands, [](
auto pair) {
1037 return cast<ShapedType>(std::get<0>(pair).getType())
1038 .getShape()[std::get<1>(pair)] != 1;
1040 LLVM_DEBUG(llvm::dbgs() <<
"specified unit dims not found");
1044 operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
1045 std::get<1>(bOperands[1]),
1046 std::get<1>(bOperands[2])};
1052template <
typename FromOpTy,
typename ToOpTy>
1053struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1054 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1057 static bool constexpr reduceLeft =
1058 (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1059 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1060 (std::is_same_v<FromOpTy, MatmulOp> &&
1061 std::is_same_v<ToOpTy, VecmatOp>) ||
1062 (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1066 getOperandUnitDims(LinalgOp op,
1067 SmallVectorImpl<int64_t> &operandUnitDims)
const override {
1068 FailureOr<ContractionDimensions> maybeContractionDims =
1070 if (
failed(maybeContractionDims)) {
1071 LLVM_DEBUG(llvm::dbgs() <<
"could not infer contraction dims");
1074 const ContractionDimensions &contractionDims = maybeContractionDims.value();
1076 if constexpr (reduceLeft) {
1077 auto m = contractionDims.
m[0];
1078 SmallVector<std::pair<Value, unsigned>, 2> mOperands;
1079 op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1080 if (mOperands.size() != 2)
1082 if (llvm::all_of(mOperands, [](
auto pair) {
1083 return cast<ShapedType>(std::get<0>(pair).
getType())
1084 .getShape()[std::get<1>(pair)] == 1;
1086 operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
1087 std::get<1>(mOperands[1])};
1091 auto n = contractionDims.
n[0];
1092 SmallVector<std::pair<Value, unsigned>, 2> nOperands;
1093 op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1094 if (nOperands.size() != 2)
1096 if (llvm::all_of(nOperands, [](
auto pair) {
1097 return cast<ShapedType>(std::get<0>(pair).
getType())
1098 .getShape()[std::get<1>(pair)] == 1;
1100 operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
1101 std::get<1>(nOperands[1])};
1105 LLVM_DEBUG(llvm::dbgs() <<
"specified unit dims not found");
1116 patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1117 patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1118 patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1121 patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1122 patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1124 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1125 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1128 patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1129 patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet< unsigned > &unitDims, RewriterBase &rewriter)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand, llvm::SmallDenseMap< unsigned, unsigned > &oldDimsToNewDimsMap, ArrayRef< AffineExpr > dimReplacements)
static llvm::ManagedStatic< PassManagerOptions > options
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
ArrayRef< AffineExpr > getResults() const
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Specialization of arith.constant op that returns an integer of index type.
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns)
A pattern that converts init operands to input operands.
std::function< IndexingMapOpInterface( Location loc, OpBuilder &, IndexingMapOpInterface, ArrayRef< Value > newOperands, ArrayRef< AffineMap > newIndexingMaps, const llvm::SmallDenseSet< unsigned > &droppedDims)> DroppedUnitDimsBuilder
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns)
Adds patterns that reduce the rank of named contraction ops that have unit dimensions in the operand(...
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors and memref.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< DropUnitDimsResult > dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op, const DroppedUnitDimsBuilder &droppedUnitDimsBuilder, const ControlDropUnitDims &options)
Drop unit extent dimensions from the op and its operands.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
void populateFoldUnitExtentDimsCanonicalizationPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Populates canonicalization patterns that simplify IR after folding unit-extent dimensions.
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
SmallVector< int64_t, 2 > ReassociationIndices
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Compute the modified metadata for an operands of operation whose unit dims are being dropped.
SmallVector< ReassociationIndices > reassociation
SmallVector< int64_t > targetShape
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
SmallVector< unsigned, 2 > batch
SmallVector< unsigned, 2 > m
SmallVector< unsigned, 2 > n
Transformation to drop unit-extent dimensions from linalg.generic operations.
RankReductionStrategy rankReductionStrategy
CollapseFnTy collapseFn
Function to control how operands are collapsed into their new target shape after dropping unit extent...