29#include "llvm/ADT/SmallVectorExtras.h"
34#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
35#include "mlir/Dialect/Linalg/Passes.h.inc"
61 assert(invProducerResultIndexMap &&
62 "expected producer result indexing map to be invertible");
64 LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
66 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
74 return t1.
compose(fusedConsumerArgIndexMap);
81 GenericOp producer, GenericOp consumer,
86 for (
auto &op : ops) {
87 for (
auto &opOperand : op->getOpOperands()) {
88 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
91 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
94 if (indexingMaps.empty()) {
97 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
105 indexingMaps, producer.getContext())) !=
AffineMap();
114 GenericOp producer, GenericOp consumer,
OpOperand *fusedOperand) {
115 llvm::SmallDenseSet<int> preservedProducerResults;
119 opOperandsToIgnore.emplace_back(fusedOperand);
121 for (
const auto &producerResult : llvm::enumerate(producer->getResults())) {
122 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
123 opOperandsToIgnore.emplace_back(outputOperand);
124 if (producer.payloadUsesValueFromOperand(outputOperand) ||
126 opOperandsToIgnore) ||
127 llvm::any_of(producerResult.value().getUsers(), [&](
Operation *user) {
128 return user != consumer.getOperation();
130 preservedProducerResults.insert(producerResult.index());
133 (
void)opOperandsToIgnore.pop_back_val();
136 return preservedProducerResults;
145 auto consumer = dyn_cast<GenericOp>(fusedOperand->
getOwner());
148 if (!producer || !consumer)
154 if (!producer.hasPureTensorSemantics() ||
155 !isa<RankedTensorType>(fusedOperand->
get().
getType()))
160 if (producer.getNumParallelLoops() != producer.getNumLoops())
165 if (!consumer.isDpsInput(fusedOperand))
170 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
171 if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
176 auto producerResult = cast<OpResult>(fusedOperand->
get());
178 producer.getIndexingMapMatchingResult(producerResult);
186 if ((consumer.getNumReductionLoops())) {
187 BitVector coveredDims(consumer.getNumLoops(),
false);
189 auto addToCoveredDims = [&](
AffineMap map) {
190 for (
auto result : map.getResults())
191 if (
auto dimExpr = dyn_cast<AffineDimExpr>(
result))
192 coveredDims[dimExpr.getPosition()] =
true;
196 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
197 Value operand = std::get<0>(pair);
198 if (operand == fusedOperand->
get())
200 AffineMap operandMap = std::get<1>(pair);
201 addToCoveredDims(operandMap);
204 for (
OpOperand *operand : producer.getDpsInputOperands()) {
207 operand, producerResultIndexMap, consumerIndexMap);
208 addToCoveredDims(newIndexingMap);
210 if (!coveredDims.all())
222 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
224 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
226 Block &producerBlock = producer->getRegion(0).
front();
227 Block &consumerBlock = consumer->getRegion(0).
front();
234 if (producer.hasIndexSemantics()) {
236 unsigned numFusedOpLoops = fusedOp.getNumLoops();
238 fusedIndices.reserve(numFusedOpLoops);
239 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
240 std::back_inserter(fusedIndices), [&](uint64_t dim) {
241 return IndexOp::create(rewriter, producer.getLoc(), dim);
243 for (IndexOp indexOp :
244 llvm::make_early_inc_range(producerBlock.
getOps<IndexOp>())) {
245 Value newIndex = affine::AffineApplyOp::create(
246 rewriter, producer.getLoc(),
247 consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
248 mapper.
map(indexOp.getResult(), newIndex);
252 assert(consumer.isDpsInput(fusedOperand) &&
253 "expected producer of input operand");
257 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
264 producerBlock.
getArguments().take_front(producer.getNumDpsInputs()))
265 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
270 .take_front(consumer.getNumDpsInputs())
272 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
275 for (
const auto &bbArg : llvm::enumerate(
276 producerBlock.
getArguments().take_back(producer.getNumDpsInits()))) {
277 if (!preservedProducerResults.count(bbArg.index()))
279 mapper.
map(bbArg.value(), fusedBlock->
addArgument(bbArg.value().getType(),
280 bbArg.value().getLoc()));
285 consumerBlock.
getArguments().take_back(consumer.getNumDpsInits()))
286 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
291 if (!isa<IndexOp>(op))
292 rewriter.
clone(op, mapper);
296 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.
getTerminator());
297 unsigned producerResultNumber =
298 cast<OpResult>(fusedOperand->
get()).getResultNumber();
300 mapper.
lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
304 if (
replacement == producerYieldOp.getOperand(producerResultNumber)) {
305 if (
auto bb = dyn_cast<BlockArgument>(
replacement))
306 assert(bb.getOwner() != &producerBlock &&
307 "yielded block argument must have been mapped");
309 assert(!producer->isAncestor(
replacement.getDefiningOp()) &&
310 "yielded value must have been mapped");
316 rewriter.
clone(op, mapper);
320 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.
getTerminator());
322 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
323 consumerYieldOp.getNumOperands());
324 for (
const auto &producerYieldVal :
325 llvm::enumerate(producerYieldOp.getOperands())) {
326 if (preservedProducerResults.count(producerYieldVal.index()))
327 fusedYieldValues.push_back(
330 for (
auto consumerYieldVal : consumerYieldOp.getOperands())
332 YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);
336 "Ill-formed GenericOp region");
339FailureOr<mlir::linalg::ElementwiseOpFusionResult>
343 "expected elementwise operation pre-conditions to pass");
344 auto producerResult = cast<OpResult>(fusedOperand->
get());
345 auto producer = cast<GenericOp>(producerResult.getOwner());
346 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
348 assert(consumer.isDpsInput(fusedOperand) &&
349 "expected producer of input operand");
352 llvm::SmallDenseSet<int> preservedProducerResults =
360 fusedInputOperands.reserve(producer.getNumDpsInputs() +
361 consumer.getNumDpsInputs());
362 fusedOutputOperands.reserve(preservedProducerResults.size() +
363 consumer.getNumDpsInits());
364 fusedResultTypes.reserve(preservedProducerResults.size() +
365 consumer.getNumDpsInits());
366 fusedIndexMaps.reserve(producer->getNumOperands() +
367 consumer->getNumOperands());
370 auto consumerInputs = consumer.getDpsInputOperands();
371 auto *it = llvm::find_if(consumerInputs, [&](
OpOperand *operand) {
372 return operand == fusedOperand;
374 assert(it != consumerInputs.end() &&
"expected to find the consumer operand");
375 for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
376 fusedInputOperands.push_back(opOperand->get());
377 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
381 producer.getIndexingMapMatchingResult(producerResult);
382 for (
OpOperand *opOperand : producer.getDpsInputOperands()) {
383 fusedInputOperands.push_back(opOperand->get());
386 opOperand, producerResultIndexMap,
387 consumer.getMatchingIndexingMap(fusedOperand));
388 fusedIndexMaps.push_back(map);
393 llvm::make_range(std::next(it), consumerInputs.end())) {
394 fusedInputOperands.push_back(opOperand->get());
395 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
399 for (
const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
400 if (!preservedProducerResults.count(opOperand.index()))
403 fusedOutputOperands.push_back(opOperand.value().get());
405 &opOperand.value(), producerResultIndexMap,
406 consumer.getMatchingIndexingMap(fusedOperand));
407 fusedIndexMaps.push_back(map);
408 fusedResultTypes.push_back(opOperand.value().get().getType());
412 for (
OpOperand &opOperand : consumer.getDpsInitsMutable()) {
413 fusedOutputOperands.push_back(opOperand.get());
414 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
415 Type resultType = opOperand.get().getType();
416 if (!isa<MemRefType>(resultType))
417 fusedResultTypes.push_back(resultType);
421 auto fusedOp = GenericOp::create(
422 rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
424 consumer.getIteratorTypes(),
427 if (!fusedOp.getShapesToLoopsMap()) {
433 fusedOp,
"fused op failed loop bound computation check");
439 consumer.getMatchingIndexingMap(fusedOperand);
443 assert(invProducerResultIndexMap &&
444 "expected producer result indexig map to be invertible");
447 invProducerResultIndexMap.
compose(consumerResultIndexMap);
450 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
451 consumer.getNumLoops(), preservedProducerResults);
455 for (
auto [
index, producerResult] : llvm::enumerate(producer->getResults()))
456 if (preservedProducerResults.count(
index))
457 result.replacements[producerResult] = fusedOp->getResult(resultNum++);
458 for (
auto consumerResult : consumer->getResults())
459 result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
470 controlFn(std::move(fun)) {}
472 LogicalResult matchAndRewrite(GenericOp genericOp,
475 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
478 if (!controlFn(&opOperand))
481 Operation *producer = opOperand.get().getDefiningOp();
484 FailureOr<ElementwiseOpFusionResult> fusionResult =
486 if (failed(fusionResult))
490 for (
auto [origVal,
replacement] : fusionResult->replacements) {
572 linalgOp.getIteratorTypesArray();
573 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
574 return linalgOp.hasPureTensorSemantics() &&
575 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
577 return cast<AffineMapAttr>(attr)
579 .isProjectedPermutation();
593 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
594 ArrayRef<AffineMap> reassociationMaps,
595 ArrayRef<OpFoldResult> expandedShape,
596 PatternRewriter &rewriter);
597 unsigned getOrigOpNumDims()
const {
return reassociation.size(); }
598 unsigned getExpandedOpNumDims()
const {
return expandedOpNumDims; }
600 return reassociation[i];
602 ArrayRef<OpFoldResult> getExpandedShapeOfDim(
unsigned i)
const {
603 return expandedShapeMap[i];
605 ArrayRef<OpFoldResult> getOriginalShape()
const {
return originalLoopExtent; }
610 SmallVector<ReassociationIndices> reassociation;
613 SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
615 SmallVector<OpFoldResult> originalLoopExtent;
616 unsigned expandedOpNumDims;
620LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
625 if (reassociationMaps.empty())
627 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
629 OpBuilder::InsertionGuard g(rewriter);
631 originalLoopExtent = llvm::map_to_vector(
632 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
633 [](Range r) { return r.size; });
635 reassociation.clear();
636 expandedShapeMap.clear();
639 SmallVector<unsigned> numExpandedDims(fusedIndexMap.
getNumDims(), 1);
640 expandedShapeMap.resize(fusedIndexMap.
getNumDims());
641 for (
const auto &resultExpr : llvm::enumerate(fusedIndexMap.
getResults())) {
642 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
643 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
645 ArrayRef<OpFoldResult> shape =
646 expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
647 expandedShapeMap[pos].assign(shape.begin(), shape.end());
650 for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
651 if (expandedShapeMap[i].empty())
652 expandedShapeMap[i] = {originalLoopExtent[i]};
656 reassociation.reserve(fusedIndexMap.
getNumDims());
657 for (
const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
658 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
659 reassociation.emplace_back(seq.begin(), seq.end());
660 sum += numFoldedDim.value();
662 expandedOpNumDims = sum;
670 const ExpansionInfo &expansionInfo) {
673 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
675 expansionInfo.getExpandedDims(pos), [&](
int64_t v) {
676 return builder.getAffineDimExpr(static_cast<unsigned>(v));
678 newExprs.append(expandedExprs.begin(), expandedExprs.end());
687static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
689 const ExpansionInfo &expansionInfo) {
692 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
694 expansionInfo.getExpandedShapeOfDim(dim);
695 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
698 std::tie(expandedStaticShape, std::ignore) =
700 return {expandedShape, RankedTensorType::get(expandedStaticShape,
701 originalType.getElementType())};
710static SmallVector<ReassociationIndices>
712 const ExpansionInfo &expansionInfo) {
714 unsigned numReshapeDims = 0;
716 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
717 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
719 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
720 reassociation.emplace_back(std::move(
indices));
721 numReshapeDims += numExpandedDims;
723 return reassociation;
733 const ExpansionInfo &expansionInfo) {
735 for (IndexOp indexOp :
736 llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
738 expansionInfo.getExpandedDims(indexOp.getDim());
739 assert(!expandedDims.empty() &&
"expected valid expansion info");
742 if (expandedDims.size() == 1 &&
743 expandedDims.front() == (
int64_t)indexOp.getDim())
750 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
752 expandedIndices.reserve(expandedDims.size() - 1);
754 expandedDims.drop_front(), std::back_inserter(expandedIndices),
755 [&](
int64_t dim) { return IndexOp::create(rewriter, loc, dim); });
757 IndexOp::create(rewriter, loc, expandedDims.front()).getResult();
758 for (
auto [expandedShape, expandedIndex] :
759 llvm::zip(expandedDimsShape, expandedIndices)) {
764 rewriter, indexOp.getLoc(), idx +
acc *
shape,
769 rewriter.
replaceOp(indexOp, newIndexVal);
791 TransposeOp transposeOp,
793 ExpansionInfo &expansionInfo) {
796 auto reassoc = expansionInfo.getExpandedDims(perm);
798 newPerm.push_back(dim);
801 return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput,
812 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
814 for (
auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
815 for (
auto j : expansionInfo.getExpandedDims(i))
816 iteratorTypes[
j] = type;
818 Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes,
819 expandedOpOperands, outputs,
820 expandedOpIndexingMaps, iteratorTypes);
823 Region &originalRegion = linalgOp->getRegion(0);
841 ExpansionInfo &expansionInfo) {
844 .Case([&](TransposeOp transposeOp) {
846 expandedOpOperands[0], outputs[0],
849 .Case<FillOp, CopyOp>([&](
Operation *op) {
850 return clone(rewriter, linalgOp, resultTypes,
851 llvm::to_vector(llvm::concat<Value>(
852 llvm::to_vector(expandedOpOperands),
853 llvm::to_vector(outputs))));
857 expandedOpOperands, outputs,
858 expansionInfo, expandedOpIndexingMaps);
865static std::optional<SmallVector<Value>>
870 "preconditions for fuse operation failed");
876 if (
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
880 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
883 expandedShape = expandingReshapeOp.getMixedOutputShape();
884 reassociationIndices = expandingReshapeOp.getReassociationMaps();
885 src = expandingReshapeOp.getSrc();
887 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
888 if (!collapsingReshapeOp)
892 rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
893 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
894 src = collapsingReshapeOp.getSrc();
897 ExpansionInfo expansionInfo;
898 if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
899 reassociationIndices, expandedShape,
904 llvm::map_to_vector<4>(linalgOp.getIndexingMapsArray(), [&](
AffineMap m) {
905 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
913 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
914 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
915 if (opOperand == fusableOpOperand) {
916 expandedOpOperands.push_back(src);
919 if (
auto opOperandType =
920 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
921 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
923 RankedTensorType expandedOperandType;
924 std::tie(expandedOperandShape, expandedOperandType) =
926 if (expandedOperandType != opOperand->get().getType()) {
930 if (failed(reshapeLikeShapesAreCompatible(
931 [&](
const Twine &msg) {
934 opOperandType.getShape(), expandedOperandType.getShape(),
938 expandedOpOperands.push_back(tensor::ExpandShapeOp::create(
939 rewriter, loc, expandedOperandType, opOperand->get(), reassociation,
940 expandedOperandShape));
944 expandedOpOperands.push_back(opOperand->get());
948 for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
949 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
950 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
952 RankedTensorType expandedOutputType;
953 std::tie(expandedOutputShape, expandedOutputType) =
955 if (expandedOutputType != opOperand.get().getType()) {
958 if (failed(reshapeLikeShapesAreCompatible(
959 [&](
const Twine &msg) {
962 opOperandType.getShape(), expandedOutputType.getShape(),
966 outputs.push_back(tensor::ExpandShapeOp::create(
967 rewriter, loc, expandedOutputType, opOperand.get(), reassociation,
968 expandedOutputShape));
970 outputs.push_back(opOperand.get());
977 outputs, expandedOpIndexingMaps, expansionInfo);
981 for (
OpResult opResult : linalgOp->getOpResults()) {
982 int64_t resultNumber = opResult.getResultNumber();
983 if (resultTypes[resultNumber] != opResult.getType()) {
986 linalgOp.getMatchingIndexingMap(
987 linalgOp.getDpsInitOperand(resultNumber)),
989 resultVals.push_back(tensor::CollapseShapeOp::create(
990 rewriter, linalgOp.getLoc(), opResult.getType(),
991 fusedOp->
getResult(resultNumber), reassociation));
993 resultVals.push_back(fusedOp->
getResult(resultNumber));
1005class FoldWithProducerReshapeOpByExpansion
1006 :
public OpInterfaceRewritePattern<LinalgOp> {
1008 FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1010 PatternBenefit benefit = 1)
1011 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1012 controlFoldingReshapes(std::move(foldReshapes)) {}
1014 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1015 PatternRewriter &rewriter)
const override {
1016 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1017 tensor::CollapseShapeOp reshapeOp =
1018 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1025 (!controlFoldingReshapes(opOperand)))
1028 std::optional<SmallVector<Value>> replacementValues =
1030 if (!replacementValues)
1032 rewriter.
replaceOp(linalgOp, *replacementValues);
1045 SmallVector<int64_t> paddedShape;
1048 SmallVector<OpFoldResult> lowPad;
1049 SmallVector<OpFoldResult> highPad;
1056static FailureOr<PadDimInfo>
1057computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
1058 ArrayRef<ReassociationIndices> reassociations,
1059 PatternRewriter &rewriter) {
1066 if (!padOp.getConstantPaddingValue())
1073 ArrayRef<int64_t> low = padOp.getStaticLow();
1074 ArrayRef<int64_t> high = padOp.getStaticHigh();
1075 for (
auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1076 if (reInd.size() != 1 && (l != 0 || h != 0))
1080 SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
1081 SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
1082 ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
1083 PadDimInfo padDimInfo;
1084 padDimInfo.paddedShape.assign(expandedShape);
1085 padDimInfo.lowPad.assign(expandedShape.size(), rewriter.
getIndexAttr(0));
1086 padDimInfo.highPad.assign(expandedShape.size(), rewriter.
getIndexAttr(0));
1087 for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
1088 if (reInd.size() == 1) {
1089 padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
1090 padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
1091 padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
1098class FoldPadWithProducerReshapeOpByExpansion
1099 :
public OpRewritePattern<tensor::PadOp> {
1101 FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1103 PatternBenefit benefit = 1)
1104 : OpRewritePattern<tensor::PadOp>(context, benefit),
1105 controlFoldingReshapes(std::move(foldReshapes)) {}
1107 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1108 PatternRewriter &rewriter)
const override {
1109 tensor::CollapseShapeOp reshapeOp =
1110 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1114 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1116 "fusion blocked by control function");
1119 RankedTensorType expandedType = reshapeOp.getSrcType();
1120 SmallVector<ReassociationIndices> reassociations =
1121 reshapeOp.getReassociationIndices();
1122 FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1123 padOp, expandedType.getShape(), reassociations, rewriter);
1124 if (
failed(maybeExpandedPadding))
1126 PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1128 Location loc = padOp->getLoc();
1129 RankedTensorType expandedPaddedType =
1130 padOp.getResultType().clone(expandedPadding.paddedShape);
1132 auto newPadOp = tensor::PadOp::create(
1133 rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
1134 expandedPadding.lowPad, expandedPadding.highPad,
1135 padOp.getConstantPaddingValue(), padOp.getNofold());
1138 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1147class FoldReshapeWithProducerPadOpByExpansion
1148 :
public OpRewritePattern<tensor::ExpandShapeOp> {
1150 FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
1152 PatternBenefit benefit = 1)
1153 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1154 controlFoldingReshapes(std::move(foldReshapes)) {}
1156 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
1157 PatternRewriter &rewriter)
const override {
1158 tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
1162 if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
1164 "fusion blocked by control function");
1167 RankedTensorType expandedType = expandOp.getResultType();
1168 SmallVector<ReassociationIndices> reassociations =
1169 expandOp.getReassociationIndices();
1170 FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1171 padOp, expandedType.getShape(), reassociations, rewriter);
1172 if (
failed(maybeExpandedPadding))
1174 PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1176 Location loc = expandOp->getLoc();
1177 SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
1178 SmallVector<int64_t> newExpandedShape(expandedType.getShape());
1180 SmallVector<OpFoldResult> padSrcSizes =
1182 for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
1185 if (reInd.size() == 1) {
1186 newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
1187 newExpandedSizes[reInd[0]] = padSrcSizes[idx];
1190 RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
1191 auto newExpandOp = tensor::ExpandShapeOp::create(
1192 rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
1194 RankedTensorType expandedPaddedType =
1195 padOp.getResultType().clone(expandedPadding.paddedShape);
1197 auto newPadOp = tensor::PadOp::create(
1198 rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
1199 expandedPadding.lowPad, expandedPadding.highPad,
1200 padOp.getConstantPaddingValue(), padOp.getNofold());
1202 rewriter.
replaceOp(expandOp, newPadOp.getResult());
1213struct FoldReshapeWithGenericOpByExpansion
1214 :
public OpRewritePattern<tensor::ExpandShapeOp> {
1216 FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1218 PatternBenefit benefit = 1)
1219 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1220 controlFoldingReshapes(std::move(foldReshapes)) {}
1222 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1223 PatternRewriter &rewriter)
const override {
1225 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1226 if (!producerResult) {
1228 "source not produced by an operation");
1231 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1234 "producer not a generic op");
1239 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1241 reshapeOp,
"failed preconditions of fusion with producer generic op");
1244 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1246 "fusion blocked by control function");
1249 std::optional<SmallVector<Value>> replacementValues =
1251 producer, reshapeOp,
1252 producer.getDpsInitOperand(producerResult.getResultNumber()),
1254 if (!replacementValues) {
1256 "fusion by expansion failed");
1263 Value reshapeReplacement =
1264 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1265 .getResultNumber()];
1266 if (
auto collapseOp =
1267 reshapeReplacement.
getDefiningOp<tensor::CollapseShapeOp>()) {
1268 reshapeReplacement = collapseOp.getSrc();
1270 rewriter.
replaceOp(reshapeOp, reshapeReplacement);
1271 rewriter.
replaceOp(producer, *replacementValues);
1293 "expected projected permutation");
1296 llvm::map_to_vector<4>(rangeReassociation, [&](
int64_t pos) ->
int64_t {
1297 return cast<AffineDimExpr>(indexingMap.
getResults()[pos]).getPosition();
1301 return domainReassociation;
1309 assert(!dimSequence.empty() &&
1310 "expected non-empty list for dimension sequence");
1317 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1318 sequenceElements.insert_range(dimSequence);
1320 unsigned dimSequenceStart = dimSequence[0];
1321 for (
const auto &expr : enumerate(indexingMap.
getResults())) {
1322 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1324 if (dimInMapStart == dimSequenceStart) {
1325 if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
1328 for (
const auto &dimInSequence : enumerate(dimSequence)) {
1330 cast<AffineDimExpr>(
1331 indexingMap.
getResult(expr.index() + dimInSequence.index()))
1333 if (dimInMap != dimInSequence.value())
1344 if (sequenceElements.count(dimInMapStart))
1353 return llvm::all_of(maps, [&](
AffineMap map) {
1410 if (!genericOp.hasPureTensorSemantics())
1413 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
1414 return map.isProjectedPermutation();
1421 genericOp.getReductionDims(reductionDims);
1423 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1424 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1425 auto iteratorTypes = genericOp.getIteratorTypesArray();
1428 assert(!foldedRangeDims.empty() &&
"unexpected empty reassociation");
1431 if (foldedRangeDims.size() == 1)
1439 if (llvm::any_of(foldedIterationSpaceDims, [&](
int64_t dim) {
1440 return processedIterationDims.count(dim);
1445 utils::IteratorType startIteratorType =
1446 iteratorTypes[foldedIterationSpaceDims[0]];
1450 if (llvm::any_of(foldedIterationSpaceDims, [&](
int64_t dim) {
1451 return iteratorTypes[dim] != startIteratorType;
1460 bool isContiguous =
false;
1461 for (
const auto &startDim : llvm::enumerate(reductionDims)) {
1463 if (startDim.value() != foldedIterationSpaceDims[0])
1467 if (startDim.index() + foldedIterationSpaceDims.size() >
1468 reductionDims.size())
1471 isContiguous =
true;
1472 for (
const auto &foldedDim :
1473 llvm::enumerate(foldedIterationSpaceDims)) {
1474 if (reductionDims[foldedDim.index() + startDim.index()] !=
1475 foldedDim.value()) {
1476 isContiguous =
false;
1487 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1489 return !isDimSequencePreserved(indexingMap,
1490 foldedIterationSpaceDims);
1494 processedIterationDims.insert_range(foldedIterationSpaceDims);
1495 iterationSpaceReassociation.emplace_back(
1496 std::move(foldedIterationSpaceDims));
1499 return iterationSpaceReassociation;
1504class CollapsingInfo {
1506 LogicalResult
initialize(
unsigned origNumLoops,
1507 ArrayRef<ReassociationIndices> foldedIterationDims) {
1508 llvm::SmallDenseSet<int64_t, 4> processedDims;
1511 if (foldedIterationDim.empty())
1515 for (
auto dim : foldedIterationDim) {
1516 if (dim >= origNumLoops)
1518 if (processedDims.count(dim))
1520 processedDims.insert(dim);
1522 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1523 foldedIterationDim.end());
1525 if (processedDims.size() > origNumLoops)
1530 for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1531 if (processedDims.count(dim))
1536 llvm::sort(collapsedOpToOrigOpIterationDim,
1540 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1541 for (
const auto &foldedDims :
1542 llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1543 for (
const auto &dim :
enumerate(foldedDims.value()))
1544 origOpToCollapsedOpIterationDim[dim.value()] =
1545 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1552 return collapsedOpToOrigOpIterationDim;
1575 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping()
const {
1576 return origOpToCollapsedOpIterationDim;
1580 unsigned getCollapsedOpIterationRank()
const {
1581 return collapsedOpToOrigOpIterationDim.size();
1587 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1591 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1597static SmallVector<utils::IteratorType>
1598getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
1599 const CollapsingInfo &collapsingInfo) {
1600 SmallVector<utils::IteratorType> collapsedIteratorTypes;
1602 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1603 assert(!foldedIterDims.empty() &&
1604 "reassociation indices expected to have non-empty sets");
1608 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1610 return collapsedIteratorTypes;
1616getCollapsedOpIndexingMap(AffineMap indexingMap,
1617 const CollapsingInfo &collapsingInfo) {
1618 MLIRContext *context = indexingMap.
getContext();
1620 "expected indexing map to be projected permutation");
1621 SmallVector<AffineExpr> resultExprs;
1622 auto origOpToCollapsedOpMapping =
1623 collapsingInfo.getOrigOpToCollapsedOpMapping();
1625 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1627 if (origOpToCollapsedOpMapping[dim].second != 0)
1631 resultExprs.push_back(
1634 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1635 resultExprs, context);
1640static SmallVector<ReassociationIndices>
1641getOperandReassociation(AffineMap indexingMap,
1642 const CollapsingInfo &collapsingInfo) {
1643 unsigned counter = 0;
1644 SmallVector<ReassociationIndices> operandReassociation;
1645 auto origOpToCollapsedOpMapping =
1646 collapsingInfo.getOrigOpToCollapsedOpMapping();
1647 auto collapsedOpToOrigOpMapping =
1648 collapsingInfo.getCollapsedOpToOrigOpMapping();
1651 cast<AffineDimExpr>(indexingMap.
getResult(counter)).getPosition();
1655 unsigned numFoldedDims =
1656 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1658 if (origOpToCollapsedOpMapping[dim].second == 0) {
1659 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1660 operandReassociation.emplace_back(range.begin(), range.end());
1662 counter += numFoldedDims;
1664 return operandReassociation;
1668static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1669 OpOperand *opOperand,
1670 const CollapsingInfo &collapsingInfo,
1671 OpBuilder &builder) {
1672 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1673 SmallVector<ReassociationIndices> operandReassociation =
1674 getOperandReassociation(indexingMap, collapsingInfo);
1679 Value operand = opOperand->
get();
1680 if (operandReassociation.size() == indexingMap.
getNumResults())
1684 if (isa<MemRefType>(operand.
getType())) {
1685 return memref::CollapseShapeOp::create(builder, loc, operand,
1686 operandReassociation)
1689 return tensor::CollapseShapeOp::create(builder, loc, operand,
1690 operandReassociation)
1696static void generateCollapsedIndexingRegion(
1697 Location loc,
Block *block,
const CollapsingInfo &collapsingInfo,
1698 ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1699 OpBuilder::InsertionGuard g(rewriter);
1703 auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
1712 llvm::DenseMap<unsigned, Value> indexReplacementVals;
1713 for (
auto foldedDims :
1714 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1717 linalg::IndexOp::create(rewriter, loc, foldedDims.index());
1718 for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1721 indexReplacementVals[dim] =
1722 rewriter.
createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1724 rewriter.
createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1726 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1729 for (
auto indexOp : indexOps) {
1730 auto dim = indexOp.getDim();
1731 rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
1735static void collapseOperandsAndResults(LinalgOp op,
1736 const CollapsingInfo &collapsingInfo,
1737 RewriterBase &rewriter,
1738 SmallVectorImpl<Value> &inputOperands,
1739 SmallVectorImpl<Value> &outputOperands,
1740 SmallVectorImpl<Type> &resultTypes) {
1741 Location loc = op->getLoc();
1743 llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1744 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1749 resultTypes.reserve(op.getNumDpsInits());
1750 outputOperands.reserve(op.getNumDpsInits());
1751 for (OpOperand &output : op.getDpsInitsMutable()) {
1753 getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1754 outputOperands.push_back(newOutput);
1757 if (!op.hasPureBufferSemantics())
1758 resultTypes.push_back(newOutput.
getType());
1763template <
typename OpTy>
1764static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1765 const CollapsingInfo &collapsingInfo) {
1772LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1773 const CollapsingInfo &collapsingInfo) {
1774 SmallVector<Value> inputOperands, outputOperands;
1775 SmallVector<Type> resultTypes;
1776 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1777 outputOperands, resultTypes);
1780 rewriter, origOp, resultTypes,
1781 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1786GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1788 const CollapsingInfo &collapsingInfo) {
1789 SmallVector<Value> inputOperands, outputOperands;
1790 SmallVector<Type> resultTypes;
1791 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1792 outputOperands, resultTypes);
1793 SmallVector<AffineMap> indexingMaps(
1794 llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1795 return getCollapsedOpIndexingMap(map, collapsingInfo);
1798 SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1799 origOp.getIteratorTypesArray(), collapsingInfo));
1801 GenericOp collapsedOp = linalg::GenericOp::create(
1802 rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands,
1803 indexingMaps, iteratorTypes,
1804 [](OpBuilder &builder, Location loc,
ValueRange args) {});
1805 Block *origOpBlock = &origOp->getRegion(0).front();
1806 Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1807 rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
1812static LinalgOp createCollapsedOp(LinalgOp op,
1813 const CollapsingInfo &collapsingInfo,
1814 RewriterBase &rewriter) {
1815 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1816 return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1818 return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1823 LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1824 RewriterBase &rewriter) {
1826 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1828 return foldedDims.size() <= 1;
1832 CollapsingInfo collapsingInfo;
1834 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1836 op,
"illegal to collapse specified dimensions");
1839 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1840 if (hasPureBufferSemantics &&
1841 !llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) ->
bool {
1842 MemRefType memRefToCollapse =
1843 dyn_cast<MemRefType>(opOperand.get().getType());
1844 if (!memRefToCollapse)
1847 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1848 SmallVector<ReassociationIndices> operandReassociation =
1849 getOperandReassociation(indexingMap, collapsingInfo);
1850 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1851 memRefToCollapse, operandReassociation);
1854 "memref is not guaranteed collapsible");
1857 SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1858 auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1859 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1860 return cast<IntegerAttr>(attr).getInt() == value;
1863 actual.getSExtValue() == value;
1865 if (!llvm::all_of(loopRanges, [&](Range range) {
1866 return opFoldIsConstantValue(range.
offset, 0) &&
1867 opFoldIsConstantValue(range.
stride, 1);
1870 op,
"expected all loop ranges to have zero start and unit stride");
1873 LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1875 Location loc = op->getLoc();
1876 SmallVector<OpFoldResult> loopBound =
1877 llvm::map_to_vector(loopRanges, [](Range range) {
return range.
size; });
1879 if (collapsedOp.hasIndexSemantics()) {
1881 OpBuilder::InsertionGuard g(rewriter);
1883 generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1884 collapsingInfo, loopBound, rewriter);
1889 SmallVector<Value> results;
1890 for (
const auto &originalResult : llvm::enumerate(op->getResults())) {
1891 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1892 auto originalResultType =
1893 cast<ShapedType>(originalResult.value().getType());
1894 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.
getType());
1895 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1896 AffineMap indexingMap =
1897 op.getIndexingMapMatchingResult(originalResult.value());
1898 SmallVector<ReassociationIndices> reassociation =
1899 getOperandReassociation(indexingMap, collapsingInfo);
1902 "Expected indexing map to be a projected permutation for collapsing");
1903 SmallVector<OpFoldResult> resultShape =
1906 if (isa<MemRefType>(collapsedOpResult.
getType())) {
1907 result = memref::ExpandShapeOp::create(
1908 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1911 result = tensor::ExpandShapeOp::create(
1912 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1915 results.push_back(
result);
1917 results.push_back(collapsedOpResult);
1920 return CollapseResult{results, collapsedOp};
1927class FoldWithProducerReshapeOpByCollapsing
1928 :
public OpRewritePattern<GenericOp> {
1931 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1933 PatternBenefit benefit = 1)
1934 : OpRewritePattern<GenericOp>(context, benefit),
1935 controlFoldingReshapes(std::move(foldReshapes)) {}
1937 LogicalResult matchAndRewrite(GenericOp genericOp,
1938 PatternRewriter &rewriter)
const override {
1939 for (OpOperand &opOperand : genericOp->getOpOperands()) {
1940 tensor::ExpandShapeOp reshapeOp =
1945 SmallVector<ReassociationIndices> collapsableIterationDims =
1947 reshapeOp.getReassociationIndices());
1948 if (collapsableIterationDims.empty() ||
1949 !controlFoldingReshapes(&opOperand)) {
1954 genericOp, collapsableIterationDims, rewriter);
1955 if (!collapseResult) {
1957 genericOp,
"failed to do the fusion by collapsing transformation");
1960 rewriter.
replaceOp(genericOp, collapseResult->results);
1972struct FoldReshapeWithGenericOpByCollapsing
1973 :
public OpRewritePattern<tensor::CollapseShapeOp> {
1975 FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1977 PatternBenefit benefit = 1)
1978 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1979 controlFoldingReshapes(std::move(foldReshapes)) {}
1981 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1982 PatternRewriter &rewriter)
const override {
1985 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1986 if (!producerResult) {
1988 "source not produced by an operation");
1992 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1995 "producer not a generic op");
1998 SmallVector<ReassociationIndices> collapsableIterationDims =
2001 producer.getDpsInitOperand(producerResult.getResultNumber()),
2002 reshapeOp.getReassociationIndices());
2003 if (collapsableIterationDims.empty()) {
2005 reshapeOp,
"failed preconditions of fusion with producer generic op");
2008 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2010 "fusion blocked by control function");
2016 std::optional<CollapseResult> collapseResult =
2018 if (!collapseResult) {
2020 producer,
"failed to do the fusion by collapsing transformation");
2023 rewriter.
replaceOp(producer, collapseResult->results);
2035static FailureOr<PadDimInfo>
2036computeCollapsedPadding(tensor::PadOp padOp,
2037 ArrayRef<ReassociationIndices> reassociations,
2038 PatternRewriter &rewriter) {
2045 if (!padOp.getConstantPaddingValue())
2052 ArrayRef<int64_t> low = padOp.getStaticLow();
2053 ArrayRef<int64_t> high = padOp.getStaticHigh();
2054 for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
2055 for (int64_t dim : reInd) {
2056 if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
2062 ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
2063 PadDimInfo padDimInfo;
2064 padDimInfo.lowPad.assign(reassociations.size(), rewriter.
getIndexAttr(0));
2065 padDimInfo.highPad.assign(reassociations.size(), rewriter.
getIndexAttr(0));
2069 SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
2070 SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
2071 for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
2072 if (reInd.size() == 1) {
2073 padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]];
2074 padDimInfo.highPad[idx] = mixedHighPad[reInd[0]];
2077 for (int64_t dim : reInd) {
2081 padDimInfo.paddedShape.push_back(collapsedSize.
asInteger());
2087class FoldPadWithProducerReshapeOpByCollapsing
2088 :
public OpRewritePattern<tensor::PadOp> {
2090 FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
2092 PatternBenefit benefit = 1)
2093 : OpRewritePattern<tensor::PadOp>(context, benefit),
2094 controlFoldingReshapes(std::move(foldReshapes)) {}
2096 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2097 PatternRewriter &rewriter)
const override {
2098 tensor::ExpandShapeOp reshapeOp =
2099 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
2103 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
2105 "fusion blocked by control function");
2108 SmallVector<ReassociationIndices> reassociations =
2109 reshapeOp.getReassociationIndices();
2110 FailureOr<PadDimInfo> maybeCollapsedPadding =
2111 computeCollapsedPadding(padOp, reassociations, rewriter);
2112 if (
failed(maybeCollapsedPadding))
2114 PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2116 SmallVector<OpFoldResult> expandedPaddedSizes =
2117 reshapeOp.getMixedOutputShape();
2118 AffineExpr d0, d1, d2;
2121 Location loc = reshapeOp->getLoc();
2122 for (
auto [reInd, l, h] :
2123 llvm::zip_equal(reassociations, collapsedPadding.lowPad,
2124 collapsedPadding.highPad)) {
2125 if (reInd.size() == 1) {
2126 expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
2127 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
2131 RankedTensorType collapsedPaddedType =
2132 padOp.getType().clone(collapsedPadding.paddedShape);
2133 auto newPadOp = tensor::PadOp::create(
2134 rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
2135 collapsedPadding.lowPad, collapsedPadding.highPad,
2136 padOp.getConstantPaddingValue(), padOp.getNofold());
2139 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
2140 expandedPaddedSizes);
2149class FoldReshapeWithProducerPadOpByCollapsing
2150 :
public OpRewritePattern<tensor::CollapseShapeOp> {
2152 FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
2154 PatternBenefit benefit = 1)
2155 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
2156 controlFoldingReshapes(std::move(foldReshapes)) {}
2158 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
2159 PatternRewriter &rewriter)
const override {
2160 tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
2164 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2166 "fusion blocked by control function");
2169 SmallVector<ReassociationIndices> reassociations =
2170 reshapeOp.getReassociationIndices();
2171 RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
2172 FailureOr<PadDimInfo> maybeCollapsedPadding =
2173 computeCollapsedPadding(padOp, reassociations, rewriter);
2174 if (
failed(maybeCollapsedPadding))
2176 PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2178 Location loc = reshapeOp->getLoc();
2179 auto newCollapseOp = tensor::CollapseShapeOp::create(
2180 rewriter, loc, padOp.getSource(), reassociations);
2182 auto newPadOp = tensor::PadOp::create(
2183 rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
2184 collapsedPadding.lowPad, collapsedPadding.highPad,
2185 padOp.getConstantPaddingValue(), padOp.getNofold());
2187 rewriter.
replaceOp(reshapeOp, newPadOp.getResult());
2196template <
typename LinalgType>
2197class CollapseLinalgDimensions :
public OpRewritePattern<LinalgType> {
2199 CollapseLinalgDimensions(MLIRContext *context,
2201 PatternBenefit benefit = 1)
2202 : OpRewritePattern<LinalgType>(context, benefit),
2203 controlCollapseDimension(std::move(collapseDimensions)) {}
2205 LogicalResult matchAndRewrite(LinalgType op,
2206 PatternRewriter &rewriter)
const override {
2207 SmallVector<ReassociationIndices> collapsableIterationDims =
2208 controlCollapseDimension(op);
2209 if (collapsableIterationDims.empty())
2214 collapsableIterationDims)) {
2216 op,
"specified dimensions cannot be collapsed");
2219 std::optional<CollapseResult> collapseResult =
2221 if (!collapseResult) {
2224 rewriter.
replaceOp(op, collapseResult->results);
2241class FoldScalarOrSplatConstant :
public OpRewritePattern<GenericOp> {
2243 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2244 : OpRewritePattern<GenericOp>(context, benefit) {}
2246 LogicalResult matchAndRewrite(GenericOp genericOp,
2247 PatternRewriter &rewriter)
const override {
2248 if (!genericOp.hasPureTensorSemantics())
2250 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2252 TypedAttr constantAttr;
2253 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) ->
bool {
2255 DenseElementsAttr splatAttr;
2258 splatAttr.
getType().getElementType().isIntOrFloat()) {
2264 IntegerAttr intAttr;
2266 constantAttr = intAttr;
2271 FloatAttr floatAttr;
2273 constantAttr = floatAttr;
2280 auto resultValue = dyn_cast<OpResult>(opOperand->
get());
2281 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2287 SmallVector<AffineMap> fusedIndexMaps;
2288 SmallVector<Value> fusedOperands;
2289 SmallVector<Location> fusedLocs{genericOp.getLoc()};
2290 fusedIndexMaps.reserve(genericOp->getNumOperands());
2291 fusedOperands.reserve(genericOp.getNumDpsInputs());
2292 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2293 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2294 if (inputOperand == opOperand)
2296 Value inputValue = inputOperand->get();
2297 fusedIndexMaps.push_back(
2298 genericOp.getMatchingIndexingMap(inputOperand));
2299 fusedOperands.push_back(inputValue);
2300 fusedLocs.push_back(inputValue.
getLoc());
2302 for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2303 fusedIndexMaps.push_back(
2304 genericOp.getMatchingIndexingMap(&outputOperand));
2310 genericOp,
"fused op loop bound computation failed");
2314 Value scalarConstant =
2315 arith::ConstantOp::create(rewriter, def->
getLoc(), constantAttr);
2317 SmallVector<Value> outputOperands = genericOp.getOutputs();
2319 GenericOp::create(rewriter, rewriter.
getFusedLoc(fusedLocs),
2320 genericOp->getResultTypes(),
2324 genericOp.getIteratorTypes(),
2330 Region ®ion = genericOp->getRegion(0);
2335 Region &fusedRegion = fusedOp->getRegion(0);
2338 rewriter.
replaceOp(genericOp, fusedOp->getResults());
2356struct RemoveOutsDependency :
public OpRewritePattern<GenericOp> {
2357 using OpRewritePattern<GenericOp>::OpRewritePattern;
2359 LogicalResult matchAndRewrite(GenericOp op,
2360 PatternRewriter &rewriter)
const override {
2362 bool modifiedOutput =
false;
2363 Location loc = op.getLoc();
2364 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2365 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2366 Value operandVal = opOperand.
get();
2367 auto operandType = dyn_cast<RankedTensorType>(operandVal.
getType());
2376 auto definingOp = operandVal.
getDefiningOp<tensor::EmptyOp>();
2379 modifiedOutput =
true;
2380 SmallVector<OpFoldResult> mixedSizes =
2382 Value emptyTensor = tensor::EmptyOp::create(
2383 rewriter, loc, mixedSizes, operandType.getElementType());
2387 if (!modifiedOutput) {
2397struct FoldFillWithGenericOp :
public OpRewritePattern<GenericOp> {
2398 using OpRewritePattern<GenericOp>::OpRewritePattern;
2400 LogicalResult matchAndRewrite(GenericOp genericOp,
2401 PatternRewriter &rewriter)
const override {
2402 if (!genericOp.hasPureTensorSemantics())
2404 bool fillFound =
false;
2405 Block &payload = genericOp.getRegion().front();
2406 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2407 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2413 Value fillVal = fillOp.value();
2415 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2416 Value convertedVal =
2430 patterns.
add<FoldReshapeWithGenericOpByExpansion>(patterns.
getContext(),
2431 controlFoldingReshapes);
2432 patterns.
add<FoldPadWithProducerReshapeOpByExpansion>(patterns.
getContext(),
2433 controlFoldingReshapes);
2434 patterns.
add<FoldReshapeWithProducerPadOpByExpansion>(patterns.
getContext(),
2435 controlFoldingReshapes);
2436 patterns.
add<FoldWithProducerReshapeOpByExpansion>(patterns.
getContext(),
2437 controlFoldingReshapes);
2443 patterns.
add<FoldWithProducerReshapeOpByCollapsing>(patterns.
getContext(),
2444 controlFoldingReshapes);
2445 patterns.
add<FoldPadWithProducerReshapeOpByCollapsing>(
2446 patterns.
getContext(), controlFoldingReshapes);
2447 patterns.
add<FoldReshapeWithProducerPadOpByCollapsing>(
2448 patterns.
getContext(), controlFoldingReshapes);
2449 patterns.
add<FoldReshapeWithGenericOpByCollapsing>(patterns.
getContext(),
2450 controlFoldingReshapes);
2457 patterns.
add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2458 patterns.
add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2459 RemoveOutsDependency>(context);
2467 patterns.
add<CollapseLinalgDimensions<linalg::GenericOp>,
2468 CollapseLinalgDimensions<linalg::CopyOp>>(
2469 patterns.
getContext(), controlCollapseDimensions);
2484struct LinalgElementwiseOpFusionPass
2485 :
public impl::LinalgElementwiseOpFusionPassBase<
2486 LinalgElementwiseOpFusionPass> {
2487 using impl::LinalgElementwiseOpFusionPassBase<
2488 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2489 void runOnOperation()
override {
2496 Operation *producer = fusedOperand->get().getDefiningOp();
2497 return producer && producer->
hasOneUse();
2506 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2507 GenericOp::getCanonicalizationPatterns(patterns, context);
2508 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2509 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
ArrayRef< ReassociationIndices > getCollapsedOpToOrigOpMapping() const
Return mapping from collapsed loop domain to original loop domain.
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
This transformation is intended to be used with a top-down traversal (from producer to consumer).
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
ArrayRef< int64_t > ReassociationIndicesRef
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
SmallVector< int64_t, 2 > ReassociationIndices
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values (and their transitive dependencies) before insertionPoint.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
static SaturatedInteger wrap(int64_t v)
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.