60 assert(invProducerResultIndexMap &&
61 "expected producer result indexing map to be invertible");
63 LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
65 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
73 return t1.
compose(fusedConsumerArgIndexMap);
80 GenericOp producer, GenericOp consumer,
85 for (
auto &op : ops) {
86 for (
auto &opOperand : op->getOpOperands()) {
87 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
90 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
93 if (indexingMaps.empty()) {
96 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
104 indexingMaps, producer.getContext())) !=
AffineMap();
113 GenericOp producer, GenericOp consumer,
OpOperand *fusedOperand) {
114 llvm::SmallDenseSet<int> preservedProducerResults;
118 opOperandsToIgnore.emplace_back(fusedOperand);
120 for (
const auto &producerResult : llvm::enumerate(producer->getResults())) {
121 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122 opOperandsToIgnore.emplace_back(outputOperand);
123 if (producer.payloadUsesValueFromOperand(outputOperand) ||
125 opOperandsToIgnore) ||
126 llvm::any_of(producerResult.value().getUsers(), [&](
Operation *user) {
127 return user != consumer.getOperation();
129 preservedProducerResults.insert(producerResult.index());
132 (
void)opOperandsToIgnore.pop_back_val();
135 return preservedProducerResults;
144 auto consumer = dyn_cast<GenericOp>(fusedOperand->
getOwner());
147 if (!producer || !consumer)
153 if (!producer.hasPureTensorSemantics() ||
154 !isa<RankedTensorType>(fusedOperand->
get().
getType()))
159 if (producer.getNumParallelLoops() != producer.getNumLoops())
164 if (!consumer.isDpsInput(fusedOperand))
169 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170 if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
175 auto producerResult = cast<OpResult>(fusedOperand->
get());
177 producer.getIndexingMapMatchingResult(producerResult);
185 if ((consumer.getNumReductionLoops())) {
186 BitVector coveredDims(consumer.getNumLoops(),
false);
188 auto addToCoveredDims = [&](
AffineMap map) {
189 for (
auto result : map.getResults())
190 if (
auto dimExpr = dyn_cast<AffineDimExpr>(
result))
191 coveredDims[dimExpr.getPosition()] =
true;
195 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
196 Value operand = std::get<0>(pair);
197 if (operand == fusedOperand->
get())
199 AffineMap operandMap = std::get<1>(pair);
200 addToCoveredDims(operandMap);
203 for (
OpOperand *operand : producer.getDpsInputOperands()) {
206 operand, producerResultIndexMap, consumerIndexMap);
207 addToCoveredDims(newIndexingMap);
209 if (!coveredDims.all())
221 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
223 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
225 Block &producerBlock = producer->getRegion(0).
front();
226 Block &consumerBlock = consumer->getRegion(0).
front();
233 if (producer.hasIndexSemantics()) {
235 unsigned numFusedOpLoops = fusedOp.getNumLoops();
237 fusedIndices.reserve(numFusedOpLoops);
238 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
239 std::back_inserter(fusedIndices), [&](uint64_t dim) {
240 return IndexOp::create(rewriter, producer.getLoc(), dim);
242 for (IndexOp indexOp :
243 llvm::make_early_inc_range(producerBlock.
getOps<IndexOp>())) {
244 Value newIndex = affine::AffineApplyOp::create(
245 rewriter, producer.getLoc(),
246 consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
247 mapper.
map(indexOp.getResult(), newIndex);
251 assert(consumer.isDpsInput(fusedOperand) &&
252 "expected producer of input operand");
256 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
263 producerBlock.
getArguments().take_front(producer.getNumDpsInputs()))
264 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
269 .take_front(consumer.getNumDpsInputs())
271 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
274 for (
const auto &bbArg : llvm::enumerate(
275 producerBlock.
getArguments().take_back(producer.getNumDpsInits()))) {
276 if (!preservedProducerResults.count(bbArg.index()))
278 mapper.
map(bbArg.value(), fusedBlock->
addArgument(bbArg.value().getType(),
279 bbArg.value().getLoc()));
284 consumerBlock.
getArguments().take_back(consumer.getNumDpsInits()))
285 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
290 if (!isa<IndexOp>(op))
291 rewriter.
clone(op, mapper);
295 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.
getTerminator());
296 unsigned producerResultNumber =
297 cast<OpResult>(fusedOperand->
get()).getResultNumber();
299 mapper.
lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
303 if (
replacement == producerYieldOp.getOperand(producerResultNumber)) {
304 if (
auto bb = dyn_cast<BlockArgument>(
replacement))
305 assert(bb.getOwner() != &producerBlock &&
306 "yielded block argument must have been mapped");
308 assert(!producer->isAncestor(
replacement.getDefiningOp()) &&
309 "yielded value must have been mapped");
315 rewriter.
clone(op, mapper);
319 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.
getTerminator());
321 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
322 consumerYieldOp.getNumOperands());
323 for (
const auto &producerYieldVal :
324 llvm::enumerate(producerYieldOp.getOperands())) {
325 if (preservedProducerResults.count(producerYieldVal.index()))
326 fusedYieldValues.push_back(
329 for (
auto consumerYieldVal : consumerYieldOp.getOperands())
331 YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);
335 "Ill-formed GenericOp region");
338FailureOr<mlir::linalg::ElementwiseOpFusionResult>
342 "expected elementwise operation pre-conditions to pass");
343 auto producerResult = cast<OpResult>(fusedOperand->
get());
344 auto producer = cast<GenericOp>(producerResult.getOwner());
345 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
347 assert(consumer.isDpsInput(fusedOperand) &&
348 "expected producer of input operand");
351 llvm::SmallDenseSet<int> preservedProducerResults =
359 fusedInputOperands.reserve(producer.getNumDpsInputs() +
360 consumer.getNumDpsInputs());
361 fusedOutputOperands.reserve(preservedProducerResults.size() +
362 consumer.getNumDpsInits());
363 fusedResultTypes.reserve(preservedProducerResults.size() +
364 consumer.getNumDpsInits());
365 fusedIndexMaps.reserve(producer->getNumOperands() +
366 consumer->getNumOperands());
369 auto consumerInputs = consumer.getDpsInputOperands();
370 auto *it = llvm::find_if(consumerInputs, [&](
OpOperand *operand) {
371 return operand == fusedOperand;
373 assert(it != consumerInputs.end() &&
"expected to find the consumer operand");
374 for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
375 fusedInputOperands.push_back(opOperand->get());
376 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
380 producer.getIndexingMapMatchingResult(producerResult);
381 for (
OpOperand *opOperand : producer.getDpsInputOperands()) {
382 fusedInputOperands.push_back(opOperand->get());
385 opOperand, producerResultIndexMap,
386 consumer.getMatchingIndexingMap(fusedOperand));
387 fusedIndexMaps.push_back(map);
392 llvm::make_range(std::next(it), consumerInputs.end())) {
393 fusedInputOperands.push_back(opOperand->get());
394 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
398 for (
const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
399 if (!preservedProducerResults.count(opOperand.index()))
402 fusedOutputOperands.push_back(opOperand.value().get());
404 &opOperand.value(), producerResultIndexMap,
405 consumer.getMatchingIndexingMap(fusedOperand));
406 fusedIndexMaps.push_back(map);
407 fusedResultTypes.push_back(opOperand.value().get().getType());
411 for (
OpOperand &opOperand : consumer.getDpsInitsMutable()) {
412 fusedOutputOperands.push_back(opOperand.get());
413 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
414 Type resultType = opOperand.get().getType();
415 if (!isa<MemRefType>(resultType))
416 fusedResultTypes.push_back(resultType);
420 auto fusedOp = GenericOp::create(
421 rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
423 consumer.getIteratorTypes(),
426 if (!fusedOp.getShapesToLoopsMap()) {
432 fusedOp,
"fused op failed loop bound computation check");
438 consumer.getMatchingIndexingMap(fusedOperand);
442 assert(invProducerResultIndexMap &&
443 "expected producer result indexig map to be invertible");
446 invProducerResultIndexMap.
compose(consumerResultIndexMap);
449 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
450 consumer.getNumLoops(), preservedProducerResults);
454 for (
auto [
index, producerResult] : llvm::enumerate(producer->getResults()))
455 if (preservedProducerResults.count(
index))
456 result.replacements[producerResult] = fusedOp->getResult(resultNum++);
457 for (
auto consumerResult : consumer->getResults())
458 result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
469 controlFn(std::move(fun)) {}
471 LogicalResult matchAndRewrite(GenericOp genericOp,
474 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
477 if (!controlFn(&opOperand))
480 Operation *producer = opOperand.get().getDefiningOp();
483 FailureOr<ElementwiseOpFusionResult> fusionResult =
485 if (failed(fusionResult))
489 for (
auto [origVal,
replacement] : fusionResult->replacements) {
571 linalgOp.getIteratorTypesArray();
572 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573 return linalgOp.hasPureTensorSemantics() &&
574 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
576 return cast<AffineMapAttr>(attr)
578 .isProjectedPermutation();
592 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
593 ArrayRef<AffineMap> reassociationMaps,
594 ArrayRef<OpFoldResult> expandedShape,
595 PatternRewriter &rewriter);
596 unsigned getOrigOpNumDims()
const {
return reassociation.size(); }
597 unsigned getExpandedOpNumDims()
const {
return expandedOpNumDims; }
599 return reassociation[i];
601 ArrayRef<OpFoldResult> getExpandedShapeOfDim(
unsigned i)
const {
602 return expandedShapeMap[i];
604 ArrayRef<OpFoldResult> getOriginalShape()
const {
return originalLoopExtent; }
609 SmallVector<ReassociationIndices> reassociation;
612 SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
614 SmallVector<OpFoldResult> originalLoopExtent;
615 unsigned expandedOpNumDims;
619LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
624 if (reassociationMaps.empty())
626 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
628 OpBuilder::InsertionGuard g(rewriter);
630 originalLoopExtent = llvm::map_to_vector(
631 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
632 [](Range r) { return r.size; });
634 reassociation.clear();
635 expandedShapeMap.clear();
638 SmallVector<unsigned> numExpandedDims(fusedIndexMap.
getNumDims(), 1);
639 expandedShapeMap.resize(fusedIndexMap.
getNumDims());
640 for (
const auto &resultExpr : llvm::enumerate(fusedIndexMap.
getResults())) {
641 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
642 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
644 ArrayRef<OpFoldResult> shape =
645 expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
646 expandedShapeMap[pos].assign(shape.begin(), shape.end());
649 for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
650 if (expandedShapeMap[i].empty())
651 expandedShapeMap[i] = {originalLoopExtent[i]};
655 reassociation.reserve(fusedIndexMap.
getNumDims());
656 for (
const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
657 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
658 reassociation.emplace_back(seq.begin(), seq.end());
659 sum += numFoldedDim.value();
661 expandedOpNumDims = sum;
669 const ExpansionInfo &expansionInfo) {
672 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
674 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](
int64_t v) {
675 return builder.getAffineDimExpr(static_cast<unsigned>(v));
677 newExprs.append(expandedExprs.begin(), expandedExprs.end());
686static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
688 const ExpansionInfo &expansionInfo) {
691 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
693 expansionInfo.getExpandedShapeOfDim(dim);
694 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
697 std::tie(expandedStaticShape, std::ignore) =
699 return {expandedShape, RankedTensorType::get(expandedStaticShape,
700 originalType.getElementType())};
709static SmallVector<ReassociationIndices>
711 const ExpansionInfo &expansionInfo) {
713 unsigned numReshapeDims = 0;
715 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
716 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
718 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
719 reassociation.emplace_back(std::move(
indices));
720 numReshapeDims += numExpandedDims;
722 return reassociation;
732 const ExpansionInfo &expansionInfo) {
734 for (IndexOp indexOp :
735 llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
737 expansionInfo.getExpandedDims(indexOp.getDim());
738 assert(!expandedDims.empty() &&
"expected valid expansion info");
741 if (expandedDims.size() == 1 &&
742 expandedDims.front() == (
int64_t)indexOp.getDim())
749 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
751 expandedIndices.reserve(expandedDims.size() - 1);
753 expandedDims.drop_front(), std::back_inserter(expandedIndices),
754 [&](
int64_t dim) { return IndexOp::create(rewriter, loc, dim); });
756 IndexOp::create(rewriter, loc, expandedDims.front()).getResult();
757 for (
auto [expandedShape, expandedIndex] :
758 llvm::zip(expandedDimsShape, expandedIndices)) {
763 rewriter, indexOp.getLoc(), idx +
acc *
shape,
768 rewriter.
replaceOp(indexOp, newIndexVal);
790 TransposeOp transposeOp,
792 ExpansionInfo &expansionInfo) {
795 auto reassoc = expansionInfo.getExpandedDims(perm);
797 newPerm.push_back(dim);
800 return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput,
811 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
813 for (
auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
814 for (
auto j : expansionInfo.getExpandedDims(i))
815 iteratorTypes[
j] = type;
817 Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes,
818 expandedOpOperands, outputs,
819 expandedOpIndexingMaps, iteratorTypes);
822 Region &originalRegion = linalgOp->getRegion(0);
840 ExpansionInfo &expansionInfo) {
843 .Case<TransposeOp>([&](TransposeOp transposeOp) {
845 expandedOpOperands[0], outputs[0],
848 .Case<FillOp, CopyOp>([&](
Operation *op) {
849 return clone(rewriter, linalgOp, resultTypes,
850 llvm::to_vector(llvm::concat<Value>(
851 llvm::to_vector(expandedOpOperands),
852 llvm::to_vector(outputs))));
856 expandedOpOperands, outputs,
857 expansionInfo, expandedOpIndexingMaps);
864static std::optional<SmallVector<Value>>
869 "preconditions for fuse operation failed");
875 if (
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
879 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
882 expandedShape = expandingReshapeOp.getMixedOutputShape();
883 reassociationIndices = expandingReshapeOp.getReassociationMaps();
884 src = expandingReshapeOp.getSrc();
886 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887 if (!collapsingReshapeOp)
891 rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
892 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
893 src = collapsingReshapeOp.getSrc();
896 ExpansionInfo expansionInfo;
897 if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
898 reassociationIndices, expandedShape,
903 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](
AffineMap m) {
904 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
912 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
913 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
914 if (opOperand == fusableOpOperand) {
915 expandedOpOperands.push_back(src);
918 if (
auto opOperandType =
919 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
920 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
922 RankedTensorType expandedOperandType;
923 std::tie(expandedOperandShape, expandedOperandType) =
925 if (expandedOperandType != opOperand->get().getType()) {
929 if (failed(reshapeLikeShapesAreCompatible(
930 [&](
const Twine &msg) {
933 opOperandType.getShape(), expandedOperandType.getShape(),
937 expandedOpOperands.push_back(tensor::ExpandShapeOp::create(
938 rewriter, loc, expandedOperandType, opOperand->get(), reassociation,
939 expandedOperandShape));
943 expandedOpOperands.push_back(opOperand->get());
947 for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
948 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
949 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
951 RankedTensorType expandedOutputType;
952 std::tie(expandedOutputShape, expandedOutputType) =
954 if (expandedOutputType != opOperand.get().getType()) {
957 if (failed(reshapeLikeShapesAreCompatible(
958 [&](
const Twine &msg) {
961 opOperandType.getShape(), expandedOutputType.getShape(),
965 outputs.push_back(tensor::ExpandShapeOp::create(
966 rewriter, loc, expandedOutputType, opOperand.get(), reassociation,
967 expandedOutputShape));
969 outputs.push_back(opOperand.get());
976 outputs, expandedOpIndexingMaps, expansionInfo);
980 for (
OpResult opResult : linalgOp->getOpResults()) {
981 int64_t resultNumber = opResult.getResultNumber();
982 if (resultTypes[resultNumber] != opResult.getType()) {
985 linalgOp.getMatchingIndexingMap(
986 linalgOp.getDpsInitOperand(resultNumber)),
988 resultVals.push_back(tensor::CollapseShapeOp::create(
989 rewriter, linalgOp.getLoc(), opResult.getType(),
990 fusedOp->
getResult(resultNumber), reassociation));
992 resultVals.push_back(fusedOp->
getResult(resultNumber));
1004class FoldWithProducerReshapeOpByExpansion
1005 :
public OpInterfaceRewritePattern<LinalgOp> {
1007 FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1009 PatternBenefit benefit = 1)
1010 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1011 controlFoldingReshapes(std::move(foldReshapes)) {}
1013 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1014 PatternRewriter &rewriter)
const override {
1015 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1016 tensor::CollapseShapeOp reshapeOp =
1017 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1024 (!controlFoldingReshapes(opOperand)))
1027 std::optional<SmallVector<Value>> replacementValues =
1029 if (!replacementValues)
1031 rewriter.
replaceOp(linalgOp, *replacementValues);
1044 SmallVector<int64_t> paddedShape;
1047 SmallVector<OpFoldResult> lowPad;
1048 SmallVector<OpFoldResult> highPad;
1055static FailureOr<PadDimInfo>
1056computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
1057 ArrayRef<ReassociationIndices> reassociations,
1058 PatternRewriter &rewriter) {
1065 if (!padOp.getConstantPaddingValue())
1072 ArrayRef<int64_t> low = padOp.getStaticLow();
1073 ArrayRef<int64_t> high = padOp.getStaticHigh();
1074 for (
auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1075 if (reInd.size() != 1 && (l != 0 || h != 0))
1079 SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
1080 SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
1081 ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
1082 PadDimInfo padDimInfo;
1083 padDimInfo.paddedShape.assign(expandedShape);
1084 padDimInfo.lowPad.assign(expandedShape.size(), rewriter.
getIndexAttr(0));
1085 padDimInfo.highPad.assign(expandedShape.size(), rewriter.
getIndexAttr(0));
1086 for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
1087 if (reInd.size() == 1) {
1088 padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
1089 padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
1090 padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
1097class FoldPadWithProducerReshapeOpByExpansion
1098 :
public OpRewritePattern<tensor::PadOp> {
1100 FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1102 PatternBenefit benefit = 1)
1103 : OpRewritePattern<tensor::PadOp>(context, benefit),
1104 controlFoldingReshapes(std::move(foldReshapes)) {}
1106 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1107 PatternRewriter &rewriter)
const override {
1108 tensor::CollapseShapeOp reshapeOp =
1109 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1113 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1115 "fusion blocked by control function");
1118 RankedTensorType expandedType = reshapeOp.getSrcType();
1119 SmallVector<ReassociationIndices> reassociations =
1120 reshapeOp.getReassociationIndices();
1121 FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1122 padOp, expandedType.getShape(), reassociations, rewriter);
1123 if (
failed(maybeExpandedPadding))
1125 PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1127 Location loc = padOp->getLoc();
1128 RankedTensorType expandedPaddedType =
1129 padOp.getResultType().clone(expandedPadding.paddedShape);
1131 auto newPadOp = tensor::PadOp::create(
1132 rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
1133 expandedPadding.lowPad, expandedPadding.highPad,
1134 padOp.getConstantPaddingValue(), padOp.getNofold());
1137 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1146class FoldReshapeWithProducerPadOpByExpansion
1147 :
public OpRewritePattern<tensor::ExpandShapeOp> {
1149 FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
1151 PatternBenefit benefit = 1)
1152 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1153 controlFoldingReshapes(std::move(foldReshapes)) {}
1155 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
1156 PatternRewriter &rewriter)
const override {
1157 tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
1161 if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
1163 "fusion blocked by control function");
1166 RankedTensorType expandedType = expandOp.getResultType();
1167 SmallVector<ReassociationIndices> reassociations =
1168 expandOp.getReassociationIndices();
1169 FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1170 padOp, expandedType.getShape(), reassociations, rewriter);
1171 if (
failed(maybeExpandedPadding))
1173 PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1175 Location loc = expandOp->getLoc();
1176 SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
1177 SmallVector<int64_t> newExpandedShape(expandedType.getShape());
1179 SmallVector<OpFoldResult> padSrcSizes =
1181 for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
1184 if (reInd.size() == 1) {
1185 newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
1186 newExpandedSizes[reInd[0]] = padSrcSizes[idx];
1189 RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
1190 auto newExpandOp = tensor::ExpandShapeOp::create(
1191 rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
1193 RankedTensorType expandedPaddedType =
1194 padOp.getResultType().clone(expandedPadding.paddedShape);
1196 auto newPadOp = tensor::PadOp::create(
1197 rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
1198 expandedPadding.lowPad, expandedPadding.highPad,
1199 padOp.getConstantPaddingValue(), padOp.getNofold());
1201 rewriter.
replaceOp(expandOp, newPadOp.getResult());
1212struct FoldReshapeWithGenericOpByExpansion
1213 :
public OpRewritePattern<tensor::ExpandShapeOp> {
1215 FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1217 PatternBenefit benefit = 1)
1218 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1219 controlFoldingReshapes(std::move(foldReshapes)) {}
1221 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1222 PatternRewriter &rewriter)
const override {
1224 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1225 if (!producerResult) {
1227 "source not produced by an operation");
1230 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1233 "producer not a generic op");
1238 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1240 reshapeOp,
"failed preconditions of fusion with producer generic op");
1243 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1245 "fusion blocked by control function");
1248 std::optional<SmallVector<Value>> replacementValues =
1250 producer, reshapeOp,
1251 producer.getDpsInitOperand(producerResult.getResultNumber()),
1253 if (!replacementValues) {
1255 "fusion by expansion failed");
1262 Value reshapeReplacement =
1263 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1264 .getResultNumber()];
1265 if (
auto collapseOp =
1266 reshapeReplacement.
getDefiningOp<tensor::CollapseShapeOp>()) {
1267 reshapeReplacement = collapseOp.getSrc();
1269 rewriter.
replaceOp(reshapeOp, reshapeReplacement);
1270 rewriter.
replaceOp(producer, *replacementValues);
1292 "expected projected permutation");
1295 llvm::map_range(rangeReassociation, [&](
int64_t pos) ->
int64_t {
1296 return cast<AffineDimExpr>(indexingMap.
getResults()[pos]).getPosition();
1300 return domainReassociation;
1308 assert(!dimSequence.empty() &&
1309 "expected non-empty list for dimension sequence");
1311 "expected indexing map to be projected permutation");
1313 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1314 sequenceElements.insert_range(dimSequence);
1316 unsigned dimSequenceStart = dimSequence[0];
1317 for (
const auto &expr : enumerate(indexingMap.
getResults())) {
1318 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1320 if (dimInMapStart == dimSequenceStart) {
1321 if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
1324 for (
const auto &dimInSequence : enumerate(dimSequence)) {
1326 cast<AffineDimExpr>(
1327 indexingMap.
getResult(expr.index() + dimInSequence.index()))
1329 if (dimInMap != dimInSequence.value())
1340 if (sequenceElements.count(dimInMapStart))
1349 return llvm::all_of(maps, [&](
AffineMap map) {
1406 if (!genericOp.hasPureTensorSemantics())
1409 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
1410 return map.isProjectedPermutation();
1417 genericOp.getReductionDims(reductionDims);
1419 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1420 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1421 auto iteratorTypes = genericOp.getIteratorTypesArray();
1424 assert(!foldedRangeDims.empty() &&
"unexpected empty reassociation");
1427 if (foldedRangeDims.size() == 1)
1435 if (llvm::any_of(foldedIterationSpaceDims, [&](
int64_t dim) {
1436 return processedIterationDims.count(dim);
1441 utils::IteratorType startIteratorType =
1442 iteratorTypes[foldedIterationSpaceDims[0]];
1446 if (llvm::any_of(foldedIterationSpaceDims, [&](
int64_t dim) {
1447 return iteratorTypes[dim] != startIteratorType;
1456 bool isContiguous =
false;
1457 for (
const auto &startDim : llvm::enumerate(reductionDims)) {
1459 if (startDim.value() != foldedIterationSpaceDims[0])
1463 if (startDim.index() + foldedIterationSpaceDims.size() >
1464 reductionDims.size())
1467 isContiguous =
true;
1468 for (
const auto &foldedDim :
1469 llvm::enumerate(foldedIterationSpaceDims)) {
1470 if (reductionDims[foldedDim.index() + startDim.index()] !=
1471 foldedDim.value()) {
1472 isContiguous =
false;
1483 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1485 return !isDimSequencePreserved(indexingMap,
1486 foldedIterationSpaceDims);
1490 processedIterationDims.insert_range(foldedIterationSpaceDims);
1491 iterationSpaceReassociation.emplace_back(
1492 std::move(foldedIterationSpaceDims));
1495 return iterationSpaceReassociation;
1500class CollapsingInfo {
1502 LogicalResult
initialize(
unsigned origNumLoops,
1503 ArrayRef<ReassociationIndices> foldedIterationDims) {
1504 llvm::SmallDenseSet<int64_t, 4> processedDims;
1507 if (foldedIterationDim.empty())
1511 for (
auto dim : foldedIterationDim) {
1512 if (dim >= origNumLoops)
1514 if (processedDims.count(dim))
1516 processedDims.insert(dim);
1518 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1519 foldedIterationDim.end());
1521 if (processedDims.size() > origNumLoops)
1526 for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1527 if (processedDims.count(dim))
1532 llvm::sort(collapsedOpToOrigOpIterationDim,
1536 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1537 for (
const auto &foldedDims :
1538 llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1539 for (
const auto &dim :
enumerate(foldedDims.value()))
1540 origOpToCollapsedOpIterationDim[dim.value()] =
1541 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1548 return collapsedOpToOrigOpIterationDim;
1571 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping()
const {
1572 return origOpToCollapsedOpIterationDim;
1576 unsigned getCollapsedOpIterationRank()
const {
1577 return collapsedOpToOrigOpIterationDim.size();
1583 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1587 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1593static SmallVector<utils::IteratorType>
1594getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
1595 const CollapsingInfo &collapsingInfo) {
1596 SmallVector<utils::IteratorType> collapsedIteratorTypes;
1598 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1599 assert(!foldedIterDims.empty() &&
1600 "reassociation indices expected to have non-empty sets");
1604 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1606 return collapsedIteratorTypes;
1612getCollapsedOpIndexingMap(AffineMap indexingMap,
1613 const CollapsingInfo &collapsingInfo) {
1614 MLIRContext *context = indexingMap.
getContext();
1616 "expected indexing map to be projected permutation");
1617 SmallVector<AffineExpr> resultExprs;
1618 auto origOpToCollapsedOpMapping =
1619 collapsingInfo.getOrigOpToCollapsedOpMapping();
1621 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1623 if (origOpToCollapsedOpMapping[dim].second != 0)
1627 resultExprs.push_back(
1630 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1631 resultExprs, context);
1636static SmallVector<ReassociationIndices>
1637getOperandReassociation(AffineMap indexingMap,
1638 const CollapsingInfo &collapsingInfo) {
1639 unsigned counter = 0;
1640 SmallVector<ReassociationIndices> operandReassociation;
1641 auto origOpToCollapsedOpMapping =
1642 collapsingInfo.getOrigOpToCollapsedOpMapping();
1643 auto collapsedOpToOrigOpMapping =
1644 collapsingInfo.getCollapsedOpToOrigOpMapping();
1647 cast<AffineDimExpr>(indexingMap.
getResult(counter)).getPosition();
1651 unsigned numFoldedDims =
1652 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1654 if (origOpToCollapsedOpMapping[dim].second == 0) {
1655 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1656 operandReassociation.emplace_back(range.begin(), range.end());
1658 counter += numFoldedDims;
1660 return operandReassociation;
1664static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1665 OpOperand *opOperand,
1666 const CollapsingInfo &collapsingInfo,
1667 OpBuilder &builder) {
1668 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1669 SmallVector<ReassociationIndices> operandReassociation =
1670 getOperandReassociation(indexingMap, collapsingInfo);
1675 Value operand = opOperand->
get();
1676 if (operandReassociation.size() == indexingMap.
getNumResults())
1680 if (isa<MemRefType>(operand.
getType())) {
1681 return memref::CollapseShapeOp::create(builder, loc, operand,
1682 operandReassociation)
1685 return tensor::CollapseShapeOp::create(builder, loc, operand,
1686 operandReassociation)
1692static void generateCollapsedIndexingRegion(
1693 Location loc,
Block *block,
const CollapsingInfo &collapsingInfo,
1694 ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1695 OpBuilder::InsertionGuard g(rewriter);
1699 auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
1708 llvm::DenseMap<unsigned, Value> indexReplacementVals;
1709 for (
auto foldedDims :
1710 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1713 linalg::IndexOp::create(rewriter, loc, foldedDims.index());
1714 for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1717 indexReplacementVals[dim] =
1718 rewriter.
createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1720 rewriter.
createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1722 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1725 for (
auto indexOp : indexOps) {
1726 auto dim = indexOp.getDim();
1727 rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
1731static void collapseOperandsAndResults(LinalgOp op,
1732 const CollapsingInfo &collapsingInfo,
1733 RewriterBase &rewriter,
1734 SmallVectorImpl<Value> &inputOperands,
1735 SmallVectorImpl<Value> &outputOperands,
1736 SmallVectorImpl<Type> &resultTypes) {
1737 Location loc = op->getLoc();
1739 llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1740 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1745 resultTypes.reserve(op.getNumDpsInits());
1746 outputOperands.reserve(op.getNumDpsInits());
1747 for (OpOperand &output : op.getDpsInitsMutable()) {
1749 getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1750 outputOperands.push_back(newOutput);
1753 if (!op.hasPureBufferSemantics())
1754 resultTypes.push_back(newOutput.
getType());
1759template <
typename OpTy>
1760static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1761 const CollapsingInfo &collapsingInfo) {
1768LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1769 const CollapsingInfo &collapsingInfo) {
1770 SmallVector<Value> inputOperands, outputOperands;
1771 SmallVector<Type> resultTypes;
1772 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1773 outputOperands, resultTypes);
1776 rewriter, origOp, resultTypes,
1777 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1782GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1784 const CollapsingInfo &collapsingInfo) {
1785 SmallVector<Value> inputOperands, outputOperands;
1786 SmallVector<Type> resultTypes;
1787 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1788 outputOperands, resultTypes);
1789 SmallVector<AffineMap> indexingMaps(
1790 llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1791 return getCollapsedOpIndexingMap(map, collapsingInfo);
1794 SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1795 origOp.getIteratorTypesArray(), collapsingInfo));
1797 GenericOp collapsedOp = linalg::GenericOp::create(
1798 rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands,
1799 indexingMaps, iteratorTypes,
1800 [](OpBuilder &builder, Location loc,
ValueRange args) {});
1801 Block *origOpBlock = &origOp->getRegion(0).front();
1802 Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1803 rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
1808static LinalgOp createCollapsedOp(LinalgOp op,
1809 const CollapsingInfo &collapsingInfo,
1810 RewriterBase &rewriter) {
1811 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1812 return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1814 return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1819 LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1820 RewriterBase &rewriter) {
1822 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1824 return foldedDims.size() <= 1;
1828 CollapsingInfo collapsingInfo;
1830 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1832 op,
"illegal to collapse specified dimensions");
1835 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1836 if (hasPureBufferSemantics &&
1837 !llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) ->
bool {
1838 MemRefType memRefToCollapse =
1839 dyn_cast<MemRefType>(opOperand.get().getType());
1840 if (!memRefToCollapse)
1843 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1844 SmallVector<ReassociationIndices> operandReassociation =
1845 getOperandReassociation(indexingMap, collapsingInfo);
1846 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1847 memRefToCollapse, operandReassociation);
1850 "memref is not guaranteed collapsible");
1853 SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1854 auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1855 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1856 return cast<IntegerAttr>(attr).getInt() == value;
1859 actual.getSExtValue() == value;
1861 if (!llvm::all_of(loopRanges, [&](Range range) {
1862 return opFoldIsConstantValue(range.
offset, 0) &&
1863 opFoldIsConstantValue(range.
stride, 1);
1866 op,
"expected all loop ranges to have zero start and unit stride");
1869 LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1871 Location loc = op->getLoc();
1872 SmallVector<OpFoldResult> loopBound =
1873 llvm::map_to_vector(loopRanges, [](Range range) {
return range.
size; });
1875 if (collapsedOp.hasIndexSemantics()) {
1877 OpBuilder::InsertionGuard g(rewriter);
1879 generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1880 collapsingInfo, loopBound, rewriter);
1885 SmallVector<Value> results;
1886 for (
const auto &originalResult : llvm::enumerate(op->getResults())) {
1887 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1888 auto originalResultType =
1889 cast<ShapedType>(originalResult.value().getType());
1890 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.
getType());
1891 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1892 AffineMap indexingMap =
1893 op.getIndexingMapMatchingResult(originalResult.value());
1894 SmallVector<ReassociationIndices> reassociation =
1895 getOperandReassociation(indexingMap, collapsingInfo);
1898 "Expected indexing map to be a projected permutation for collapsing");
1899 SmallVector<OpFoldResult> resultShape =
1902 if (isa<MemRefType>(collapsedOpResult.
getType())) {
1903 MemRefType expandShapeResultType = MemRefType::get(
1904 originalResultType.getShape(), originalResultType.getElementType());
1905 result = memref::ExpandShapeOp::create(
1906 rewriter, loc, expandShapeResultType, collapsedOpResult,
1907 reassociation, resultShape);
1909 result = tensor::ExpandShapeOp::create(
1910 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1913 results.push_back(
result);
1915 results.push_back(collapsedOpResult);
1918 return CollapseResult{results, collapsedOp};
1925class FoldWithProducerReshapeOpByCollapsing
1926 :
public OpRewritePattern<GenericOp> {
1929 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1931 PatternBenefit benefit = 1)
1932 : OpRewritePattern<GenericOp>(context, benefit),
1933 controlFoldingReshapes(std::move(foldReshapes)) {}
1935 LogicalResult matchAndRewrite(GenericOp genericOp,
1936 PatternRewriter &rewriter)
const override {
1937 for (OpOperand &opOperand : genericOp->getOpOperands()) {
1938 tensor::ExpandShapeOp reshapeOp =
1943 SmallVector<ReassociationIndices> collapsableIterationDims =
1945 reshapeOp.getReassociationIndices());
1946 if (collapsableIterationDims.empty() ||
1947 !controlFoldingReshapes(&opOperand)) {
1952 genericOp, collapsableIterationDims, rewriter);
1953 if (!collapseResult) {
1955 genericOp,
"failed to do the fusion by collapsing transformation");
1958 rewriter.
replaceOp(genericOp, collapseResult->results);
1970struct FoldReshapeWithGenericOpByCollapsing
1971 :
public OpRewritePattern<tensor::CollapseShapeOp> {
1973 FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1975 PatternBenefit benefit = 1)
1976 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1977 controlFoldingReshapes(std::move(foldReshapes)) {}
1979 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1980 PatternRewriter &rewriter)
const override {
1983 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1984 if (!producerResult) {
1986 "source not produced by an operation");
1990 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1993 "producer not a generic op");
1996 SmallVector<ReassociationIndices> collapsableIterationDims =
1999 producer.getDpsInitOperand(producerResult.getResultNumber()),
2000 reshapeOp.getReassociationIndices());
2001 if (collapsableIterationDims.empty()) {
2003 reshapeOp,
"failed preconditions of fusion with producer generic op");
2006 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2008 "fusion blocked by control function");
2014 std::optional<CollapseResult> collapseResult =
2016 if (!collapseResult) {
2018 producer,
"failed to do the fusion by collapsing transformation");
2021 rewriter.
replaceOp(producer, collapseResult->results);
2033static FailureOr<PadDimInfo>
2034computeCollapsedPadding(tensor::PadOp padOp,
2035 ArrayRef<ReassociationIndices> reassociations,
2036 PatternRewriter &rewriter) {
2043 if (!padOp.getConstantPaddingValue())
2050 ArrayRef<int64_t> low = padOp.getStaticLow();
2051 ArrayRef<int64_t> high = padOp.getStaticHigh();
2052 for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
2053 for (int64_t dim : reInd) {
2054 if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
2060 ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
2061 PadDimInfo padDimInfo;
2062 padDimInfo.lowPad.assign(reassociations.size(), rewriter.
getIndexAttr(0));
2063 padDimInfo.highPad.assign(reassociations.size(), rewriter.
getIndexAttr(0));
2067 SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
2068 SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
2069 for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
2070 if (reInd.size() == 1) {
2071 padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]];
2072 padDimInfo.highPad[idx] = mixedHighPad[reInd[0]];
2075 for (int64_t dim : reInd) {
2079 padDimInfo.paddedShape.push_back(collapsedSize.
asInteger());
2085class FoldPadWithProducerReshapeOpByCollapsing
2086 :
public OpRewritePattern<tensor::PadOp> {
2088 FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
2090 PatternBenefit benefit = 1)
2091 : OpRewritePattern<tensor::PadOp>(context, benefit),
2092 controlFoldingReshapes(std::move(foldReshapes)) {}
2094 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2095 PatternRewriter &rewriter)
const override {
2096 tensor::ExpandShapeOp reshapeOp =
2097 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
2101 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
2103 "fusion blocked by control function");
2106 SmallVector<ReassociationIndices> reassociations =
2107 reshapeOp.getReassociationIndices();
2108 FailureOr<PadDimInfo> maybeCollapsedPadding =
2109 computeCollapsedPadding(padOp, reassociations, rewriter);
2110 if (
failed(maybeCollapsedPadding))
2112 PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2114 SmallVector<OpFoldResult> expandedPaddedSizes =
2115 reshapeOp.getMixedOutputShape();
2116 AffineExpr d0, d1, d2;
2119 Location loc = reshapeOp->getLoc();
2120 for (
auto [reInd, l, h] :
2121 llvm::zip_equal(reassociations, collapsedPadding.lowPad,
2122 collapsedPadding.highPad)) {
2123 if (reInd.size() == 1) {
2124 expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
2125 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
2129 RankedTensorType collapsedPaddedType =
2130 padOp.getType().clone(collapsedPadding.paddedShape);
2131 auto newPadOp = tensor::PadOp::create(
2132 rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
2133 collapsedPadding.lowPad, collapsedPadding.highPad,
2134 padOp.getConstantPaddingValue(), padOp.getNofold());
2137 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
2138 expandedPaddedSizes);
2147class FoldReshapeWithProducerPadOpByCollapsing
2148 :
public OpRewritePattern<tensor::CollapseShapeOp> {
2150 FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
2152 PatternBenefit benefit = 1)
2153 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
2154 controlFoldingReshapes(std::move(foldReshapes)) {}
2156 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
2157 PatternRewriter &rewriter)
const override {
2158 tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
2162 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2164 "fusion blocked by control function");
2167 SmallVector<ReassociationIndices> reassociations =
2168 reshapeOp.getReassociationIndices();
2169 RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
2170 FailureOr<PadDimInfo> maybeCollapsedPadding =
2171 computeCollapsedPadding(padOp, reassociations, rewriter);
2172 if (
failed(maybeCollapsedPadding))
2174 PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2176 Location loc = reshapeOp->getLoc();
2177 auto newCollapseOp = tensor::CollapseShapeOp::create(
2178 rewriter, loc, padOp.getSource(), reassociations);
2180 auto newPadOp = tensor::PadOp::create(
2181 rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
2182 collapsedPadding.lowPad, collapsedPadding.highPad,
2183 padOp.getConstantPaddingValue(), padOp.getNofold());
2185 rewriter.
replaceOp(reshapeOp, newPadOp.getResult());
2194template <
typename LinalgType>
2195class CollapseLinalgDimensions :
public OpRewritePattern<LinalgType> {
2197 CollapseLinalgDimensions(MLIRContext *context,
2199 PatternBenefit benefit = 1)
2200 : OpRewritePattern<LinalgType>(context, benefit),
2201 controlCollapseDimension(std::move(collapseDimensions)) {}
2203 LogicalResult matchAndRewrite(LinalgType op,
2204 PatternRewriter &rewriter)
const override {
2205 SmallVector<ReassociationIndices> collapsableIterationDims =
2206 controlCollapseDimension(op);
2207 if (collapsableIterationDims.empty())
2212 collapsableIterationDims)) {
2214 op,
"specified dimensions cannot be collapsed");
2217 std::optional<CollapseResult> collapseResult =
2219 if (!collapseResult) {
2222 rewriter.
replaceOp(op, collapseResult->results);
2239class FoldScalarOrSplatConstant :
public OpRewritePattern<GenericOp> {
2241 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2242 : OpRewritePattern<GenericOp>(context, benefit) {}
2244 LogicalResult matchAndRewrite(GenericOp genericOp,
2245 PatternRewriter &rewriter)
const override {
2246 if (!genericOp.hasPureTensorSemantics())
2248 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2250 TypedAttr constantAttr;
2251 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) ->
bool {
2253 DenseElementsAttr splatAttr;
2256 splatAttr.
getType().getElementType().isIntOrFloat()) {
2262 IntegerAttr intAttr;
2264 constantAttr = intAttr;
2269 FloatAttr floatAttr;
2271 constantAttr = floatAttr;
2278 auto resultValue = dyn_cast<OpResult>(opOperand->
get());
2279 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2285 SmallVector<AffineMap> fusedIndexMaps;
2286 SmallVector<Value> fusedOperands;
2287 SmallVector<Location> fusedLocs{genericOp.getLoc()};
2288 fusedIndexMaps.reserve(genericOp->getNumOperands());
2289 fusedOperands.reserve(genericOp.getNumDpsInputs());
2290 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2291 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2292 if (inputOperand == opOperand)
2294 Value inputValue = inputOperand->get();
2295 fusedIndexMaps.push_back(
2296 genericOp.getMatchingIndexingMap(inputOperand));
2297 fusedOperands.push_back(inputValue);
2298 fusedLocs.push_back(inputValue.
getLoc());
2300 for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2301 fusedIndexMaps.push_back(
2302 genericOp.getMatchingIndexingMap(&outputOperand));
2308 genericOp,
"fused op loop bound computation failed");
2312 Value scalarConstant =
2313 arith::ConstantOp::create(rewriter, def->
getLoc(), constantAttr);
2315 SmallVector<Value> outputOperands = genericOp.getOutputs();
2317 GenericOp::create(rewriter, rewriter.
getFusedLoc(fusedLocs),
2318 genericOp->getResultTypes(),
2322 genericOp.getIteratorTypes(),
2328 Region ®ion = genericOp->getRegion(0);
2333 Region &fusedRegion = fusedOp->getRegion(0);
2336 rewriter.
replaceOp(genericOp, fusedOp->getResults());
2354struct RemoveOutsDependency :
public OpRewritePattern<GenericOp> {
2355 using OpRewritePattern<GenericOp>::OpRewritePattern;
2357 LogicalResult matchAndRewrite(GenericOp op,
2358 PatternRewriter &rewriter)
const override {
2360 bool modifiedOutput =
false;
2361 Location loc = op.getLoc();
2362 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2363 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2364 Value operandVal = opOperand.
get();
2365 auto operandType = dyn_cast<RankedTensorType>(operandVal.
getType());
2374 auto definingOp = operandVal.
getDefiningOp<tensor::EmptyOp>();
2377 modifiedOutput =
true;
2378 SmallVector<OpFoldResult> mixedSizes =
2380 Value emptyTensor = tensor::EmptyOp::create(
2381 rewriter, loc, mixedSizes, operandType.getElementType());
2385 if (!modifiedOutput) {
2395struct FoldFillWithGenericOp :
public OpRewritePattern<GenericOp> {
2396 using OpRewritePattern<GenericOp>::OpRewritePattern;
2398 LogicalResult matchAndRewrite(GenericOp genericOp,
2399 PatternRewriter &rewriter)
const override {
2400 if (!genericOp.hasPureTensorSemantics())
2402 bool fillFound =
false;
2403 Block &payload = genericOp.getRegion().front();
2404 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2405 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2411 Value fillVal = fillOp.value();
2413 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2414 Value convertedVal =
2429 controlFoldingReshapes);
2430 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(
patterns.getContext(),
2431 controlFoldingReshapes);
2432 patterns.add<FoldReshapeWithProducerPadOpByExpansion>(
patterns.getContext(),
2433 controlFoldingReshapes);
2435 controlFoldingReshapes);
2442 controlFoldingReshapes);
2443 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2444 patterns.getContext(), controlFoldingReshapes);
2445 patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
2446 patterns.getContext(), controlFoldingReshapes);
2448 controlFoldingReshapes);
2454 auto *context =
patterns.getContext();
2455 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2456 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2457 RemoveOutsDependency>(context);
2465 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2466 CollapseLinalgDimensions<linalg::CopyOp>>(
2467 patterns.getContext(), controlCollapseDimensions);
2482struct LinalgElementwiseOpFusionPass
2483 :
public impl::LinalgElementwiseOpFusionPassBase<
2484 LinalgElementwiseOpFusionPass> {
2485 using impl::LinalgElementwiseOpFusionPassBase<
2486 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2487 void runOnOperation()
override {
2494 Operation *producer = fusedOperand->get().getDefiningOp();
2495 return producer && producer->
hasOneUse();
2504 affine::AffineApplyOp::getCanonicalizationPatterns(
patterns, context);
2505 GenericOp::getCanonicalizationPatterns(
patterns, context);
2506 tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
2507 tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);