29#include "llvm/ADT/SmallVectorExtras.h"
34#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
35#include "mlir/Dialect/Linalg/Passes.h.inc"
61 assert(invProducerResultIndexMap &&
62 "expected producer result indexing map to be invertible");
64 LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
66 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
74 return t1.
compose(fusedConsumerArgIndexMap);
81 GenericOp producer, GenericOp consumer,
86 for (
auto &op : ops) {
87 for (
auto &opOperand : op->getOpOperands()) {
88 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
91 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
94 if (indexingMaps.empty()) {
97 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
105 indexingMaps, producer.getContext())) !=
AffineMap();
114 GenericOp producer, GenericOp consumer,
OpOperand *fusedOperand) {
115 llvm::SmallDenseSet<int> preservedProducerResults;
119 opOperandsToIgnore.emplace_back(fusedOperand);
121 for (
const auto &producerResult : llvm::enumerate(producer->getResults())) {
122 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
123 opOperandsToIgnore.emplace_back(outputOperand);
124 if (producer.payloadUsesValueFromOperand(outputOperand) ||
126 opOperandsToIgnore) ||
127 llvm::any_of(producerResult.value().getUsers(), [&](
Operation *user) {
128 return user != consumer.getOperation();
130 preservedProducerResults.insert(producerResult.index());
133 (
void)opOperandsToIgnore.pop_back_val();
136 return preservedProducerResults;
145 auto consumer = dyn_cast<GenericOp>(fusedOperand->
getOwner());
148 if (!producer || !consumer)
154 if (!producer.hasPureTensorSemantics() ||
155 !isa<RankedTensorType>(fusedOperand->
get().
getType()))
160 if (producer.getNumParallelLoops() != producer.getNumLoops())
165 if (!consumer.isDpsInput(fusedOperand))
170 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
171 if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
176 auto producerResult = cast<OpResult>(fusedOperand->
get());
178 producer.getIndexingMapMatchingResult(producerResult);
186 if ((consumer.getNumReductionLoops())) {
187 BitVector coveredDims(consumer.getNumLoops(),
false);
189 auto addToCoveredDims = [&](
AffineMap map) {
190 for (
auto result : map.getResults())
191 if (
auto dimExpr = dyn_cast<AffineDimExpr>(
result))
192 coveredDims[dimExpr.getPosition()] =
true;
196 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
197 Value operand = std::get<0>(pair);
198 if (operand == fusedOperand->
get())
200 AffineMap operandMap = std::get<1>(pair);
201 addToCoveredDims(operandMap);
204 for (
OpOperand *operand : producer.getDpsInputOperands()) {
207 operand, producerResultIndexMap, consumerIndexMap);
208 addToCoveredDims(newIndexingMap);
210 if (!coveredDims.all())
222 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
224 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
226 Block &producerBlock = producer->getRegion(0).
front();
227 Block &consumerBlock = consumer->getRegion(0).
front();
234 if (producer.hasIndexSemantics()) {
236 unsigned numFusedOpLoops = fusedOp.getNumLoops();
238 fusedIndices.reserve(numFusedOpLoops);
239 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
240 std::back_inserter(fusedIndices), [&](uint64_t dim) {
241 return IndexOp::create(rewriter, producer.getLoc(), dim);
243 for (IndexOp indexOp :
244 llvm::make_early_inc_range(producerBlock.
getOps<IndexOp>())) {
245 Value newIndex = affine::AffineApplyOp::create(
246 rewriter, producer.getLoc(),
247 consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
248 mapper.
map(indexOp.getResult(), newIndex);
252 assert(consumer.isDpsInput(fusedOperand) &&
253 "expected producer of input operand");
257 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
264 producerBlock.
getArguments().take_front(producer.getNumDpsInputs()))
265 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
270 .take_front(consumer.getNumDpsInputs())
272 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
275 for (
const auto &bbArg : llvm::enumerate(
276 producerBlock.
getArguments().take_back(producer.getNumDpsInits()))) {
277 if (!preservedProducerResults.count(bbArg.index()))
279 mapper.
map(bbArg.value(), fusedBlock->
addArgument(bbArg.value().getType(),
280 bbArg.value().getLoc()));
285 consumerBlock.
getArguments().take_back(consumer.getNumDpsInits()))
286 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
291 if (!isa<IndexOp>(op))
292 rewriter.
clone(op, mapper);
296 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.
getTerminator());
297 unsigned producerResultNumber =
298 cast<OpResult>(fusedOperand->
get()).getResultNumber();
300 mapper.
lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
304 if (
replacement == producerYieldOp.getOperand(producerResultNumber)) {
305 if (
auto bb = dyn_cast<BlockArgument>(
replacement))
306 assert(bb.getOwner() != &producerBlock &&
307 "yielded block argument must have been mapped");
309 assert(!producer->isAncestor(
replacement.getDefiningOp()) &&
310 "yielded value must have been mapped");
316 rewriter.
clone(op, mapper);
320 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.
getTerminator());
322 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
323 consumerYieldOp.getNumOperands());
324 for (
const auto &producerYieldVal :
325 llvm::enumerate(producerYieldOp.getOperands())) {
326 if (preservedProducerResults.count(producerYieldVal.index()))
327 fusedYieldValues.push_back(
330 for (
auto consumerYieldVal : consumerYieldOp.getOperands())
332 YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);
336 "Ill-formed GenericOp region");
339FailureOr<mlir::linalg::ElementwiseOpFusionResult>
343 "expected elementwise operation pre-conditions to pass");
344 auto producerResult = cast<OpResult>(fusedOperand->
get());
345 auto producer = cast<GenericOp>(producerResult.getOwner());
346 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
348 assert(consumer.isDpsInput(fusedOperand) &&
349 "expected producer of input operand");
352 llvm::SmallDenseSet<int> preservedProducerResults =
360 fusedInputOperands.reserve(producer.getNumDpsInputs() +
361 consumer.getNumDpsInputs());
362 fusedOutputOperands.reserve(preservedProducerResults.size() +
363 consumer.getNumDpsInits());
364 fusedResultTypes.reserve(preservedProducerResults.size() +
365 consumer.getNumDpsInits());
366 fusedIndexMaps.reserve(producer->getNumOperands() +
367 consumer->getNumOperands());
370 auto consumerInputs = consumer.getDpsInputOperands();
371 auto *it = llvm::find_if(consumerInputs, [&](
OpOperand *operand) {
372 return operand == fusedOperand;
374 assert(it != consumerInputs.end() &&
"expected to find the consumer operand");
375 for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
376 fusedInputOperands.push_back(opOperand->get());
377 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
381 producer.getIndexingMapMatchingResult(producerResult);
382 for (
OpOperand *opOperand : producer.getDpsInputOperands()) {
383 fusedInputOperands.push_back(opOperand->get());
386 opOperand, producerResultIndexMap,
387 consumer.getMatchingIndexingMap(fusedOperand));
388 fusedIndexMaps.push_back(map);
393 llvm::make_range(std::next(it), consumerInputs.end())) {
394 fusedInputOperands.push_back(opOperand->get());
395 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
399 for (
const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
400 if (!preservedProducerResults.count(opOperand.index()))
403 fusedOutputOperands.push_back(opOperand.value().get());
405 &opOperand.value(), producerResultIndexMap,
406 consumer.getMatchingIndexingMap(fusedOperand));
407 fusedIndexMaps.push_back(map);
408 fusedResultTypes.push_back(opOperand.value().get().getType());
412 for (
OpOperand &opOperand : consumer.getDpsInitsMutable()) {
413 fusedOutputOperands.push_back(opOperand.get());
414 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
415 Type resultType = opOperand.get().getType();
416 if (!isa<MemRefType>(resultType))
417 fusedResultTypes.push_back(resultType);
421 auto fusedOp = GenericOp::create(
422 rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
424 consumer.getIteratorTypes(),
427 if (!fusedOp.getShapesToLoopsMap()) {
433 fusedOp,
"fused op failed loop bound computation check");
439 consumer.getMatchingIndexingMap(fusedOperand);
443 assert(invProducerResultIndexMap &&
444 "expected producer result indexig map to be invertible");
447 invProducerResultIndexMap.
compose(consumerResultIndexMap);
450 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
451 consumer.getNumLoops(), preservedProducerResults);
455 for (
auto [
index, producerResult] : llvm::enumerate(producer->getResults()))
456 if (preservedProducerResults.count(
index))
457 result.replacements[producerResult] = fusedOp->getResult(resultNum++);
458 for (
auto consumerResult : consumer->getResults())
459 result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
470 controlFn(std::move(fun)) {}
472 LogicalResult matchAndRewrite(GenericOp genericOp,
475 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
478 if (!controlFn(&opOperand))
481 Operation *producer = opOperand.get().getDefiningOp();
484 FailureOr<ElementwiseOpFusionResult> fusionResult =
486 if (failed(fusionResult))
490 for (
auto [origVal,
replacement] : fusionResult->replacements) {
572 linalgOp.getIteratorTypesArray();
573 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
574 return linalgOp.hasPureTensorSemantics() &&
575 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
577 return cast<AffineMapAttr>(attr)
579 .isProjectedPermutation();
593 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
594 ArrayRef<AffineMap> reassociationMaps,
595 ArrayRef<OpFoldResult> expandedShape,
596 PatternRewriter &rewriter);
597 unsigned getOrigOpNumDims()
const {
return reassociation.size(); }
598 unsigned getExpandedOpNumDims()
const {
return expandedOpNumDims; }
600 return reassociation[i];
602 ArrayRef<OpFoldResult> getExpandedShapeOfDim(
unsigned i)
const {
603 return expandedShapeMap[i];
605 ArrayRef<OpFoldResult> getOriginalShape()
const {
return originalLoopExtent; }
610 SmallVector<ReassociationIndices> reassociation;
613 SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
615 SmallVector<OpFoldResult> originalLoopExtent;
616 unsigned expandedOpNumDims;
620LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
625 if (reassociationMaps.empty())
627 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
629 OpBuilder::InsertionGuard g(rewriter);
631 originalLoopExtent = llvm::map_to_vector(
632 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
633 [](Range r) { return r.size; });
635 reassociation.clear();
636 expandedShapeMap.clear();
639 SmallVector<unsigned> numExpandedDims(fusedIndexMap.
getNumDims(), 1);
640 expandedShapeMap.resize(fusedIndexMap.
getNumDims());
641 for (
const auto &resultExpr : llvm::enumerate(fusedIndexMap.
getResults())) {
642 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
643 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
645 ArrayRef<OpFoldResult> shape =
646 expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
647 expandedShapeMap[pos].assign(shape.begin(), shape.end());
650 for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
651 if (expandedShapeMap[i].empty())
652 expandedShapeMap[i] = {originalLoopExtent[i]};
656 reassociation.reserve(fusedIndexMap.
getNumDims());
657 for (
const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
658 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
659 reassociation.emplace_back(seq.begin(), seq.end());
660 sum += numFoldedDim.value();
662 expandedOpNumDims = sum;
670 const ExpansionInfo &expansionInfo) {
673 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
675 expansionInfo.getExpandedDims(pos), [&](
int64_t v) {
676 return builder.getAffineDimExpr(static_cast<unsigned>(v));
678 newExprs.append(expandedExprs.begin(), expandedExprs.end());
687static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
689 const ExpansionInfo &expansionInfo) {
692 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
694 expansionInfo.getExpandedShapeOfDim(dim);
695 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
698 std::tie(expandedStaticShape, std::ignore) =
700 return {expandedShape, RankedTensorType::get(expandedStaticShape,
701 originalType.getElementType())};
710static SmallVector<ReassociationIndices>
712 const ExpansionInfo &expansionInfo) {
714 unsigned numReshapeDims = 0;
716 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
717 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
719 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
720 reassociation.emplace_back(std::move(
indices));
721 numReshapeDims += numExpandedDims;
723 return reassociation;
733 const ExpansionInfo &expansionInfo) {
735 for (IndexOp indexOp :
736 llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
738 expansionInfo.getExpandedDims(indexOp.getDim());
739 assert(!expandedDims.empty() &&
"expected valid expansion info");
742 if (expandedDims.size() == 1 &&
743 expandedDims.front() == (
int64_t)indexOp.getDim())
750 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
752 expandedIndices.reserve(expandedDims.size() - 1);
754 expandedDims.drop_front(), std::back_inserter(expandedIndices),
755 [&](
int64_t dim) { return IndexOp::create(rewriter, loc, dim); });
757 IndexOp::create(rewriter, loc, expandedDims.front()).getResult();
758 for (
auto [expandedShape, expandedIndex] :
759 llvm::zip(expandedDimsShape, expandedIndices)) {
764 rewriter, indexOp.getLoc(), idx +
acc *
shape,
769 rewriter.
replaceOp(indexOp, newIndexVal);
791 TransposeOp transposeOp,
793 ExpansionInfo &expansionInfo) {
796 auto reassoc = expansionInfo.getExpandedDims(perm);
798 newPerm.push_back(dim);
801 return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput,
812 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
814 for (
auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
815 for (
auto j : expansionInfo.getExpandedDims(i))
816 iteratorTypes[
j] = type;
818 Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes,
819 expandedOpOperands, outputs,
820 expandedOpIndexingMaps, iteratorTypes);
823 Region &originalRegion = linalgOp->getRegion(0);
841 ExpansionInfo &expansionInfo) {
844 .Case([&](TransposeOp transposeOp) {
846 expandedOpOperands[0], outputs[0],
849 .Case<FillOp, CopyOp>([&](
Operation *op) {
850 return clone(rewriter, linalgOp, resultTypes,
851 llvm::to_vector(llvm::concat<Value>(
852 llvm::to_vector(expandedOpOperands),
853 llvm::to_vector(outputs))));
857 expandedOpOperands, outputs,
858 expansionInfo, expandedOpIndexingMaps);
865static std::optional<SmallVector<Value>>
870 "preconditions for fuse operation failed");
876 if (
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
880 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
883 expandedShape = expandingReshapeOp.getMixedOutputShape();
884 reassociationIndices = expandingReshapeOp.getReassociationMaps();
885 src = expandingReshapeOp.getSrc();
887 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
888 if (!collapsingReshapeOp)
892 rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
893 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
894 src = collapsingReshapeOp.getSrc();
897 ExpansionInfo expansionInfo;
898 if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
899 reassociationIndices, expandedShape,
904 llvm::map_to_vector<4>(linalgOp.getIndexingMapsArray(), [&](
AffineMap m) {
905 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
913 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
914 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
915 if (opOperand == fusableOpOperand) {
916 expandedOpOperands.push_back(src);
919 if (
auto opOperandType =
920 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
921 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
923 RankedTensorType expandedOperandType;
924 std::tie(expandedOperandShape, expandedOperandType) =
926 if (expandedOperandType != opOperand->get().getType()) {
930 if (failed(reshapeLikeShapesAreCompatible(
931 [&](
const Twine &msg) {
934 opOperandType.getShape(), expandedOperandType.getShape(),
938 expandedOpOperands.push_back(tensor::ExpandShapeOp::create(
939 rewriter, loc, expandedOperandType, opOperand->get(), reassociation,
940 expandedOperandShape));
944 expandedOpOperands.push_back(opOperand->get());
948 for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
949 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
950 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
952 RankedTensorType expandedOutputType;
953 std::tie(expandedOutputShape, expandedOutputType) =
955 if (expandedOutputType != opOperand.get().getType()) {
958 if (failed(reshapeLikeShapesAreCompatible(
959 [&](
const Twine &msg) {
962 opOperandType.getShape(), expandedOutputType.getShape(),
966 outputs.push_back(tensor::ExpandShapeOp::create(
967 rewriter, loc, expandedOutputType, opOperand.get(), reassociation,
968 expandedOutputShape));
970 outputs.push_back(opOperand.get());
977 outputs, expandedOpIndexingMaps, expansionInfo);
981 for (
OpResult opResult : linalgOp->getOpResults()) {
982 int64_t resultNumber = opResult.getResultNumber();
983 if (resultTypes[resultNumber] != opResult.getType()) {
986 linalgOp.getMatchingIndexingMap(
987 linalgOp.getDpsInitOperand(resultNumber)),
989 resultVals.push_back(tensor::CollapseShapeOp::create(
990 rewriter, linalgOp.getLoc(), opResult.getType(),
991 fusedOp->
getResult(resultNumber), reassociation));
993 resultVals.push_back(fusedOp->
getResult(resultNumber));
1005class FoldWithProducerReshapeOpByExpansion
1006 :
public OpInterfaceRewritePattern<LinalgOp> {
1008 FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1010 PatternBenefit benefit = 1)
1011 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1012 controlFoldingReshapes(std::move(foldReshapes)) {}
1014 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1015 PatternRewriter &rewriter)
const override {
1016 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1017 tensor::CollapseShapeOp reshapeOp =
1018 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1025 (!controlFoldingReshapes(opOperand)))
1028 std::optional<SmallVector<Value>> replacementValues =
1030 if (!replacementValues)
1032 rewriter.
replaceOp(linalgOp, *replacementValues);
1045 SmallVector<int64_t> paddedShape;
1048 SmallVector<OpFoldResult> lowPad;
1049 SmallVector<OpFoldResult> highPad;
1056static FailureOr<PadDimInfo>
1057computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
1058 ArrayRef<ReassociationIndices> reassociations,
1059 PatternRewriter &rewriter) {
1066 if (!padOp.getConstantPaddingValue())
1073 ArrayRef<int64_t> low = padOp.getStaticLow();
1074 ArrayRef<int64_t> high = padOp.getStaticHigh();
1075 for (
auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1076 if (reInd.size() != 1 && (l != 0 || h != 0))
1080 SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
1081 SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
1082 ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
1083 PadDimInfo padDimInfo;
1084 padDimInfo.paddedShape.assign(expandedShape);
1085 padDimInfo.lowPad.assign(expandedShape.size(), rewriter.
getIndexAttr(0));
1086 padDimInfo.highPad.assign(expandedShape.size(), rewriter.
getIndexAttr(0));
1087 for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
1088 if (reInd.size() == 1) {
1089 padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
1090 padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
1091 padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
1098class FoldPadWithProducerReshapeOpByExpansion
1099 :
public OpRewritePattern<tensor::PadOp> {
1101 FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1103 PatternBenefit benefit = 1)
1104 : OpRewritePattern<tensor::PadOp>(context, benefit),
1105 controlFoldingReshapes(std::move(foldReshapes)) {}
1107 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1108 PatternRewriter &rewriter)
const override {
1109 tensor::CollapseShapeOp reshapeOp =
1110 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1114 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1116 "fusion blocked by control function");
1119 RankedTensorType expandedType = reshapeOp.getSrcType();
1120 SmallVector<ReassociationIndices> reassociations =
1121 reshapeOp.getReassociationIndices();
1122 FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1123 padOp, expandedType.getShape(), reassociations, rewriter);
1124 if (
failed(maybeExpandedPadding))
1126 PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1128 Location loc = padOp->getLoc();
1129 RankedTensorType expandedPaddedType =
1130 padOp.getResultType().clone(expandedPadding.paddedShape);
1132 auto newPadOp = tensor::PadOp::create(
1133 rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
1134 expandedPadding.lowPad, expandedPadding.highPad,
1135 padOp.getConstantPaddingValue(), padOp.getNofold());
1138 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1147class FoldReshapeWithProducerPadOpByExpansion
1148 :
public OpRewritePattern<tensor::ExpandShapeOp> {
1150 FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
1152 PatternBenefit benefit = 1)
1153 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1154 controlFoldingReshapes(std::move(foldReshapes)) {}
1156 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
1157 PatternRewriter &rewriter)
const override {
1158 tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
1162 if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
1164 "fusion blocked by control function");
1167 RankedTensorType expandedType = expandOp.getResultType();
1168 SmallVector<ReassociationIndices> reassociations =
1169 expandOp.getReassociationIndices();
1170 FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1171 padOp, expandedType.getShape(), reassociations, rewriter);
1172 if (
failed(maybeExpandedPadding))
1174 PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1176 Location loc = expandOp->getLoc();
1177 SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
1178 SmallVector<int64_t> newExpandedShape(expandedType.getShape());
1180 SmallVector<OpFoldResult> padSrcSizes =
1182 for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
1185 if (reInd.size() == 1) {
1186 newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
1187 newExpandedSizes[reInd[0]] = padSrcSizes[idx];
1190 RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
1191 auto newExpandOp = tensor::ExpandShapeOp::create(
1192 rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
1194 RankedTensorType expandedPaddedType =
1195 padOp.getResultType().clone(expandedPadding.paddedShape);
1197 auto newPadOp = tensor::PadOp::create(
1198 rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
1199 expandedPadding.lowPad, expandedPadding.highPad,
1200 padOp.getConstantPaddingValue(), padOp.getNofold());
1202 rewriter.
replaceOp(expandOp, newPadOp.getResult());
1213struct FoldReshapeWithGenericOpByExpansion
1214 :
public OpRewritePattern<tensor::ExpandShapeOp> {
1216 FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1218 PatternBenefit benefit = 1)
1219 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1220 controlFoldingReshapes(std::move(foldReshapes)) {}
1222 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1223 PatternRewriter &rewriter)
const override {
1225 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1226 if (!producerResult) {
1228 "source not produced by an operation");
1231 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1234 "producer not a generic op");
1239 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1241 reshapeOp,
"failed preconditions of fusion with producer generic op");
1244 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1246 "fusion blocked by control function");
1249 std::optional<SmallVector<Value>> replacementValues =
1251 producer, reshapeOp,
1252 producer.getDpsInitOperand(producerResult.getResultNumber()),
1254 if (!replacementValues) {
1256 "fusion by expansion failed");
1263 Value reshapeReplacement =
1264 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1265 .getResultNumber()];
1266 if (
auto collapseOp =
1267 reshapeReplacement.
getDefiningOp<tensor::CollapseShapeOp>()) {
1268 reshapeReplacement = collapseOp.getSrc();
1270 rewriter.
replaceOp(reshapeOp, reshapeReplacement);
1271 rewriter.
replaceOp(producer, *replacementValues);
1293 "expected projected permutation");
1296 llvm::map_to_vector<4>(rangeReassociation, [&](
int64_t pos) ->
int64_t {
1297 return cast<AffineDimExpr>(indexingMap.
getResults()[pos]).getPosition();
1301 return domainReassociation;
1309 assert(!dimSequence.empty() &&
1310 "expected non-empty list for dimension sequence");
1312 "expected indexing map to be projected permutation");
1314 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1315 sequenceElements.insert_range(dimSequence);
1317 unsigned dimSequenceStart = dimSequence[0];
1318 for (
const auto &expr : enumerate(indexingMap.
getResults())) {
1319 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1321 if (dimInMapStart == dimSequenceStart) {
1322 if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
1325 for (
const auto &dimInSequence : enumerate(dimSequence)) {
1327 cast<AffineDimExpr>(
1328 indexingMap.
getResult(expr.index() + dimInSequence.index()))
1330 if (dimInMap != dimInSequence.value())
1341 if (sequenceElements.count(dimInMapStart))
1350 return llvm::all_of(maps, [&](
AffineMap map) {
1407 if (!genericOp.hasPureTensorSemantics())
1410 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
1411 return map.isProjectedPermutation();
1418 genericOp.getReductionDims(reductionDims);
1420 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1421 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1422 auto iteratorTypes = genericOp.getIteratorTypesArray();
1425 assert(!foldedRangeDims.empty() &&
"unexpected empty reassociation");
1428 if (foldedRangeDims.size() == 1)
1436 if (llvm::any_of(foldedIterationSpaceDims, [&](
int64_t dim) {
1437 return processedIterationDims.count(dim);
1442 utils::IteratorType startIteratorType =
1443 iteratorTypes[foldedIterationSpaceDims[0]];
1447 if (llvm::any_of(foldedIterationSpaceDims, [&](
int64_t dim) {
1448 return iteratorTypes[dim] != startIteratorType;
1457 bool isContiguous =
false;
1458 for (
const auto &startDim : llvm::enumerate(reductionDims)) {
1460 if (startDim.value() != foldedIterationSpaceDims[0])
1464 if (startDim.index() + foldedIterationSpaceDims.size() >
1465 reductionDims.size())
1468 isContiguous =
true;
1469 for (
const auto &foldedDim :
1470 llvm::enumerate(foldedIterationSpaceDims)) {
1471 if (reductionDims[foldedDim.index() + startDim.index()] !=
1472 foldedDim.value()) {
1473 isContiguous =
false;
1484 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1486 return !isDimSequencePreserved(indexingMap,
1487 foldedIterationSpaceDims);
1491 processedIterationDims.insert_range(foldedIterationSpaceDims);
1492 iterationSpaceReassociation.emplace_back(
1493 std::move(foldedIterationSpaceDims));
1496 return iterationSpaceReassociation;
1501class CollapsingInfo {
1503 LogicalResult
initialize(
unsigned origNumLoops,
1504 ArrayRef<ReassociationIndices> foldedIterationDims) {
1505 llvm::SmallDenseSet<int64_t, 4> processedDims;
1508 if (foldedIterationDim.empty())
1512 for (
auto dim : foldedIterationDim) {
1513 if (dim >= origNumLoops)
1515 if (processedDims.count(dim))
1517 processedDims.insert(dim);
1519 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1520 foldedIterationDim.end());
1522 if (processedDims.size() > origNumLoops)
1527 for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1528 if (processedDims.count(dim))
1533 llvm::sort(collapsedOpToOrigOpIterationDim,
1537 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1538 for (
const auto &foldedDims :
1539 llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1540 for (
const auto &dim :
enumerate(foldedDims.value()))
1541 origOpToCollapsedOpIterationDim[dim.value()] =
1542 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1549 return collapsedOpToOrigOpIterationDim;
1572 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping()
const {
1573 return origOpToCollapsedOpIterationDim;
1577 unsigned getCollapsedOpIterationRank()
const {
1578 return collapsedOpToOrigOpIterationDim.size();
1584 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1588 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1594static SmallVector<utils::IteratorType>
1595getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
1596 const CollapsingInfo &collapsingInfo) {
1597 SmallVector<utils::IteratorType> collapsedIteratorTypes;
1599 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1600 assert(!foldedIterDims.empty() &&
1601 "reassociation indices expected to have non-empty sets");
1605 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1607 return collapsedIteratorTypes;
1613getCollapsedOpIndexingMap(AffineMap indexingMap,
1614 const CollapsingInfo &collapsingInfo) {
1615 MLIRContext *context = indexingMap.
getContext();
1617 "expected indexing map to be projected permutation");
1618 SmallVector<AffineExpr> resultExprs;
1619 auto origOpToCollapsedOpMapping =
1620 collapsingInfo.getOrigOpToCollapsedOpMapping();
1622 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1624 if (origOpToCollapsedOpMapping[dim].second != 0)
1628 resultExprs.push_back(
1631 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1632 resultExprs, context);
1637static SmallVector<ReassociationIndices>
1638getOperandReassociation(AffineMap indexingMap,
1639 const CollapsingInfo &collapsingInfo) {
1640 unsigned counter = 0;
1641 SmallVector<ReassociationIndices> operandReassociation;
1642 auto origOpToCollapsedOpMapping =
1643 collapsingInfo.getOrigOpToCollapsedOpMapping();
1644 auto collapsedOpToOrigOpMapping =
1645 collapsingInfo.getCollapsedOpToOrigOpMapping();
1648 cast<AffineDimExpr>(indexingMap.
getResult(counter)).getPosition();
1652 unsigned numFoldedDims =
1653 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1655 if (origOpToCollapsedOpMapping[dim].second == 0) {
1656 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1657 operandReassociation.emplace_back(range.begin(), range.end());
1659 counter += numFoldedDims;
1661 return operandReassociation;
1665static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1666 OpOperand *opOperand,
1667 const CollapsingInfo &collapsingInfo,
1668 OpBuilder &builder) {
1669 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1670 SmallVector<ReassociationIndices> operandReassociation =
1671 getOperandReassociation(indexingMap, collapsingInfo);
1676 Value operand = opOperand->
get();
1677 if (operandReassociation.size() == indexingMap.
getNumResults())
1681 if (isa<MemRefType>(operand.
getType())) {
1682 return memref::CollapseShapeOp::create(builder, loc, operand,
1683 operandReassociation)
1686 return tensor::CollapseShapeOp::create(builder, loc, operand,
1687 operandReassociation)
1693static void generateCollapsedIndexingRegion(
1694 Location loc,
Block *block,
const CollapsingInfo &collapsingInfo,
1695 ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1696 OpBuilder::InsertionGuard g(rewriter);
1700 auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
1709 llvm::DenseMap<unsigned, Value> indexReplacementVals;
1710 for (
auto foldedDims :
1711 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1714 linalg::IndexOp::create(rewriter, loc, foldedDims.index());
1715 for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1718 indexReplacementVals[dim] =
1719 rewriter.
createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1721 rewriter.
createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1723 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1726 for (
auto indexOp : indexOps) {
1727 auto dim = indexOp.getDim();
1728 rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
1732static void collapseOperandsAndResults(LinalgOp op,
1733 const CollapsingInfo &collapsingInfo,
1734 RewriterBase &rewriter,
1735 SmallVectorImpl<Value> &inputOperands,
1736 SmallVectorImpl<Value> &outputOperands,
1737 SmallVectorImpl<Type> &resultTypes) {
1738 Location loc = op->getLoc();
1740 llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1741 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1746 resultTypes.reserve(op.getNumDpsInits());
1747 outputOperands.reserve(op.getNumDpsInits());
1748 for (OpOperand &output : op.getDpsInitsMutable()) {
1750 getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1751 outputOperands.push_back(newOutput);
1754 if (!op.hasPureBufferSemantics())
1755 resultTypes.push_back(newOutput.
getType());
1760template <
typename OpTy>
1761static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1762 const CollapsingInfo &collapsingInfo) {
1769LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1770 const CollapsingInfo &collapsingInfo) {
1771 SmallVector<Value> inputOperands, outputOperands;
1772 SmallVector<Type> resultTypes;
1773 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1774 outputOperands, resultTypes);
1777 rewriter, origOp, resultTypes,
1778 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1783GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1785 const CollapsingInfo &collapsingInfo) {
1786 SmallVector<Value> inputOperands, outputOperands;
1787 SmallVector<Type> resultTypes;
1788 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1789 outputOperands, resultTypes);
1790 SmallVector<AffineMap> indexingMaps(
1791 llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1792 return getCollapsedOpIndexingMap(map, collapsingInfo);
1795 SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1796 origOp.getIteratorTypesArray(), collapsingInfo));
1798 GenericOp collapsedOp = linalg::GenericOp::create(
1799 rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands,
1800 indexingMaps, iteratorTypes,
1801 [](OpBuilder &builder, Location loc,
ValueRange args) {});
1802 Block *origOpBlock = &origOp->getRegion(0).front();
1803 Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1804 rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
1809static LinalgOp createCollapsedOp(LinalgOp op,
1810 const CollapsingInfo &collapsingInfo,
1811 RewriterBase &rewriter) {
1812 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1813 return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1815 return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1820 LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1821 RewriterBase &rewriter) {
1823 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1825 return foldedDims.size() <= 1;
1829 CollapsingInfo collapsingInfo;
1831 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1833 op,
"illegal to collapse specified dimensions");
1836 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1837 if (hasPureBufferSemantics &&
1838 !llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) ->
bool {
1839 MemRefType memRefToCollapse =
1840 dyn_cast<MemRefType>(opOperand.get().getType());
1841 if (!memRefToCollapse)
1844 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1845 SmallVector<ReassociationIndices> operandReassociation =
1846 getOperandReassociation(indexingMap, collapsingInfo);
1847 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1848 memRefToCollapse, operandReassociation);
1851 "memref is not guaranteed collapsible");
1854 SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1855 auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1856 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1857 return cast<IntegerAttr>(attr).getInt() == value;
1860 actual.getSExtValue() == value;
1862 if (!llvm::all_of(loopRanges, [&](Range range) {
1863 return opFoldIsConstantValue(range.
offset, 0) &&
1864 opFoldIsConstantValue(range.
stride, 1);
1867 op,
"expected all loop ranges to have zero start and unit stride");
1870 LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1872 Location loc = op->getLoc();
1873 SmallVector<OpFoldResult> loopBound =
1874 llvm::map_to_vector(loopRanges, [](Range range) {
return range.
size; });
1876 if (collapsedOp.hasIndexSemantics()) {
1878 OpBuilder::InsertionGuard g(rewriter);
1880 generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1881 collapsingInfo, loopBound, rewriter);
1886 SmallVector<Value> results;
1887 for (
const auto &originalResult : llvm::enumerate(op->getResults())) {
1888 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1889 auto originalResultType =
1890 cast<ShapedType>(originalResult.value().getType());
1891 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.
getType());
1892 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1893 AffineMap indexingMap =
1894 op.getIndexingMapMatchingResult(originalResult.value());
1895 SmallVector<ReassociationIndices> reassociation =
1896 getOperandReassociation(indexingMap, collapsingInfo);
1899 "Expected indexing map to be a projected permutation for collapsing");
1900 SmallVector<OpFoldResult> resultShape =
1903 if (isa<MemRefType>(collapsedOpResult.
getType())) {
1904 result = memref::ExpandShapeOp::create(
1905 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1908 result = tensor::ExpandShapeOp::create(
1909 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1912 results.push_back(
result);
1914 results.push_back(collapsedOpResult);
1917 return CollapseResult{results, collapsedOp};
1924class FoldWithProducerReshapeOpByCollapsing
1925 :
public OpRewritePattern<GenericOp> {
1928 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1930 PatternBenefit benefit = 1)
1931 : OpRewritePattern<GenericOp>(context, benefit),
1932 controlFoldingReshapes(std::move(foldReshapes)) {}
1934 LogicalResult matchAndRewrite(GenericOp genericOp,
1935 PatternRewriter &rewriter)
const override {
1936 for (OpOperand &opOperand : genericOp->getOpOperands()) {
1937 tensor::ExpandShapeOp reshapeOp =
1942 SmallVector<ReassociationIndices> collapsableIterationDims =
1944 reshapeOp.getReassociationIndices());
1945 if (collapsableIterationDims.empty() ||
1946 !controlFoldingReshapes(&opOperand)) {
1951 genericOp, collapsableIterationDims, rewriter);
1952 if (!collapseResult) {
1954 genericOp,
"failed to do the fusion by collapsing transformation");
1957 rewriter.
replaceOp(genericOp, collapseResult->results);
1969struct FoldReshapeWithGenericOpByCollapsing
1970 :
public OpRewritePattern<tensor::CollapseShapeOp> {
1972 FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1974 PatternBenefit benefit = 1)
1975 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1976 controlFoldingReshapes(std::move(foldReshapes)) {}
1978 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1979 PatternRewriter &rewriter)
const override {
1982 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1983 if (!producerResult) {
1985 "source not produced by an operation");
1989 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1992 "producer not a generic op");
1995 SmallVector<ReassociationIndices> collapsableIterationDims =
1998 producer.getDpsInitOperand(producerResult.getResultNumber()),
1999 reshapeOp.getReassociationIndices());
2000 if (collapsableIterationDims.empty()) {
2002 reshapeOp,
"failed preconditions of fusion with producer generic op");
2005 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2007 "fusion blocked by control function");
2013 std::optional<CollapseResult> collapseResult =
2015 if (!collapseResult) {
2017 producer,
"failed to do the fusion by collapsing transformation");
2020 rewriter.
replaceOp(producer, collapseResult->results);
2032static FailureOr<PadDimInfo>
2033computeCollapsedPadding(tensor::PadOp padOp,
2034 ArrayRef<ReassociationIndices> reassociations,
2035 PatternRewriter &rewriter) {
2042 if (!padOp.getConstantPaddingValue())
2049 ArrayRef<int64_t> low = padOp.getStaticLow();
2050 ArrayRef<int64_t> high = padOp.getStaticHigh();
2051 for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
2052 for (int64_t dim : reInd) {
2053 if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
2059 ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
2060 PadDimInfo padDimInfo;
2061 padDimInfo.lowPad.assign(reassociations.size(), rewriter.
getIndexAttr(0));
2062 padDimInfo.highPad.assign(reassociations.size(), rewriter.
getIndexAttr(0));
2066 SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
2067 SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
2068 for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
2069 if (reInd.size() == 1) {
2070 padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]];
2071 padDimInfo.highPad[idx] = mixedHighPad[reInd[0]];
2074 for (int64_t dim : reInd) {
2078 padDimInfo.paddedShape.push_back(collapsedSize.
asInteger());
2084class FoldPadWithProducerReshapeOpByCollapsing
2085 :
public OpRewritePattern<tensor::PadOp> {
2087 FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
2089 PatternBenefit benefit = 1)
2090 : OpRewritePattern<tensor::PadOp>(context, benefit),
2091 controlFoldingReshapes(std::move(foldReshapes)) {}
2093 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2094 PatternRewriter &rewriter)
const override {
2095 tensor::ExpandShapeOp reshapeOp =
2096 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
2100 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
2102 "fusion blocked by control function");
2105 SmallVector<ReassociationIndices> reassociations =
2106 reshapeOp.getReassociationIndices();
2107 FailureOr<PadDimInfo> maybeCollapsedPadding =
2108 computeCollapsedPadding(padOp, reassociations, rewriter);
2109 if (
failed(maybeCollapsedPadding))
2111 PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2113 SmallVector<OpFoldResult> expandedPaddedSizes =
2114 reshapeOp.getMixedOutputShape();
2115 AffineExpr d0, d1, d2;
2118 Location loc = reshapeOp->getLoc();
2119 for (
auto [reInd, l, h] :
2120 llvm::zip_equal(reassociations, collapsedPadding.lowPad,
2121 collapsedPadding.highPad)) {
2122 if (reInd.size() == 1) {
2124 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
2128 RankedTensorType collapsedPaddedType =
2129 padOp.getType().clone(collapsedPadding.paddedShape);
2130 auto newPadOp = tensor::PadOp::create(
2131 rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
2132 collapsedPadding.lowPad, collapsedPadding.highPad,
2133 padOp.getConstantPaddingValue(), padOp.getNofold());
2136 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
2137 expandedPaddedSizes);
2146class FoldReshapeWithProducerPadOpByCollapsing
2147 :
public OpRewritePattern<tensor::CollapseShapeOp> {
2149 FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
2151 PatternBenefit benefit = 1)
2152 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
2153 controlFoldingReshapes(std::move(foldReshapes)) {}
2155 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
2156 PatternRewriter &rewriter)
const override {
2157 tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
2161 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2163 "fusion blocked by control function");
2166 SmallVector<ReassociationIndices> reassociations =
2167 reshapeOp.getReassociationIndices();
2168 RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
2169 FailureOr<PadDimInfo> maybeCollapsedPadding =
2170 computeCollapsedPadding(padOp, reassociations, rewriter);
2171 if (
failed(maybeCollapsedPadding))
2173 PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2175 Location loc = reshapeOp->getLoc();
2176 auto newCollapseOp = tensor::CollapseShapeOp::create(
2177 rewriter, loc, padOp.getSource(), reassociations);
2179 auto newPadOp = tensor::PadOp::create(
2180 rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
2181 collapsedPadding.lowPad, collapsedPadding.highPad,
2182 padOp.getConstantPaddingValue(), padOp.getNofold());
2184 rewriter.
replaceOp(reshapeOp, newPadOp.getResult());
2193template <
typename LinalgType>
2194class CollapseLinalgDimensions :
public OpRewritePattern<LinalgType> {
2196 CollapseLinalgDimensions(MLIRContext *context,
2198 PatternBenefit benefit = 1)
2199 : OpRewritePattern<LinalgType>(context, benefit),
2200 controlCollapseDimension(std::move(collapseDimensions)) {}
2202 LogicalResult matchAndRewrite(LinalgType op,
2203 PatternRewriter &rewriter)
const override {
2204 SmallVector<ReassociationIndices> collapsableIterationDims =
2205 controlCollapseDimension(op);
2206 if (collapsableIterationDims.empty())
2211 collapsableIterationDims)) {
2213 op,
"specified dimensions cannot be collapsed");
2216 std::optional<CollapseResult> collapseResult =
2218 if (!collapseResult) {
2221 rewriter.
replaceOp(op, collapseResult->results);
2238class FoldScalarOrSplatConstant :
public OpRewritePattern<GenericOp> {
2240 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2241 : OpRewritePattern<GenericOp>(context, benefit) {}
2243 LogicalResult matchAndRewrite(GenericOp genericOp,
2244 PatternRewriter &rewriter)
const override {
2245 if (!genericOp.hasPureTensorSemantics())
2247 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2249 TypedAttr constantAttr;
2250 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) ->
bool {
2252 DenseElementsAttr splatAttr;
2255 splatAttr.
getType().getElementType().isIntOrFloat()) {
2261 IntegerAttr intAttr;
2263 constantAttr = intAttr;
2268 FloatAttr floatAttr;
2270 constantAttr = floatAttr;
2277 auto resultValue = dyn_cast<OpResult>(opOperand->
get());
2278 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2284 SmallVector<AffineMap> fusedIndexMaps;
2285 SmallVector<Value> fusedOperands;
2286 SmallVector<Location> fusedLocs{genericOp.getLoc()};
2287 fusedIndexMaps.reserve(genericOp->getNumOperands());
2288 fusedOperands.reserve(genericOp.getNumDpsInputs());
2289 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2290 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2291 if (inputOperand == opOperand)
2293 Value inputValue = inputOperand->get();
2294 fusedIndexMaps.push_back(
2295 genericOp.getMatchingIndexingMap(inputOperand));
2296 fusedOperands.push_back(inputValue);
2297 fusedLocs.push_back(inputValue.
getLoc());
2299 for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2300 fusedIndexMaps.push_back(
2301 genericOp.getMatchingIndexingMap(&outputOperand));
2307 genericOp,
"fused op loop bound computation failed");
2311 Value scalarConstant =
2312 arith::ConstantOp::create(rewriter, def->
getLoc(), constantAttr);
2314 SmallVector<Value> outputOperands = genericOp.getOutputs();
2316 GenericOp::create(rewriter, rewriter.
getFusedLoc(fusedLocs),
2317 genericOp->getResultTypes(),
2321 genericOp.getIteratorTypes(),
2327 Region ®ion = genericOp->getRegion(0);
2332 Region &fusedRegion = fusedOp->getRegion(0);
2335 rewriter.
replaceOp(genericOp, fusedOp->getResults());
2353struct RemoveOutsDependency :
public OpRewritePattern<GenericOp> {
2354 using OpRewritePattern<GenericOp>::OpRewritePattern;
2356 LogicalResult matchAndRewrite(GenericOp op,
2357 PatternRewriter &rewriter)
const override {
2359 bool modifiedOutput =
false;
2360 Location loc = op.getLoc();
2361 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2362 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2363 Value operandVal = opOperand.
get();
2364 auto operandType = dyn_cast<RankedTensorType>(operandVal.
getType());
2373 auto definingOp = operandVal.
getDefiningOp<tensor::EmptyOp>();
2376 modifiedOutput =
true;
2377 SmallVector<OpFoldResult> mixedSizes =
2379 Value emptyTensor = tensor::EmptyOp::create(
2380 rewriter, loc, mixedSizes, operandType.getElementType());
2384 if (!modifiedOutput) {
2394struct FoldFillWithGenericOp :
public OpRewritePattern<GenericOp> {
2395 using OpRewritePattern<GenericOp>::OpRewritePattern;
2397 LogicalResult matchAndRewrite(GenericOp genericOp,
2398 PatternRewriter &rewriter)
const override {
2399 if (!genericOp.hasPureTensorSemantics())
2401 bool fillFound =
false;
2402 Block &payload = genericOp.getRegion().front();
2403 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2404 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2410 Value fillVal = fillOp.value();
2412 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2413 Value convertedVal =
2428 controlFoldingReshapes);
2429 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(
patterns.getContext(),
2430 controlFoldingReshapes);
2431 patterns.add<FoldReshapeWithProducerPadOpByExpansion>(
patterns.getContext(),
2432 controlFoldingReshapes);
2434 controlFoldingReshapes);
2441 controlFoldingReshapes);
2442 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2443 patterns.getContext(), controlFoldingReshapes);
2444 patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
2445 patterns.getContext(), controlFoldingReshapes);
2447 controlFoldingReshapes);
2453 auto *context =
patterns.getContext();
2454 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2455 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2456 RemoveOutsDependency>(context);
2464 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2465 CollapseLinalgDimensions<linalg::CopyOp>>(
2466 patterns.getContext(), controlCollapseDimensions);
2481struct LinalgElementwiseOpFusionPass
2483 LinalgElementwiseOpFusionPass> {
2485 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2486 void runOnOperation()
override {
2493 Operation *producer = fusedOperand->get().getDefiningOp();
2494 return producer && producer->
hasOneUse();
2503 affine::AffineApplyOp::getCanonicalizationPatterns(
patterns, context);
2504 GenericOp::getCanonicalizationPatterns(
patterns, context);
2505 tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
2506 tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
ArrayRef< ReassociationIndices > getCollapsedOpToOrigOpMapping() const
Return mapping from collapsed loop domain to original loop domain.
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
This transformation is intended to be used with a top-down traversal (from producer to consumer).
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
ArrayRef< int64_t > ReassociationIndicesRef
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
SmallVector< int64_t, 2 > ReassociationIndices
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values (and their transitive dependencies) before insertionPoint.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
static SaturatedInteger wrap(int64_t v)
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.