60 assert(invProducerResultIndexMap &&
61 "expected producer result indexing map to be invertible");
63 LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
65 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
73 return t1.
compose(fusedConsumerArgIndexMap);
80 GenericOp producer, GenericOp consumer,
85 for (
auto &op : ops) {
86 for (
auto &opOperand : op->getOpOperands()) {
87 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
90 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
93 if (indexingMaps.empty()) {
96 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
104 indexingMaps, producer.getContext())) !=
AffineMap();
113 GenericOp producer, GenericOp consumer,
OpOperand *fusedOperand) {
114 llvm::SmallDenseSet<int> preservedProducerResults;
118 opOperandsToIgnore.emplace_back(fusedOperand);
120 for (
const auto &producerResult : llvm::enumerate(producer->getResults())) {
121 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122 opOperandsToIgnore.emplace_back(outputOperand);
123 if (producer.payloadUsesValueFromOperand(outputOperand) ||
125 opOperandsToIgnore) ||
126 llvm::any_of(producerResult.value().getUsers(), [&](
Operation *user) {
127 return user != consumer.getOperation();
129 preservedProducerResults.insert(producerResult.index());
132 (
void)opOperandsToIgnore.pop_back_val();
135 return preservedProducerResults;
144 auto consumer = dyn_cast<GenericOp>(fusedOperand->
getOwner());
147 if (!producer || !consumer)
153 if (!producer.hasPureTensorSemantics() ||
154 !isa<RankedTensorType>(fusedOperand->
get().
getType()))
159 if (producer.getNumParallelLoops() != producer.getNumLoops())
164 if (!consumer.isDpsInput(fusedOperand))
169 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170 if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
175 auto producerResult = cast<OpResult>(fusedOperand->
get());
177 producer.getIndexingMapMatchingResult(producerResult);
185 if ((consumer.getNumReductionLoops())) {
186 BitVector coveredDims(consumer.getNumLoops(),
false);
188 auto addToCoveredDims = [&](
AffineMap map) {
189 for (
auto result : map.getResults())
190 if (
auto dimExpr = dyn_cast<AffineDimExpr>(
result))
191 coveredDims[dimExpr.getPosition()] =
true;
195 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
196 Value operand = std::get<0>(pair);
197 if (operand == fusedOperand->
get())
199 AffineMap operandMap = std::get<1>(pair);
200 addToCoveredDims(operandMap);
203 for (
OpOperand *operand : producer.getDpsInputOperands()) {
206 operand, producerResultIndexMap, consumerIndexMap);
207 addToCoveredDims(newIndexingMap);
209 if (!coveredDims.all())
221 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
223 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
225 Block &producerBlock = producer->getRegion(0).
front();
226 Block &consumerBlock = consumer->getRegion(0).
front();
233 if (producer.hasIndexSemantics()) {
235 unsigned numFusedOpLoops = fusedOp.getNumLoops();
237 fusedIndices.reserve(numFusedOpLoops);
238 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
239 std::back_inserter(fusedIndices), [&](uint64_t dim) {
240 return IndexOp::create(rewriter, producer.getLoc(), dim);
242 for (IndexOp indexOp :
243 llvm::make_early_inc_range(producerBlock.
getOps<IndexOp>())) {
244 Value newIndex = affine::AffineApplyOp::create(
245 rewriter, producer.getLoc(),
246 consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
247 mapper.
map(indexOp.getResult(), newIndex);
251 assert(consumer.isDpsInput(fusedOperand) &&
252 "expected producer of input operand");
256 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
263 producerBlock.
getArguments().take_front(producer.getNumDpsInputs()))
264 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
269 .take_front(consumer.getNumDpsInputs())
271 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
274 for (
const auto &bbArg : llvm::enumerate(
275 producerBlock.
getArguments().take_back(producer.getNumDpsInits()))) {
276 if (!preservedProducerResults.count(bbArg.index()))
278 mapper.
map(bbArg.value(), fusedBlock->
addArgument(bbArg.value().getType(),
279 bbArg.value().getLoc()));
284 consumerBlock.
getArguments().take_back(consumer.getNumDpsInits()))
285 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
290 if (!isa<IndexOp>(op))
291 rewriter.
clone(op, mapper);
295 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.
getTerminator());
296 unsigned producerResultNumber =
297 cast<OpResult>(fusedOperand->
get()).getResultNumber();
299 mapper.
lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
303 if (
replacement == producerYieldOp.getOperand(producerResultNumber)) {
304 if (
auto bb = dyn_cast<BlockArgument>(
replacement))
305 assert(bb.getOwner() != &producerBlock &&
306 "yielded block argument must have been mapped");
308 assert(!producer->isAncestor(
replacement.getDefiningOp()) &&
309 "yielded value must have been mapped");
315 rewriter.
clone(op, mapper);
319 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.
getTerminator());
321 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
322 consumerYieldOp.getNumOperands());
323 for (
const auto &producerYieldVal :
324 llvm::enumerate(producerYieldOp.getOperands())) {
325 if (preservedProducerResults.count(producerYieldVal.index()))
326 fusedYieldValues.push_back(
329 for (
auto consumerYieldVal : consumerYieldOp.getOperands())
331 YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);
335 "Ill-formed GenericOp region");
338FailureOr<mlir::linalg::ElementwiseOpFusionResult>
342 "expected elementwise operation pre-conditions to pass");
343 auto producerResult = cast<OpResult>(fusedOperand->
get());
344 auto producer = cast<GenericOp>(producerResult.getOwner());
345 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
347 assert(consumer.isDpsInput(fusedOperand) &&
348 "expected producer of input operand");
351 llvm::SmallDenseSet<int> preservedProducerResults =
359 fusedInputOperands.reserve(producer.getNumDpsInputs() +
360 consumer.getNumDpsInputs());
361 fusedOutputOperands.reserve(preservedProducerResults.size() +
362 consumer.getNumDpsInits());
363 fusedResultTypes.reserve(preservedProducerResults.size() +
364 consumer.getNumDpsInits());
365 fusedIndexMaps.reserve(producer->getNumOperands() +
366 consumer->getNumOperands());
369 auto consumerInputs = consumer.getDpsInputOperands();
370 auto *it = llvm::find_if(consumerInputs, [&](
OpOperand *operand) {
371 return operand == fusedOperand;
373 assert(it != consumerInputs.end() &&
"expected to find the consumer operand");
374 for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
375 fusedInputOperands.push_back(opOperand->get());
376 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
380 producer.getIndexingMapMatchingResult(producerResult);
381 for (
OpOperand *opOperand : producer.getDpsInputOperands()) {
382 fusedInputOperands.push_back(opOperand->get());
385 opOperand, producerResultIndexMap,
386 consumer.getMatchingIndexingMap(fusedOperand));
387 fusedIndexMaps.push_back(map);
392 llvm::make_range(std::next(it), consumerInputs.end())) {
393 fusedInputOperands.push_back(opOperand->get());
394 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
398 for (
const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
399 if (!preservedProducerResults.count(opOperand.index()))
402 fusedOutputOperands.push_back(opOperand.value().get());
404 &opOperand.value(), producerResultIndexMap,
405 consumer.getMatchingIndexingMap(fusedOperand));
406 fusedIndexMaps.push_back(map);
407 fusedResultTypes.push_back(opOperand.value().get().getType());
411 for (
OpOperand &opOperand : consumer.getDpsInitsMutable()) {
412 fusedOutputOperands.push_back(opOperand.get());
413 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
414 Type resultType = opOperand.get().getType();
415 if (!isa<MemRefType>(resultType))
416 fusedResultTypes.push_back(resultType);
420 auto fusedOp = GenericOp::create(
421 rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
423 consumer.getIteratorTypes(),
426 if (!fusedOp.getShapesToLoopsMap()) {
432 fusedOp,
"fused op failed loop bound computation check");
438 consumer.getMatchingIndexingMap(fusedOperand);
442 assert(invProducerResultIndexMap &&
443 "expected producer result indexig map to be invertible");
446 invProducerResultIndexMap.
compose(consumerResultIndexMap);
449 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
450 consumer.getNumLoops(), preservedProducerResults);
454 for (
auto [
index, producerResult] : llvm::enumerate(producer->getResults()))
455 if (preservedProducerResults.count(
index))
456 result.replacements[producerResult] = fusedOp->getResult(resultNum++);
457 for (
auto consumerResult : consumer->getResults())
458 result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
469 controlFn(std::move(fun)) {}
471 LogicalResult matchAndRewrite(GenericOp genericOp,
474 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
477 if (!controlFn(&opOperand))
480 Operation *producer = opOperand.get().getDefiningOp();
483 FailureOr<ElementwiseOpFusionResult> fusionResult =
485 if (failed(fusionResult))
489 for (
auto [origVal,
replacement] : fusionResult->replacements) {
571 linalgOp.getIteratorTypesArray();
572 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573 return linalgOp.hasPureTensorSemantics() &&
574 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
576 return cast<AffineMapAttr>(attr)
578 .isProjectedPermutation();
592 LogicalResult compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
596 unsigned getOrigOpNumDims()
const {
return reassociation.size(); }
597 unsigned getExpandedOpNumDims()
const {
return expandedOpNumDims; }
599 return reassociation[i];
602 return expandedShapeMap[i];
604 ArrayRef<OpFoldResult> getOriginalShape()
const {
return originalLoopExtent; }
615 unsigned expandedOpNumDims;
619LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
624 if (reassociationMaps.empty())
626 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
628 OpBuilder::InsertionGuard g(rewriter);
630 originalLoopExtent = llvm::map_to_vector(
631 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
632 [](
Range r) { return r.size; });
634 reassociation.clear();
635 expandedShapeMap.clear();
638 SmallVector<unsigned> numExpandedDims(fusedIndexMap.
getNumDims(), 1);
639 expandedShapeMap.resize(fusedIndexMap.
getNumDims());
640 for (
const auto &resultExpr : llvm::enumerate(fusedIndexMap.
getResults())) {
641 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
642 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
644 ArrayRef<OpFoldResult> shape =
645 expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
646 expandedShapeMap[pos].assign(shape.begin(), shape.end());
649 for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
650 if (expandedShapeMap[i].empty())
651 expandedShapeMap[i] = {originalLoopExtent[i]};
655 reassociation.reserve(fusedIndexMap.
getNumDims());
656 for (
const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
657 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
658 reassociation.emplace_back(seq.begin(), seq.end());
659 sum += numFoldedDim.value();
661 expandedOpNumDims = sum;
669 const ExpansionInfo &expansionInfo) {
672 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
674 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](
int64_t v) {
675 return builder.getAffineDimExpr(static_cast<unsigned>(v));
677 newExprs.append(expandedExprs.begin(), expandedExprs.end());
686static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
688 const ExpansionInfo &expansionInfo) {
691 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
693 expansionInfo.getExpandedShapeOfDim(dim);
694 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
697 std::tie(expandedStaticShape, std::ignore) =
699 return {expandedShape, RankedTensorType::get(expandedStaticShape,
700 originalType.getElementType())};
709static SmallVector<ReassociationIndices>
711 const ExpansionInfo &expansionInfo) {
713 unsigned numReshapeDims = 0;
715 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
716 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
718 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
719 reassociation.emplace_back(std::move(
indices));
720 numReshapeDims += numExpandedDims;
722 return reassociation;
732 const ExpansionInfo &expansionInfo) {
734 for (IndexOp indexOp :
735 llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
737 expansionInfo.getExpandedDims(indexOp.getDim());
738 assert(!expandedDims.empty() &&
"expected valid expansion info");
741 if (expandedDims.size() == 1 &&
742 expandedDims.front() == (
int64_t)indexOp.getDim())
749 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
751 expandedIndices.reserve(expandedDims.size() - 1);
753 expandedDims.drop_front(), std::back_inserter(expandedIndices),
754 [&](
int64_t dim) { return IndexOp::create(rewriter, loc, dim); });
756 IndexOp::create(rewriter, loc, expandedDims.front()).getResult();
757 for (
auto [expandedShape, expandedIndex] :
758 llvm::zip(expandedDimsShape, expandedIndices)) {
763 rewriter, indexOp.getLoc(), idx +
acc *
shape,
768 rewriter.
replaceOp(indexOp, newIndexVal);
790 TransposeOp transposeOp,
792 ExpansionInfo &expansionInfo) {
795 auto reassoc = expansionInfo.getExpandedDims(perm);
797 newPerm.push_back(dim);
800 return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput,
811 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
813 for (
auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
814 for (
auto j : expansionInfo.getExpandedDims(i))
815 iteratorTypes[
j] = type;
817 Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes,
818 expandedOpOperands, outputs,
819 expandedOpIndexingMaps, iteratorTypes);
822 Region &originalRegion = linalgOp->getRegion(0);
840 ExpansionInfo &expansionInfo) {
843 .Case<TransposeOp>([&](TransposeOp transposeOp) {
845 expandedOpOperands[0], outputs[0],
848 .Case<FillOp, CopyOp>([&](
Operation *op) {
849 return clone(rewriter, linalgOp, resultTypes,
850 llvm::to_vector(llvm::concat<Value>(
851 llvm::to_vector(expandedOpOperands),
852 llvm::to_vector(outputs))));
856 expandedOpOperands, outputs,
857 expansionInfo, expandedOpIndexingMaps);
864static std::optional<SmallVector<Value>>
869 "preconditions for fuse operation failed");
875 if (
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
879 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
882 expandedShape = expandingReshapeOp.getMixedOutputShape();
883 reassociationIndices = expandingReshapeOp.getReassociationMaps();
884 src = expandingReshapeOp.getSrc();
886 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887 if (!collapsingReshapeOp)
891 rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
892 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
893 src = collapsingReshapeOp.getSrc();
896 ExpansionInfo expansionInfo;
897 if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
898 reassociationIndices, expandedShape,
903 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](
AffineMap m) {
904 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
912 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
913 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
914 if (opOperand == fusableOpOperand) {
915 expandedOpOperands.push_back(src);
918 if (
auto opOperandType =
919 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
920 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
922 RankedTensorType expandedOperandType;
923 std::tie(expandedOperandShape, expandedOperandType) =
925 if (expandedOperandType != opOperand->get().getType()) {
929 if (failed(reshapeLikeShapesAreCompatible(
930 [&](
const Twine &msg) {
933 opOperandType.getShape(), expandedOperandType.getShape(),
937 expandedOpOperands.push_back(tensor::ExpandShapeOp::create(
938 rewriter, loc, expandedOperandType, opOperand->get(), reassociation,
939 expandedOperandShape));
943 expandedOpOperands.push_back(opOperand->get());
947 for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
948 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
949 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
951 RankedTensorType expandedOutputType;
952 std::tie(expandedOutputShape, expandedOutputType) =
954 if (expandedOutputType != opOperand.get().getType()) {
957 if (failed(reshapeLikeShapesAreCompatible(
958 [&](
const Twine &msg) {
961 opOperandType.getShape(), expandedOutputType.getShape(),
965 outputs.push_back(tensor::ExpandShapeOp::create(
966 rewriter, loc, expandedOutputType, opOperand.get(), reassociation,
967 expandedOutputShape));
969 outputs.push_back(opOperand.get());
976 outputs, expandedOpIndexingMaps, expansionInfo);
980 for (
OpResult opResult : linalgOp->getOpResults()) {
981 int64_t resultNumber = opResult.getResultNumber();
982 if (resultTypes[resultNumber] != opResult.getType()) {
985 linalgOp.getMatchingIndexingMap(
986 linalgOp.getDpsInitOperand(resultNumber)),
988 resultVals.push_back(tensor::CollapseShapeOp::create(
989 rewriter, linalgOp.getLoc(), opResult.getType(),
990 fusedOp->
getResult(resultNumber), reassociation));
992 resultVals.push_back(fusedOp->
getResult(resultNumber));
1004class FoldWithProducerReshapeOpByExpansion
1005 :
public OpInterfaceRewritePattern<LinalgOp> {
1007 FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1009 PatternBenefit benefit = 1)
1010 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1011 controlFoldingReshapes(std::move(foldReshapes)) {}
1013 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1014 PatternRewriter &rewriter)
const override {
1015 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1016 tensor::CollapseShapeOp reshapeOp =
1017 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1024 (!controlFoldingReshapes(opOperand)))
1027 std::optional<SmallVector<Value>> replacementValues =
1029 if (!replacementValues)
1031 rewriter.
replaceOp(linalgOp, *replacementValues);
1041class FoldPadWithProducerReshapeOpByExpansion
1042 :
public OpRewritePattern<tensor::PadOp> {
1044 FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1046 PatternBenefit benefit = 1)
1047 : OpRewritePattern<tensor::PadOp>(context, benefit),
1048 controlFoldingReshapes(std::move(foldReshapes)) {}
1050 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1051 PatternRewriter &rewriter)
const override {
1052 tensor::CollapseShapeOp reshapeOp =
1053 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1056 if (!reshapeOp->hasOneUse())
1059 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1061 "fusion blocked by control function");
1064 ArrayRef<int64_t> low = padOp.getStaticLow();
1065 ArrayRef<int64_t> high = padOp.getStaticHigh();
1066 SmallVector<ReassociationIndices> reassociations =
1067 reshapeOp.getReassociationIndices();
1069 for (
auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1070 if (reInd.size() != 1 && (l != 0 || h != 0))
1074 SmallVector<OpFoldResult> newLow, newHigh;
1075 RankedTensorType expandedType = reshapeOp.getSrcType();
1076 RankedTensorType paddedType = padOp.getResultType();
1077 SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
1078 for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
1079 if (reInd.size() == 1) {
1080 expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1082 for (
size_t i = 0; i < reInd.size(); ++i) {
1083 newLow.push_back(padOp.getMixedLowPad()[idx]);
1084 newHigh.push_back(padOp.getMixedHighPad()[idx]);
1088 Location loc = padOp->getLoc();
1089 RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1090 auto newPadOp = tensor::PadOp::create(
1091 rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1092 padOp.getConstantPaddingValue(), padOp.getNofold());
1095 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1106struct FoldReshapeWithGenericOpByExpansion
1107 :
public OpRewritePattern<tensor::ExpandShapeOp> {
1109 FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1111 PatternBenefit benefit = 1)
1112 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1113 controlFoldingReshapes(std::move(foldReshapes)) {}
1115 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1116 PatternRewriter &rewriter)
const override {
1118 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1119 if (!producerResult) {
1121 "source not produced by an operation");
1124 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1127 "producer not a generic op");
1132 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1134 reshapeOp,
"failed preconditions of fusion with producer generic op");
1137 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1139 "fusion blocked by control function");
1142 std::optional<SmallVector<Value>> replacementValues =
1144 producer, reshapeOp,
1145 producer.getDpsInitOperand(producerResult.getResultNumber()),
1147 if (!replacementValues) {
1149 "fusion by expansion failed");
1156 Value reshapeReplacement =
1157 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1158 .getResultNumber()];
1159 if (
auto collapseOp =
1160 reshapeReplacement.
getDefiningOp<tensor::CollapseShapeOp>()) {
1161 reshapeReplacement = collapseOp.getSrc();
1163 rewriter.
replaceOp(reshapeOp, reshapeReplacement);
1164 rewriter.
replaceOp(producer, *replacementValues);
1186 "expected projected permutation");
1189 llvm::map_range(rangeReassociation, [&](
int64_t pos) ->
int64_t {
1190 return cast<AffineDimExpr>(indexingMap.
getResults()[pos]).getPosition();
1194 return domainReassociation;
1202 assert(!dimSequence.empty() &&
1203 "expected non-empty list for dimension sequence");
1205 "expected indexing map to be projected permutation");
1207 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1208 sequenceElements.insert_range(dimSequence);
1210 unsigned dimSequenceStart = dimSequence[0];
1211 for (
const auto &expr : enumerate(indexingMap.
getResults())) {
1212 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1214 if (dimInMapStart == dimSequenceStart) {
1215 if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
1218 for (
const auto &dimInSequence : enumerate(dimSequence)) {
1220 cast<AffineDimExpr>(
1221 indexingMap.
getResult(expr.index() + dimInSequence.index()))
1223 if (dimInMap != dimInSequence.value())
1234 if (sequenceElements.count(dimInMapStart))
1243 return llvm::all_of(maps, [&](
AffineMap map) {
1300 if (!genericOp.hasPureTensorSemantics())
1303 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
1304 return map.isProjectedPermutation();
1311 genericOp.getReductionDims(reductionDims);
1313 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1314 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1315 auto iteratorTypes = genericOp.getIteratorTypesArray();
1318 assert(!foldedRangeDims.empty() &&
"unexpected empty reassociation");
1321 if (foldedRangeDims.size() == 1)
1329 if (llvm::any_of(foldedIterationSpaceDims, [&](
int64_t dim) {
1330 return processedIterationDims.count(dim);
1335 utils::IteratorType startIteratorType =
1336 iteratorTypes[foldedIterationSpaceDims[0]];
1340 if (llvm::any_of(foldedIterationSpaceDims, [&](
int64_t dim) {
1341 return iteratorTypes[dim] != startIteratorType;
1350 bool isContiguous =
false;
1351 for (
const auto &startDim : llvm::enumerate(reductionDims)) {
1353 if (startDim.value() != foldedIterationSpaceDims[0])
1357 if (startDim.index() + foldedIterationSpaceDims.size() >
1358 reductionDims.size())
1361 isContiguous =
true;
1362 for (
const auto &foldedDim :
1363 llvm::enumerate(foldedIterationSpaceDims)) {
1364 if (reductionDims[foldedDim.index() + startDim.index()] !=
1365 foldedDim.value()) {
1366 isContiguous =
false;
1377 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1379 return !isDimSequencePreserved(indexingMap,
1380 foldedIterationSpaceDims);
1384 processedIterationDims.insert_range(foldedIterationSpaceDims);
1385 iterationSpaceReassociation.emplace_back(
1386 std::move(foldedIterationSpaceDims));
1389 return iterationSpaceReassociation;
1394class CollapsingInfo {
1396 LogicalResult
initialize(
unsigned origNumLoops,
1397 ArrayRef<ReassociationIndices> foldedIterationDims) {
1398 llvm::SmallDenseSet<int64_t, 4> processedDims;
1401 if (foldedIterationDim.empty())
1405 for (
auto dim : foldedIterationDim) {
1406 if (dim >= origNumLoops)
1408 if (processedDims.count(dim))
1410 processedDims.insert(dim);
1412 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1413 foldedIterationDim.end());
1415 if (processedDims.size() > origNumLoops)
1420 for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1421 if (processedDims.count(dim))
1426 llvm::sort(collapsedOpToOrigOpIterationDim,
1430 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1431 for (
const auto &foldedDims :
1432 llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1433 for (
const auto &dim :
enumerate(foldedDims.value()))
1434 origOpToCollapsedOpIterationDim[dim.value()] =
1435 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1442 return collapsedOpToOrigOpIterationDim;
1465 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping()
const {
1466 return origOpToCollapsedOpIterationDim;
1470 unsigned getCollapsedOpIterationRank()
const {
1471 return collapsedOpToOrigOpIterationDim.size();
1477 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1481 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1487static SmallVector<utils::IteratorType>
1488getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
1489 const CollapsingInfo &collapsingInfo) {
1490 SmallVector<utils::IteratorType> collapsedIteratorTypes;
1492 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1493 assert(!foldedIterDims.empty() &&
1494 "reassociation indices expected to have non-empty sets");
1498 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1500 return collapsedIteratorTypes;
1506getCollapsedOpIndexingMap(AffineMap indexingMap,
1507 const CollapsingInfo &collapsingInfo) {
1508 MLIRContext *context = indexingMap.
getContext();
1510 "expected indexing map to be projected permutation");
1511 SmallVector<AffineExpr> resultExprs;
1512 auto origOpToCollapsedOpMapping =
1513 collapsingInfo.getOrigOpToCollapsedOpMapping();
1515 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1517 if (origOpToCollapsedOpMapping[dim].second != 0)
1521 resultExprs.push_back(
1524 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1525 resultExprs, context);
1530static SmallVector<ReassociationIndices>
1531getOperandReassociation(AffineMap indexingMap,
1532 const CollapsingInfo &collapsingInfo) {
1533 unsigned counter = 0;
1534 SmallVector<ReassociationIndices> operandReassociation;
1535 auto origOpToCollapsedOpMapping =
1536 collapsingInfo.getOrigOpToCollapsedOpMapping();
1537 auto collapsedOpToOrigOpMapping =
1538 collapsingInfo.getCollapsedOpToOrigOpMapping();
1541 cast<AffineDimExpr>(indexingMap.
getResult(counter)).getPosition();
1545 unsigned numFoldedDims =
1546 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1548 if (origOpToCollapsedOpMapping[dim].second == 0) {
1549 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1550 operandReassociation.emplace_back(range.begin(), range.end());
1552 counter += numFoldedDims;
1554 return operandReassociation;
1558static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1559 OpOperand *opOperand,
1560 const CollapsingInfo &collapsingInfo,
1561 OpBuilder &builder) {
1562 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1563 SmallVector<ReassociationIndices> operandReassociation =
1564 getOperandReassociation(indexingMap, collapsingInfo);
1569 Value operand = opOperand->
get();
1570 if (operandReassociation.size() == indexingMap.
getNumResults())
1574 if (isa<MemRefType>(operand.
getType())) {
1575 return memref::CollapseShapeOp::create(builder, loc, operand,
1576 operandReassociation)
1579 return tensor::CollapseShapeOp::create(builder, loc, operand,
1580 operandReassociation)
1586static void generateCollapsedIndexingRegion(
1587 Location loc,
Block *block,
const CollapsingInfo &collapsingInfo,
1588 ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1589 OpBuilder::InsertionGuard g(rewriter);
1593 auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
1602 llvm::DenseMap<unsigned, Value> indexReplacementVals;
1603 for (
auto foldedDims :
1604 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1607 linalg::IndexOp::create(rewriter, loc, foldedDims.index());
1608 for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1611 indexReplacementVals[dim] =
1612 rewriter.
createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1614 rewriter.
createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1616 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1619 for (
auto indexOp : indexOps) {
1620 auto dim = indexOp.getDim();
1621 rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
1625static void collapseOperandsAndResults(LinalgOp op,
1626 const CollapsingInfo &collapsingInfo,
1627 RewriterBase &rewriter,
1628 SmallVectorImpl<Value> &inputOperands,
1629 SmallVectorImpl<Value> &outputOperands,
1630 SmallVectorImpl<Type> &resultTypes) {
1631 Location loc = op->getLoc();
1633 llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1634 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1639 resultTypes.reserve(op.getNumDpsInits());
1640 outputOperands.reserve(op.getNumDpsInits());
1641 for (OpOperand &output : op.getDpsInitsMutable()) {
1643 getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1644 outputOperands.push_back(newOutput);
1647 if (!op.hasPureBufferSemantics())
1648 resultTypes.push_back(newOutput.
getType());
1653template <
typename OpTy>
1654static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1655 const CollapsingInfo &collapsingInfo) {
1662LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1663 const CollapsingInfo &collapsingInfo) {
1664 SmallVector<Value> inputOperands, outputOperands;
1665 SmallVector<Type> resultTypes;
1666 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1667 outputOperands, resultTypes);
1670 rewriter, origOp, resultTypes,
1671 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1676GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1678 const CollapsingInfo &collapsingInfo) {
1679 SmallVector<Value> inputOperands, outputOperands;
1680 SmallVector<Type> resultTypes;
1681 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1682 outputOperands, resultTypes);
1683 SmallVector<AffineMap> indexingMaps(
1684 llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1685 return getCollapsedOpIndexingMap(map, collapsingInfo);
1688 SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1689 origOp.getIteratorTypesArray(), collapsingInfo));
1691 GenericOp collapsedOp = linalg::GenericOp::create(
1692 rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands,
1693 indexingMaps, iteratorTypes,
1694 [](OpBuilder &builder, Location loc,
ValueRange args) {});
1695 Block *origOpBlock = &origOp->getRegion(0).front();
1696 Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1697 rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
1702static LinalgOp createCollapsedOp(LinalgOp op,
1703 const CollapsingInfo &collapsingInfo,
1704 RewriterBase &rewriter) {
1705 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1706 return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1708 return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1714 LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1715 RewriterBase &rewriter) {
1717 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1719 return foldedDims.size() <= 1;
1723 CollapsingInfo collapsingInfo;
1725 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1727 op,
"illegal to collapse specified dimensions");
1730 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1731 if (hasPureBufferSemantics &&
1732 !llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) ->
bool {
1733 MemRefType memRefToCollapse =
1734 dyn_cast<MemRefType>(opOperand.get().getType());
1735 if (!memRefToCollapse)
1738 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1739 SmallVector<ReassociationIndices> operandReassociation =
1740 getOperandReassociation(indexingMap, collapsingInfo);
1741 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1742 memRefToCollapse, operandReassociation);
1745 "memref is not guaranteed collapsible");
1748 SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1749 auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1750 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1751 return cast<IntegerAttr>(attr).getInt() == value;
1754 actual.getSExtValue() == value;
1756 if (!llvm::all_of(loopRanges, [&](Range range) {
1757 return opFoldIsConstantValue(range.
offset, 0) &&
1758 opFoldIsConstantValue(range.
stride, 1);
1761 op,
"expected all loop ranges to have zero start and unit stride");
1764 LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1766 Location loc = op->getLoc();
1767 SmallVector<OpFoldResult> loopBound =
1768 llvm::map_to_vector(loopRanges, [](Range range) {
return range.
size; });
1770 if (collapsedOp.hasIndexSemantics()) {
1772 OpBuilder::InsertionGuard g(rewriter);
1774 generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1775 collapsingInfo, loopBound, rewriter);
1780 SmallVector<Value> results;
1781 for (
const auto &originalResult : llvm::enumerate(op->getResults())) {
1782 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1783 auto originalResultType =
1784 cast<ShapedType>(originalResult.value().getType());
1785 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.
getType());
1786 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1787 AffineMap indexingMap =
1788 op.getIndexingMapMatchingResult(originalResult.value());
1789 SmallVector<ReassociationIndices> reassociation =
1790 getOperandReassociation(indexingMap, collapsingInfo);
1793 "Expected indexing map to be a projected permutation for collapsing");
1794 SmallVector<OpFoldResult> resultShape =
1797 if (isa<MemRefType>(collapsedOpResult.
getType())) {
1798 MemRefType expandShapeResultType = MemRefType::get(
1799 originalResultType.getShape(), originalResultType.getElementType());
1800 result = memref::ExpandShapeOp::create(
1801 rewriter, loc, expandShapeResultType, collapsedOpResult,
1802 reassociation, resultShape);
1804 result = tensor::ExpandShapeOp::create(
1805 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1808 results.push_back(
result);
1810 results.push_back(collapsedOpResult);
1813 return CollapseResult{results, collapsedOp};
1820class FoldWithProducerReshapeOpByCollapsing
1821 :
public OpRewritePattern<GenericOp> {
1824 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1826 PatternBenefit benefit = 1)
1827 : OpRewritePattern<GenericOp>(context, benefit),
1828 controlFoldingReshapes(std::move(foldReshapes)) {}
1830 LogicalResult matchAndRewrite(GenericOp genericOp,
1831 PatternRewriter &rewriter)
const override {
1832 for (OpOperand &opOperand : genericOp->getOpOperands()) {
1833 tensor::ExpandShapeOp reshapeOp =
1838 SmallVector<ReassociationIndices> collapsableIterationDims =
1840 reshapeOp.getReassociationIndices());
1841 if (collapsableIterationDims.empty() ||
1842 !controlFoldingReshapes(&opOperand)) {
1847 genericOp, collapsableIterationDims, rewriter);
1848 if (!collapseResult) {
1850 genericOp,
"failed to do the fusion by collapsing transformation");
1853 rewriter.
replaceOp(genericOp, collapseResult->results);
1865struct FoldReshapeWithGenericOpByCollapsing
1866 :
public OpRewritePattern<tensor::CollapseShapeOp> {
1868 FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1870 PatternBenefit benefit = 1)
1871 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1872 controlFoldingReshapes(std::move(foldReshapes)) {}
1874 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1875 PatternRewriter &rewriter)
const override {
1878 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1879 if (!producerResult) {
1881 "source not produced by an operation");
1885 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1888 "producer not a generic op");
1891 SmallVector<ReassociationIndices> collapsableIterationDims =
1894 producer.getDpsInitOperand(producerResult.getResultNumber()),
1895 reshapeOp.getReassociationIndices());
1896 if (collapsableIterationDims.empty()) {
1898 reshapeOp,
"failed preconditions of fusion with producer generic op");
1901 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1903 "fusion blocked by control function");
1909 std::optional<CollapseResult> collapseResult =
1911 if (!collapseResult) {
1913 producer,
"failed to do the fusion by collapsing transformation");
1916 rewriter.
replaceOp(producer, collapseResult->results);
1924class FoldPadWithProducerReshapeOpByCollapsing
1925 :
public OpRewritePattern<tensor::PadOp> {
1927 FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1929 PatternBenefit benefit = 1)
1930 : OpRewritePattern<tensor::PadOp>(context, benefit),
1931 controlFoldingReshapes(std::move(foldReshapes)) {}
1933 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1934 PatternRewriter &rewriter)
const override {
1935 tensor::ExpandShapeOp reshapeOp =
1936 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1939 if (!reshapeOp->hasOneUse())
1942 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1944 "fusion blocked by control function");
1947 ArrayRef<int64_t> low = padOp.getStaticLow();
1948 ArrayRef<int64_t> high = padOp.getStaticHigh();
1949 SmallVector<ReassociationIndices> reassociations =
1950 reshapeOp.getReassociationIndices();
1952 for (
auto reInd : reassociations) {
1953 if (reInd.size() == 1)
1955 if (llvm::any_of(reInd, [&](int64_t ind) {
1956 return low[ind] != 0 || high[ind] != 0;
1962 SmallVector<OpFoldResult> newLow, newHigh;
1963 RankedTensorType collapsedType = reshapeOp.getSrcType();
1964 RankedTensorType paddedType = padOp.getResultType();
1965 SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1966 SmallVector<OpFoldResult> expandedPaddedSizes(
1968 reshapeOp.getOutputShape(), rewriter));
1969 AffineExpr d0, d1, d2;
1972 Location loc = reshapeOp->getLoc();
1973 for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
1974 OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1975 OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
1976 if (reInd.size() == 1) {
1977 collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1979 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1980 expandedPaddedSizes[reInd[0]] = paddedSize;
1982 newLow.push_back(l);
1983 newHigh.push_back(h);
1986 RankedTensorType collapsedPaddedType =
1987 paddedType.clone(collapsedPaddedShape);
1988 auto newPadOp = tensor::PadOp::create(
1989 rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1990 padOp.getConstantPaddingValue(), padOp.getNofold());
1993 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1994 expandedPaddedSizes);
2004template <
typename LinalgType>
2005class CollapseLinalgDimensions :
public OpRewritePattern<LinalgType> {
2007 CollapseLinalgDimensions(MLIRContext *context,
2009 PatternBenefit benefit = 1)
2010 : OpRewritePattern<LinalgType>(context, benefit),
2011 controlCollapseDimension(std::move(collapseDimensions)) {}
2013 LogicalResult matchAndRewrite(LinalgType op,
2014 PatternRewriter &rewriter)
const override {
2015 SmallVector<ReassociationIndices> collapsableIterationDims =
2016 controlCollapseDimension(op);
2017 if (collapsableIterationDims.empty())
2022 collapsableIterationDims)) {
2024 op,
"specified dimensions cannot be collapsed");
2027 std::optional<CollapseResult> collapseResult =
2029 if (!collapseResult) {
2032 rewriter.
replaceOp(op, collapseResult->results);
2049class FoldScalarOrSplatConstant :
public OpRewritePattern<GenericOp> {
2051 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2052 : OpRewritePattern<GenericOp>(context, benefit) {}
2054 LogicalResult matchAndRewrite(GenericOp genericOp,
2055 PatternRewriter &rewriter)
const override {
2056 if (!genericOp.hasPureTensorSemantics())
2058 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2060 TypedAttr constantAttr;
2061 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) ->
bool {
2063 DenseElementsAttr splatAttr;
2066 splatAttr.
getType().getElementType().isIntOrFloat()) {
2072 IntegerAttr intAttr;
2074 constantAttr = intAttr;
2079 FloatAttr floatAttr;
2081 constantAttr = floatAttr;
2088 auto resultValue = dyn_cast<OpResult>(opOperand->
get());
2089 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2095 SmallVector<AffineMap> fusedIndexMaps;
2096 SmallVector<Value> fusedOperands;
2097 SmallVector<Location> fusedLocs{genericOp.getLoc()};
2098 fusedIndexMaps.reserve(genericOp->getNumOperands());
2099 fusedOperands.reserve(genericOp.getNumDpsInputs());
2100 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2101 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2102 if (inputOperand == opOperand)
2104 Value inputValue = inputOperand->get();
2105 fusedIndexMaps.push_back(
2106 genericOp.getMatchingIndexingMap(inputOperand));
2107 fusedOperands.push_back(inputValue);
2108 fusedLocs.push_back(inputValue.
getLoc());
2110 for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2111 fusedIndexMaps.push_back(
2112 genericOp.getMatchingIndexingMap(&outputOperand));
2118 genericOp,
"fused op loop bound computation failed");
2122 Value scalarConstant =
2123 arith::ConstantOp::create(rewriter, def->
getLoc(), constantAttr);
2125 SmallVector<Value> outputOperands = genericOp.getOutputs();
2127 GenericOp::create(rewriter, rewriter.
getFusedLoc(fusedLocs),
2128 genericOp->getResultTypes(),
2132 genericOp.getIteratorTypes(),
2138 Region ®ion = genericOp->getRegion(0);
2143 Region &fusedRegion = fusedOp->getRegion(0);
2146 rewriter.
replaceOp(genericOp, fusedOp->getResults());
2164struct RemoveOutsDependency :
public OpRewritePattern<GenericOp> {
2165 using OpRewritePattern<GenericOp>::OpRewritePattern;
2167 LogicalResult matchAndRewrite(GenericOp op,
2168 PatternRewriter &rewriter)
const override {
2170 bool modifiedOutput =
false;
2171 Location loc = op.getLoc();
2172 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2173 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2174 Value operandVal = opOperand.
get();
2175 auto operandType = dyn_cast<RankedTensorType>(operandVal.
getType());
2184 auto definingOp = operandVal.
getDefiningOp<tensor::EmptyOp>();
2187 modifiedOutput =
true;
2188 SmallVector<OpFoldResult> mixedSizes =
2190 Value emptyTensor = tensor::EmptyOp::create(
2191 rewriter, loc, mixedSizes, operandType.getElementType());
2195 if (!modifiedOutput) {
2205struct FoldFillWithGenericOp :
public OpRewritePattern<GenericOp> {
2206 using OpRewritePattern<GenericOp>::OpRewritePattern;
2208 LogicalResult matchAndRewrite(GenericOp genericOp,
2209 PatternRewriter &rewriter)
const override {
2210 if (!genericOp.hasPureTensorSemantics())
2212 bool fillFound =
false;
2213 Block &payload = genericOp.getRegion().front();
2214 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2215 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2221 Value fillVal = fillOp.value();
2223 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2224 Value convertedVal =
2239 controlFoldingReshapes);
2240 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(
patterns.getContext(),
2241 controlFoldingReshapes);
2243 controlFoldingReshapes);
2250 controlFoldingReshapes);
2251 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2252 patterns.getContext(), controlFoldingReshapes);
2254 controlFoldingReshapes);
2260 auto *context =
patterns.getContext();
2261 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2262 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2263 RemoveOutsDependency>(context);
2271 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2272 CollapseLinalgDimensions<linalg::CopyOp>>(
2273 patterns.getContext(), controlCollapseDimensions);
2288struct LinalgElementwiseOpFusionPass
2290 LinalgElementwiseOpFusionPass> {
2292 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2293 void runOnOperation()
override {
2300 Operation *producer = fusedOperand->get().getDefiningOp();
2301 return producer && producer->
hasOneUse();
2310 affine::AffineApplyOp::getCanonicalizationPatterns(
patterns, context);
2311 GenericOp::getCanonicalizationPatterns(
patterns, context);
2312 tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
2313 tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);