33 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
34 #include "mlir/Dialect/Linalg/Passes.h.inc"
60 assert(invProducerResultIndexMap &&
61 "expected producer result indexing map to be invertible");
63 LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
65 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
73 return t1.
compose(fusedConsumerArgIndexMap);
80 GenericOp producer, GenericOp consumer,
85 for (
auto &op : ops) {
86 for (
auto &opOperand : op->getOpOperands()) {
87 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
90 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
93 if (indexingMaps.empty()) {
96 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
104 indexingMaps, producer.getContext())) !=
AffineMap();
113 GenericOp producer, GenericOp consumer,
OpOperand *fusedOperand) {
114 llvm::SmallDenseSet<int> preservedProducerResults;
118 opOperandsToIgnore.emplace_back(fusedOperand);
120 for (
const auto &producerResult :
llvm::enumerate(producer->getResults())) {
121 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122 opOperandsToIgnore.emplace_back(outputOperand);
123 if (producer.payloadUsesValueFromOperand(outputOperand) ||
125 opOperandsToIgnore) ||
126 llvm::any_of(producerResult.value().getUsers(), [&](
Operation *user) {
127 return user != consumer.getOperation();
129 preservedProducerResults.insert(producerResult.index());
132 (void)opOperandsToIgnore.pop_back_val();
135 return preservedProducerResults;
144 auto consumer = dyn_cast<GenericOp>(fusedOperand->
getOwner());
147 if (!producer || !consumer)
153 if (!producer.hasPureTensorSemantics() ||
154 !isa<RankedTensorType>(fusedOperand->
get().
getType()))
159 if (producer.getNumParallelLoops() != producer.getNumLoops())
164 if (!consumer.isDpsInput(fusedOperand))
169 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170 if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
176 producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
184 if ((consumer.getNumReductionLoops())) {
185 BitVector coveredDims(consumer.getNumLoops(),
false);
187 auto addToCoveredDims = [&](
AffineMap map) {
188 for (
auto result : map.getResults())
189 if (
auto dimExpr = dyn_cast<AffineDimExpr>(result))
190 coveredDims[dimExpr.getPosition()] =
true;
194 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
195 Value operand = std::get<0>(pair);
196 if (operand == fusedOperand->
get())
198 AffineMap operandMap = std::get<1>(pair);
199 addToCoveredDims(operandMap);
202 for (
OpOperand *operand : producer.getDpsInputOperands()) {
205 operand, producerResultIndexMap, consumerIndexMap);
206 addToCoveredDims(newIndexingMap);
208 if (!coveredDims.all())
220 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
222 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
224 Block &producerBlock = producer->getRegion(0).
front();
225 Block &consumerBlock = consumer->getRegion(0).
front();
232 if (producer.hasIndexSemantics()) {
234 unsigned numFusedOpLoops = fusedOp.getNumLoops();
236 fusedIndices.reserve(numFusedOpLoops);
237 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
238 std::back_inserter(fusedIndices), [&](uint64_t dim) {
239 return rewriter.
create<IndexOp>(producer.getLoc(), dim);
241 for (IndexOp indexOp :
242 llvm::make_early_inc_range(producerBlock.
getOps<IndexOp>())) {
243 Value newIndex = rewriter.
create<affine::AffineApplyOp>(
245 consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
246 mapper.
map(indexOp.getResult(), newIndex);
250 assert(consumer.isDpsInput(fusedOperand) &&
251 "expected producer of input operand");
255 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
262 producerBlock.
getArguments().take_front(producer.getNumDpsInputs()))
263 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
268 .take_front(consumer.getNumDpsInputs())
270 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
274 producerBlock.
getArguments().take_back(producer.getNumDpsInits()))) {
275 if (!preservedProducerResults.count(bbArg.index()))
277 mapper.
map(bbArg.value(), fusedBlock->
addArgument(bbArg.value().getType(),
278 bbArg.value().getLoc()));
283 consumerBlock.
getArguments().take_back(consumer.getNumDpsInits()))
284 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
289 if (!isa<IndexOp>(op))
290 rewriter.
clone(op, mapper);
294 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.
getTerminator());
295 unsigned producerResultNumber =
296 cast<OpResult>(fusedOperand->
get()).getResultNumber();
298 mapper.
lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
302 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
303 if (
auto bb = dyn_cast<BlockArgument>(replacement))
304 assert(bb.getOwner() != &producerBlock &&
305 "yielded block argument must have been mapped");
308 "yielded value must have been mapped");
314 rewriter.
clone(op, mapper);
318 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.
getTerminator());
320 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
321 consumerYieldOp.getNumOperands());
322 for (
const auto &producerYieldVal :
324 if (preservedProducerResults.count(producerYieldVal.index()))
325 fusedYieldValues.push_back(
328 for (
auto consumerYieldVal : consumerYieldOp.getOperands())
330 rewriter.
create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
334 "Ill-formed GenericOp region");
337 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
341 "expected elementwise operation pre-conditions to pass");
342 auto producerResult = cast<OpResult>(fusedOperand->
get());
343 auto producer = cast<GenericOp>(producerResult.getOwner());
344 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
346 assert(consumer.isDpsInput(fusedOperand) &&
347 "expected producer of input operand");
350 llvm::SmallDenseSet<int> preservedProducerResults =
358 fusedInputOperands.reserve(producer.getNumDpsInputs() +
359 consumer.getNumDpsInputs());
360 fusedOutputOperands.reserve(preservedProducerResults.size() +
361 consumer.getNumDpsInits());
362 fusedResultTypes.reserve(preservedProducerResults.size() +
363 consumer.getNumDpsInits());
364 fusedIndexMaps.reserve(producer->getNumOperands() +
365 consumer->getNumOperands());
368 auto consumerInputs = consumer.getDpsInputOperands();
369 auto *it = llvm::find_if(consumerInputs, [&](
OpOperand *operand) {
370 return operand == fusedOperand;
372 assert(it != consumerInputs.end() &&
"expected to find the consumer operand");
373 for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
374 fusedInputOperands.push_back(opOperand->get());
375 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
379 producer.getIndexingMapMatchingResult(producerResult);
380 for (
OpOperand *opOperand : producer.getDpsInputOperands()) {
381 fusedInputOperands.push_back(opOperand->get());
384 opOperand, producerResultIndexMap,
385 consumer.getMatchingIndexingMap(fusedOperand));
386 fusedIndexMaps.push_back(map);
391 llvm::make_range(std::next(it), consumerInputs.end())) {
392 fusedInputOperands.push_back(opOperand->get());
393 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
397 for (
const auto &opOperand :
llvm::enumerate(producer.getDpsInitsMutable())) {
398 if (!preservedProducerResults.count(opOperand.index()))
401 fusedOutputOperands.push_back(opOperand.value().get());
403 &opOperand.value(), producerResultIndexMap,
404 consumer.getMatchingIndexingMap(fusedOperand));
405 fusedIndexMaps.push_back(map);
406 fusedResultTypes.push_back(opOperand.value().get().getType());
410 for (
OpOperand &opOperand : consumer.getDpsInitsMutable()) {
411 fusedOutputOperands.push_back(opOperand.get());
412 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
413 Type resultType = opOperand.get().getType();
414 if (!isa<MemRefType>(resultType))
415 fusedResultTypes.push_back(resultType);
419 auto fusedOp = rewriter.
create<GenericOp>(
420 consumer.getLoc(), fusedResultTypes, fusedInputOperands,
422 consumer.getIteratorTypes(),
425 if (!fusedOp.getShapesToLoopsMap()) {
431 fusedOp,
"fused op failed loop bound computation check");
437 consumer.getMatchingIndexingMap(fusedOperand);
441 assert(invProducerResultIndexMap &&
442 "expected producer result indexig map to be invertible");
445 invProducerResultIndexMap.
compose(consumerResultIndexMap);
448 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
449 consumer.getNumLoops(), preservedProducerResults);
453 for (
auto [index, producerResult] :
llvm::enumerate(producer->getResults()))
454 if (preservedProducerResults.count(index))
455 result.
replacements[producerResult] = fusedOp->getResult(resultNum++);
456 for (
auto consumerResult : consumer->getResults())
457 result.
replacements[consumerResult] = fusedOp->getResult(resultNum++);
468 controlFn(std::move(fun)) {}
470 LogicalResult matchAndRewrite(GenericOp genericOp,
473 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
476 if (!controlFn(&opOperand))
479 Operation *producer = opOperand.get().getDefiningOp();
482 FailureOr<ElementwiseOpFusionResult> fusionResult =
484 if (failed(fusionResult))
488 for (
auto [origVal, replacement] : fusionResult->replacements) {
491 return use.
get().getDefiningOp() != producer;
570 linalgOp.getIteratorTypesArray();
571 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
572 return linalgOp.hasPureTensorSemantics() &&
573 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
575 return cast<AffineMapAttr>(attr)
577 .isProjectedPermutation();
585 class ExpansionInfo {
591 LogicalResult compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
595 unsigned getOrigOpNumDims()
const {
return reassociation.size(); }
596 unsigned getExpandedOpNumDims()
const {
return expandedOpNumDims; }
598 return reassociation[i];
601 return expandedShapeMap[i];
614 unsigned expandedOpNumDims;
618 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
623 if (reassociationMaps.empty())
625 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
629 originalLoopExtent = llvm::map_to_vector(
630 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
631 [](
Range r) { return r.size; });
633 reassociation.clear();
634 expandedShapeMap.clear();
638 expandedShapeMap.resize(fusedIndexMap.
getNumDims());
640 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
641 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
644 expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
645 expandedShapeMap[pos].assign(shape.begin(), shape.end());
648 for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
649 if (expandedShapeMap[i].empty())
650 expandedShapeMap[i] = {originalLoopExtent[i]};
654 reassociation.reserve(fusedIndexMap.
getNumDims());
656 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
657 reassociation.emplace_back(seq.begin(), seq.end());
658 sum += numFoldedDim.value();
660 expandedOpNumDims = sum;
668 const ExpansionInfo &expansionInfo) {
671 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
673 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
674 return builder.getAffineDimExpr(static_cast<unsigned>(v));
676 newExprs.append(expandedExprs.begin(), expandedExprs.end());
685 static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687 const ExpansionInfo &expansionInfo) {
690 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
692 expansionInfo.getExpandedShapeOfDim(dim);
693 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
696 std::tie(expandedStaticShape, std::ignore) =
699 originalType.getElementType())};
710 const ExpansionInfo &expansionInfo) {
712 unsigned numReshapeDims = 0;
714 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
715 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
717 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
718 reassociation.emplace_back(std::move(indices));
719 numReshapeDims += numExpandedDims;
721 return reassociation;
731 const ExpansionInfo &expansionInfo) {
733 for (IndexOp indexOp :
734 llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
736 expansionInfo.getExpandedDims(indexOp.getDim());
737 assert(!expandedDims.empty() &&
"expected valid expansion info");
740 if (expandedDims.size() == 1 &&
741 expandedDims.front() == (int64_t)indexOp.getDim())
748 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
750 expandedIndices.reserve(expandedDims.size() - 1);
752 expandedDims.drop_front(), std::back_inserter(expandedIndices),
753 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
755 rewriter.
create<IndexOp>(loc, expandedDims.front()).getResult();
756 for (
auto [expandedShape, expandedIndex] :
757 llvm::zip(expandedDimsShape, expandedIndices)) {
762 rewriter, indexOp.getLoc(), idx + acc * shape,
767 rewriter.
replaceOp(indexOp, newIndexVal);
789 TransposeOp transposeOp,
791 ExpansionInfo &expansionInfo) {
794 auto reassoc = expansionInfo.getExpandedDims(perm);
795 for (int64_t dim : reassoc) {
796 newPerm.push_back(dim);
799 return rewriter.
create<TransposeOp>(transposeOp.getLoc(), expandedInput,
810 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
812 for (
auto [i, type] :
llvm::enumerate(linalgOp.getIteratorTypesArray()))
813 for (
auto j : expansionInfo.getExpandedDims(i))
814 iteratorTypes[
j] = type;
817 linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
818 expandedOpIndexingMaps, iteratorTypes);
821 Region &originalRegion = linalgOp->getRegion(0);
839 ExpansionInfo &expansionInfo) {
842 .Case<TransposeOp>([&](TransposeOp transposeOp) {
844 expandedOpOperands[0], outputs[0],
847 .Case<FillOp, CopyOp>([&](
Operation *op) {
848 return clone(rewriter, linalgOp, resultTypes,
849 llvm::to_vector(llvm::concat<Value>(
850 llvm::to_vector(expandedOpOperands),
851 llvm::to_vector(outputs))));
855 expandedOpOperands, outputs,
856 expansionInfo, expandedOpIndexingMaps);
863 static std::optional<SmallVector<Value>>
868 "preconditions for fuse operation failed");
874 if (
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
878 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
881 expandedShape = expandingReshapeOp.getMixedOutputShape();
882 reassociationIndices = expandingReshapeOp.getReassociationMaps();
883 src = expandingReshapeOp.getSrc();
885 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
886 if (!collapsingReshapeOp)
890 rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
891 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
892 src = collapsingReshapeOp.getSrc();
895 ExpansionInfo expansionInfo;
896 if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
897 reassociationIndices, expandedShape,
902 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](
AffineMap m) {
903 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
911 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
912 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
913 if (opOperand == fusableOpOperand) {
914 expandedOpOperands.push_back(src);
917 if (
auto opOperandType =
918 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
919 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
921 RankedTensorType expandedOperandType;
922 std::tie(expandedOperandShape, expandedOperandType) =
924 if (expandedOperandType != opOperand->get().getType()) {
929 [&](
const Twine &msg) {
932 opOperandType.getShape(), expandedOperandType.getShape(),
936 expandedOpOperands.push_back(rewriter.
create<tensor::ExpandShapeOp>(
937 loc, expandedOperandType, opOperand->get(), reassociation,
938 expandedOperandShape));
942 expandedOpOperands.push_back(opOperand->get());
946 for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
947 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
948 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
950 RankedTensorType expandedOutputType;
951 std::tie(expandedOutputShape, expandedOutputType) =
953 if (expandedOutputType != opOperand.get().getType()) {
957 [&](
const Twine &msg) {
960 opOperandType.getShape(), expandedOutputType.getShape(),
964 outputs.push_back(rewriter.
create<tensor::ExpandShapeOp>(
965 loc, expandedOutputType, opOperand.get(), reassociation,
966 expandedOutputShape));
968 outputs.push_back(opOperand.get());
975 outputs, expandedOpIndexingMaps, expansionInfo);
979 for (
OpResult opResult : linalgOp->getOpResults()) {
980 int64_t resultNumber = opResult.getResultNumber();
981 if (resultTypes[resultNumber] != opResult.getType()) {
984 linalgOp.getMatchingIndexingMap(
985 linalgOp.getDpsInitOperand(resultNumber)),
987 resultVals.push_back(rewriter.
create<tensor::CollapseShapeOp>(
988 linalgOp.getLoc(), opResult.getType(),
989 fusedOp->
getResult(resultNumber), reassociation));
991 resultVals.push_back(fusedOp->
getResult(resultNumber));
1003 class FoldWithProducerReshapeOpByExpansion
1006 FoldWithProducerReshapeOpByExpansion(
MLIRContext *context,
1010 controlFoldingReshapes(std::move(foldReshapes)) {}
1012 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1014 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1015 tensor::CollapseShapeOp reshapeOp =
1016 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1023 (!controlFoldingReshapes(opOperand)))
1026 std::optional<SmallVector<Value>> replacementValues =
1028 if (!replacementValues)
1030 rewriter.
replaceOp(linalgOp, *replacementValues);
1040 class FoldPadWithProducerReshapeOpByExpansion
1043 FoldPadWithProducerReshapeOpByExpansion(
MLIRContext *context,
1047 controlFoldingReshapes(std::move(foldReshapes)) {}
1049 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1051 tensor::CollapseShapeOp reshapeOp =
1052 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1055 if (!reshapeOp->hasOneUse())
1058 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1060 "fusion blocked by control function");
1066 reshapeOp.getReassociationIndices();
1068 for (
auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1069 if (reInd.size() != 1 && (l != 0 || h != 0))
1074 RankedTensorType expandedType = reshapeOp.getSrcType();
1075 RankedTensorType paddedType = padOp.getResultType();
1078 if (reInd.size() == 1) {
1079 expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1081 for (
size_t i = 0; i < reInd.size(); ++i) {
1082 newLow.push_back(padOp.getMixedLowPad()[idx]);
1083 newHigh.push_back(padOp.getMixedHighPad()[idx]);
1088 RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1089 auto newPadOp = rewriter.
create<tensor::PadOp>(
1090 loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1091 padOp.getConstantPaddingValue(), padOp.getNofold());
1094 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1105 struct FoldReshapeWithGenericOpByExpansion
1108 FoldReshapeWithGenericOpByExpansion(
MLIRContext *context,
1112 controlFoldingReshapes(std::move(foldReshapes)) {}
1114 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1117 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1118 if (!producerResult) {
1120 "source not produced by an operation");
1123 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1126 "producer not a generic op");
1131 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1133 reshapeOp,
"failed preconditions of fusion with producer generic op");
1136 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1138 "fusion blocked by control function");
1141 std::optional<SmallVector<Value>> replacementValues =
1143 producer, reshapeOp,
1144 producer.getDpsInitOperand(producerResult.getResultNumber()),
1146 if (!replacementValues) {
1148 "fusion by expansion failed");
1155 Value reshapeReplacement =
1156 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1157 .getResultNumber()];
1158 if (
auto collapseOp =
1159 reshapeReplacement.
getDefiningOp<tensor::CollapseShapeOp>()) {
1160 reshapeReplacement = collapseOp.getSrc();
1162 rewriter.
replaceOp(reshapeOp, reshapeReplacement);
1163 rewriter.
replaceOp(producer, *replacementValues);
1185 "expected projected permutation");
1188 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1189 return cast<AffineDimExpr>(indexingMap.
getResults()[pos]).getPosition();
1193 return domainReassociation;
1201 assert(!dimSequence.empty() &&
1202 "expected non-empty list for dimension sequence");
1204 "expected indexing map to be projected permutation");
1206 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1207 sequenceElements.insert_range(dimSequence);
1209 unsigned dimSequenceStart = dimSequence[0];
1211 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1213 if (dimInMapStart == dimSequenceStart) {
1214 if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
1217 for (
const auto &dimInSequence :
enumerate(dimSequence)) {
1219 cast<AffineDimExpr>(
1220 indexingMap.
getResult(expr.index() + dimInSequence.index()))
1222 if (dimInMap != dimInSequence.value())
1233 if (sequenceElements.count(dimInMapStart))
1242 return llvm::all_of(maps, [&](
AffineMap map) {
1299 if (!genericOp.hasPureTensorSemantics())
1302 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
1303 return map.isProjectedPermutation();
1310 genericOp.getReductionDims(reductionDims);
1312 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1313 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1314 auto iteratorTypes = genericOp.getIteratorTypesArray();
1317 assert(!foldedRangeDims.empty() &&
"unexpected empty reassociation");
1320 if (foldedRangeDims.size() == 1)
1328 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1329 return processedIterationDims.count(dim);
1334 utils::IteratorType startIteratorType =
1335 iteratorTypes[foldedIterationSpaceDims[0]];
1339 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1340 return iteratorTypes[dim] != startIteratorType;
1349 bool isContiguous =
false;
1352 if (startDim.value() != foldedIterationSpaceDims[0])
1356 if (startDim.index() + foldedIterationSpaceDims.size() >
1357 reductionDims.size())
1360 isContiguous =
true;
1361 for (
const auto &foldedDim :
1363 if (reductionDims[foldedDim.index() + startDim.index()] !=
1364 foldedDim.value()) {
1365 isContiguous =
false;
1376 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1378 return !isDimSequencePreserved(indexingMap,
1379 foldedIterationSpaceDims);
1383 processedIterationDims.insert_range(foldedIterationSpaceDims);
1384 iterationSpaceReassociation.emplace_back(
1385 std::move(foldedIterationSpaceDims));
1388 return iterationSpaceReassociation;
1393 class CollapsingInfo {
1395 LogicalResult initialize(
unsigned origNumLoops,
1397 llvm::SmallDenseSet<int64_t, 4> processedDims;
1400 if (foldedIterationDim.empty())
1404 for (
auto dim : foldedIterationDim) {
1405 if (dim >= origNumLoops)
1407 if (processedDims.count(dim))
1409 processedDims.insert(dim);
1411 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1412 foldedIterationDim.end());
1414 if (processedDims.size() > origNumLoops)
1419 for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1420 if (processedDims.count(dim))
1425 llvm::sort(collapsedOpToOrigOpIterationDim,
1427 return lhs[0] < rhs[0];
1429 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1430 for (
const auto &foldedDims :
1432 for (
const auto &dim :
enumerate(foldedDims.value()))
1433 origOpToCollapsedOpIterationDim[dim.value()] =
1434 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1441 return collapsedOpToOrigOpIterationDim;
1465 return origOpToCollapsedOpIterationDim;
1469 unsigned getCollapsedOpIterationRank()
const {
1470 return collapsedOpToOrigOpIterationDim.size();
1488 const CollapsingInfo &collapsingInfo) {
1491 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1492 assert(!foldedIterDims.empty() &&
1493 "reassociation indices expected to have non-empty sets");
1497 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1499 return collapsedIteratorTypes;
1506 const CollapsingInfo &collapsingInfo) {
1509 "expected indexing map to be projected permutation");
1511 auto origOpToCollapsedOpMapping =
1512 collapsingInfo.getOrigOpToCollapsedOpMapping();
1514 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1516 if (origOpToCollapsedOpMapping[dim].second != 0)
1520 resultExprs.push_back(
1523 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1524 resultExprs, context);
1531 const CollapsingInfo &collapsingInfo) {
1532 unsigned counter = 0;
1534 auto origOpToCollapsedOpMapping =
1535 collapsingInfo.getOrigOpToCollapsedOpMapping();
1536 auto collapsedOpToOrigOpMapping =
1537 collapsingInfo.getCollapsedOpToOrigOpMapping();
1540 cast<AffineDimExpr>(indexingMap.
getResult(counter)).getPosition();
1544 unsigned numFoldedDims =
1545 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1547 if (origOpToCollapsedOpMapping[dim].second == 0) {
1548 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1549 operandReassociation.emplace_back(range.begin(), range.end());
1551 counter += numFoldedDims;
1553 return operandReassociation;
1559 const CollapsingInfo &collapsingInfo,
1561 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1569 if (operandReassociation.size() == indexingMap.
getNumResults())
1573 if (isa<MemRefType>(operand.
getType())) {
1575 .
create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1579 .
create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1586 Location loc,
Block *block,
const CollapsingInfo &collapsingInfo,
1592 auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
1602 for (
auto foldedDims :
1603 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1606 rewriter.
create<linalg::IndexOp>(loc, foldedDims.index());
1607 for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1610 indexReplacementVals[dim] =
1611 rewriter.
createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1613 rewriter.
createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1615 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1618 for (
auto indexOp : indexOps) {
1619 auto dim = indexOp.getDim();
1620 rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
1625 const CollapsingInfo &collapsingInfo,
1632 llvm::map_to_vector(op.getDpsInputOperands(), [&](
OpOperand *opOperand) {
1633 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1638 resultTypes.reserve(op.getNumDpsInits());
1639 outputOperands.reserve(op.getNumDpsInits());
1640 for (
OpOperand &output : op.getDpsInitsMutable()) {
1643 outputOperands.push_back(newOutput);
1646 if (!op.hasPureBufferSemantics())
1647 resultTypes.push_back(newOutput.
getType());
1652 template <
typename OpTy>
1654 const CollapsingInfo &collapsingInfo) {
1662 const CollapsingInfo &collapsingInfo) {
1666 outputOperands, resultTypes);
1669 rewriter, origOp, resultTypes,
1670 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1677 const CollapsingInfo &collapsingInfo) {
1681 outputOperands, resultTypes);
1683 llvm::map_range(origOp.getIndexingMapsArray(), [&](
AffineMap map) {
1684 return getCollapsedOpIndexingMap(map, collapsingInfo);
1688 origOp.getIteratorTypesArray(), collapsingInfo));
1690 GenericOp collapsedOp = rewriter.
create<linalg::GenericOp>(
1691 origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1693 Block *origOpBlock = &origOp->getRegion(0).
front();
1694 Block *collapsedOpBlock = &collapsedOp->getRegion(0).
front();
1695 rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
1702 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1714 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1716 return foldedDims.size() <= 1;
1720 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1721 if (hasPureBufferSemantics &&
1722 !llvm::all_of(op->getOperands(), [&](
Value operand) ->
bool {
1723 MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1724 if (!memRefToCollapse)
1727 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1728 memRefToCollapse, foldedIterationDims);
1731 "memref is not guaranteed collapsible");
1733 CollapsingInfo collapsingInfo;
1735 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1737 op,
"illegal to collapse specified dimensions");
1742 auto opFoldIsConstantValue = [](
OpFoldResult ofr, int64_t value) {
1743 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1744 return cast<IntegerAttr>(attr).getInt() == value;
1747 actual.getSExtValue() == value;
1749 if (!llvm::all_of(loopRanges, [&](
Range range) {
1750 return opFoldIsConstantValue(range.
offset, 0) &&
1751 opFoldIsConstantValue(range.
stride, 1);
1754 op,
"expected all loop ranges to have zero start and unit stride");
1761 llvm::map_to_vector(loopRanges, [](
Range range) {
return range.
size; });
1763 if (collapsedOp.hasIndexSemantics()) {
1768 collapsingInfo, loopBound, rewriter);
1774 for (
const auto &originalResult :
llvm::enumerate(op->getResults())) {
1775 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1776 auto originalResultType =
1777 cast<ShapedType>(originalResult.value().getType());
1778 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.
getType());
1779 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1781 op.getIndexingMapMatchingResult(originalResult.value());
1786 "Expected indexing map to be a projected permutation for collapsing");
1790 if (isa<MemRefType>(collapsedOpResult.
getType())) {
1792 originalResultType.getShape(), originalResultType.getElementType());
1793 result = rewriter.
create<memref::ExpandShapeOp>(
1794 loc, expandShapeResultType, collapsedOpResult, reassociation,
1797 result = rewriter.
create<tensor::ExpandShapeOp>(
1798 loc, originalResultType, collapsedOpResult, reassociation,
1801 results.push_back(result);
1803 results.push_back(collapsedOpResult);
1813 class FoldWithProducerReshapeOpByCollapsing
1817 FoldWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1821 controlFoldingReshapes(std::move(foldReshapes)) {}
1823 LogicalResult matchAndRewrite(GenericOp genericOp,
1825 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
1826 tensor::ExpandShapeOp reshapeOp =
1827 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1833 reshapeOp.getReassociationIndices());
1834 if (collapsableIterationDims.empty() ||
1835 !controlFoldingReshapes(&opOperand)) {
1840 genericOp, collapsableIterationDims, rewriter);
1841 if (!collapseResult) {
1843 genericOp,
"failed to do the fusion by collapsing transformation");
1846 rewriter.
replaceOp(genericOp, collapseResult->results);
1858 struct FoldReshapeWithGenericOpByCollapsing
1861 FoldReshapeWithGenericOpByCollapsing(
MLIRContext *context,
1865 controlFoldingReshapes(std::move(foldReshapes)) {}
1867 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1871 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1872 if (!producerResult) {
1874 "source not produced by an operation");
1878 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1881 "producer not a generic op");
1887 producer.getDpsInitOperand(producerResult.getResultNumber()),
1888 reshapeOp.getReassociationIndices());
1889 if (collapsableIterationDims.empty()) {
1891 reshapeOp,
"failed preconditions of fusion with producer generic op");
1894 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1896 "fusion blocked by control function");
1902 std::optional<CollapseResult> collapseResult =
1904 if (!collapseResult) {
1906 producer,
"failed to do the fusion by collapsing transformation");
1909 rewriter.
replaceOp(producer, collapseResult->results);
1917 class FoldPadWithProducerReshapeOpByCollapsing
1920 FoldPadWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1924 controlFoldingReshapes(std::move(foldReshapes)) {}
1926 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1928 tensor::ExpandShapeOp reshapeOp =
1929 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1932 if (!reshapeOp->hasOneUse())
1935 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1937 "fusion blocked by control function");
1943 reshapeOp.getReassociationIndices();
1945 for (
auto reInd : reassociations) {
1946 if (reInd.size() == 1)
1948 if (llvm::any_of(reInd, [&](int64_t ind) {
1949 return low[ind] != 0 || high[ind] != 0;
1956 RankedTensorType collapsedType = reshapeOp.getSrcType();
1957 RankedTensorType paddedType = padOp.getResultType();
1961 reshapeOp.getOutputShape(), rewriter));
1965 Location loc = reshapeOp->getLoc();
1969 if (reInd.size() == 1) {
1970 collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1972 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1973 expandedPaddedSizes[reInd[0]] = paddedSize;
1975 newLow.push_back(l);
1976 newHigh.push_back(h);
1979 RankedTensorType collapsedPaddedType =
1980 paddedType.clone(collapsedPaddedShape);
1981 auto newPadOp = rewriter.
create<tensor::PadOp>(
1982 loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1983 padOp.getConstantPaddingValue(), padOp.getNofold());
1986 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1987 expandedPaddedSizes);
1997 template <
typename LinalgType>
2004 controlCollapseDimension(std::move(collapseDimensions)) {}
2006 LogicalResult matchAndRewrite(LinalgType op,
2009 controlCollapseDimension(op);
2010 if (collapsableIterationDims.empty())
2015 collapsableIterationDims)) {
2017 op,
"specified dimensions cannot be collapsed");
2020 std::optional<CollapseResult> collapseResult =
2022 if (!collapseResult) {
2025 rewriter.
replaceOp(op, collapseResult->results);
2047 LogicalResult matchAndRewrite(GenericOp genericOp,
2049 if (!genericOp.hasPureTensorSemantics())
2051 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2052 Operation *def = opOperand->get().getDefiningOp();
2053 TypedAttr constantAttr;
2054 auto isScalarOrSplatConstantOp = [&constantAttr](
Operation *def) ->
bool {
2057 if (
matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2059 splatAttr.
getType().getElementType().isIntOrFloat()) {
2065 IntegerAttr intAttr;
2066 if (
matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2067 constantAttr = intAttr;
2072 FloatAttr floatAttr;
2073 if (
matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2074 constantAttr = floatAttr;
2081 auto resultValue = dyn_cast<OpResult>(opOperand->get());
2082 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2091 fusedIndexMaps.reserve(genericOp->getNumOperands());
2092 fusedOperands.reserve(genericOp.getNumDpsInputs());
2093 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2094 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2095 if (inputOperand == opOperand)
2097 Value inputValue = inputOperand->get();
2098 fusedIndexMaps.push_back(
2099 genericOp.getMatchingIndexingMap(inputOperand));
2100 fusedOperands.push_back(inputValue);
2101 fusedLocs.push_back(inputValue.
getLoc());
2103 for (
OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2104 fusedIndexMaps.push_back(
2105 genericOp.getMatchingIndexingMap(&outputOperand));
2111 genericOp,
"fused op loop bound computation failed");
2115 Value scalarConstant =
2116 rewriter.
create<arith::ConstantOp>(def->
getLoc(), constantAttr);
2119 auto fusedOp = rewriter.
create<GenericOp>(
2120 rewriter.
getFusedLoc(fusedLocs), genericOp->getResultTypes(),
2124 genericOp.getIteratorTypes(),
2130 Region ®ion = genericOp->getRegion(0);
2133 mapping.
map(entryBlock.
getArgument(opOperand->getOperandNumber()),
2135 Region &fusedRegion = fusedOp->getRegion(0);
2138 rewriter.
replaceOp(genericOp, fusedOp->getResults());
2159 LogicalResult matchAndRewrite(GenericOp op,
2162 bool modifiedOutput =
false;
2164 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
2165 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2166 Value operandVal = opOperand.get();
2167 auto operandType = dyn_cast<RankedTensorType>(operandVal.
getType());
2176 auto definingOp = operandVal.
getDefiningOp<tensor::EmptyOp>();
2179 modifiedOutput =
true;
2182 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
2183 loc, mixedSizes, operandType.getElementType());
2184 op->
setOperand(opOperand.getOperandNumber(), emptyTensor);
2187 if (!modifiedOutput) {
2200 LogicalResult matchAndRewrite(GenericOp genericOp,
2202 if (!genericOp.hasPureTensorSemantics())
2204 bool fillFound =
false;
2205 Block &payload = genericOp.getRegion().
front();
2206 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2207 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2209 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2213 Value fillVal = fillOp.value();
2215 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2216 Value convertedVal =
2220 payload.
getArgument(opOperand->getOperandNumber()), convertedVal);
2222 return success(fillFound);
2231 controlFoldingReshapes);
2232 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(
patterns.getContext(),
2233 controlFoldingReshapes);
2235 controlFoldingReshapes);
2242 controlFoldingReshapes);
2243 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2244 patterns.getContext(), controlFoldingReshapes);
2246 controlFoldingReshapes);
2252 auto *context =
patterns.getContext();
2253 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2254 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2255 RemoveOutsDependency>(context);
2263 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2264 CollapseLinalgDimensions<linalg::CopyOp>>(
2265 patterns.getContext(), controlCollapseDimensions);
2280 struct LinalgElementwiseOpFusionPass
2281 :
public impl::LinalgElementwiseOpFusionPassBase<
2282 LinalgElementwiseOpFusionPass> {
2283 using impl::LinalgElementwiseOpFusionPassBase<
2284 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2285 void runOnOperation()
override {
2292 Operation *producer = fusedOperand->get().getDefiningOp();
2293 return producer && producer->
hasOneUse();
2302 affine::AffineApplyOp::getCanonicalizationPatterns(
patterns, context);
2303 GenericOp::getCanonicalizationPatterns(
patterns, context);
2304 tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
2305 tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ArrayRef< OpFoldResult > loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setUseTopDownTraversal(bool use=true)
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
ArrayRef< int64_t > ReassociationIndicesRef
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
llvm::DenseMap< Value, Value > replacements
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.