31 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
32 #include "mlir/Dialect/Linalg/Passes.h.inc"
58 assert(invProducerResultIndexMap &&
59 "expected producer result indexing map to be invertible");
61 LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
63 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
71 return t1.
compose(fusedConsumerArgIndexMap);
78 GenericOp producer, GenericOp consumer,
83 for (
auto &op : ops) {
84 for (
auto &opOperand : op->getOpOperands()) {
85 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
88 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
91 if (indexingMaps.empty()) {
94 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
102 indexingMaps, producer.getContext())) !=
AffineMap();
111 GenericOp producer, GenericOp consumer,
OpOperand *fusedOperand) {
112 llvm::SmallDenseSet<int> preservedProducerResults;
116 opOperandsToIgnore.emplace_back(fusedOperand);
118 for (
const auto &producerResult :
llvm::enumerate(producer->getResults())) {
119 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
120 opOperandsToIgnore.emplace_back(outputOperand);
121 if (producer.payloadUsesValueFromOperand(outputOperand) ||
123 opOperandsToIgnore) ||
124 llvm::any_of(producerResult.value().getUsers(), [&](
Operation *user) {
125 return user != consumer.getOperation();
127 preservedProducerResults.insert(producerResult.index());
130 (void)opOperandsToIgnore.pop_back_val();
133 return preservedProducerResults;
142 auto consumer = dyn_cast<GenericOp>(fusedOperand->
getOwner());
145 if (!producer || !consumer)
151 if (!producer.hasPureTensorSemantics() ||
152 !isa<RankedTensorType>(fusedOperand->
get().
getType()))
157 if (producer.getNumParallelLoops() != producer.getNumLoops())
162 if (!consumer.isDpsInput(fusedOperand))
167 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
168 if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
174 producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
182 if ((consumer.getNumReductionLoops())) {
183 BitVector coveredDims(consumer.getNumLoops(),
false);
185 auto addToCoveredDims = [&](
AffineMap map) {
186 for (
auto result : map.getResults())
187 if (
auto dimExpr = dyn_cast<AffineDimExpr>(result))
188 coveredDims[dimExpr.getPosition()] =
true;
192 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
193 Value operand = std::get<0>(pair);
194 if (operand == fusedOperand->
get())
196 AffineMap operandMap = std::get<1>(pair);
197 addToCoveredDims(operandMap);
200 for (
OpOperand *operand : producer.getDpsInputOperands()) {
203 operand, producerResultIndexMap, consumerIndexMap);
204 addToCoveredDims(newIndexingMap);
206 if (!coveredDims.all())
218 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
220 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
222 Block &producerBlock = producer->getRegion(0).
front();
223 Block &consumerBlock = consumer->getRegion(0).
front();
230 if (producer.hasIndexSemantics()) {
232 unsigned numFusedOpLoops =
233 std::max(producer.getNumLoops(), consumer.getNumLoops());
235 fusedIndices.reserve(numFusedOpLoops);
236 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
237 std::back_inserter(fusedIndices), [&](uint64_t dim) {
238 return rewriter.
create<IndexOp>(producer.getLoc(), dim);
240 for (IndexOp indexOp :
241 llvm::make_early_inc_range(producerBlock.
getOps<IndexOp>())) {
242 Value newIndex = rewriter.
create<affine::AffineApplyOp>(
244 consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
245 mapper.
map(indexOp.getResult(), newIndex);
249 assert(consumer.isDpsInput(fusedOperand) &&
250 "expected producer of input operand");
254 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
261 producerBlock.
getArguments().take_front(producer.getNumDpsInputs()))
262 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
267 .take_front(consumer.getNumDpsInputs())
269 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
273 producerBlock.
getArguments().take_back(producer.getNumDpsInits()))) {
274 if (!preservedProducerResults.count(bbArg.index()))
276 mapper.
map(bbArg.value(), fusedBlock->
addArgument(bbArg.value().getType(),
277 bbArg.value().getLoc()));
282 consumerBlock.
getArguments().take_back(consumer.getNumDpsInits()))
283 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
288 if (!isa<IndexOp>(op))
289 rewriter.
clone(op, mapper);
293 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.
getTerminator());
294 unsigned producerResultNumber =
295 cast<OpResult>(fusedOperand->
get()).getResultNumber();
297 mapper.
lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
301 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
302 if (
auto bb = dyn_cast<BlockArgument>(replacement))
303 assert(bb.getOwner() != &producerBlock &&
304 "yielded block argument must have been mapped");
307 "yielded value must have been mapped");
313 rewriter.
clone(op, mapper);
317 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.
getTerminator());
319 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
320 consumerYieldOp.getNumOperands());
321 for (
const auto &producerYieldVal :
323 if (preservedProducerResults.count(producerYieldVal.index()))
324 fusedYieldValues.push_back(
327 for (
auto consumerYieldVal : consumerYieldOp.getOperands())
329 rewriter.
create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
333 "Ill-formed GenericOp region");
336 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
340 "expected elementwise operation pre-conditions to pass");
341 auto producerResult = cast<OpResult>(fusedOperand->
get());
342 auto producer = cast<GenericOp>(producerResult.getOwner());
343 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
345 assert(consumer.isDpsInput(fusedOperand) &&
346 "expected producer of input operand");
349 llvm::SmallDenseSet<int> preservedProducerResults =
357 fusedInputOperands.reserve(producer.getNumDpsInputs() +
358 consumer.getNumDpsInputs());
359 fusedOutputOperands.reserve(preservedProducerResults.size() +
360 consumer.getNumDpsInits());
361 fusedResultTypes.reserve(preservedProducerResults.size() +
362 consumer.getNumDpsInits());
363 fusedIndexMaps.reserve(producer->getNumOperands() +
364 consumer->getNumOperands());
367 auto consumerInputs = consumer.getDpsInputOperands();
368 auto *it = llvm::find_if(consumerInputs, [&](
OpOperand *operand) {
369 return operand == fusedOperand;
371 assert(it != consumerInputs.end() &&
"expected to find the consumer operand");
372 for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
373 fusedInputOperands.push_back(opOperand->get());
374 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
378 producer.getIndexingMapMatchingResult(producerResult);
379 for (
OpOperand *opOperand : producer.getDpsInputOperands()) {
380 fusedInputOperands.push_back(opOperand->get());
383 opOperand, producerResultIndexMap,
384 consumer.getMatchingIndexingMap(fusedOperand));
385 fusedIndexMaps.push_back(map);
390 llvm::make_range(std::next(it), consumerInputs.end())) {
391 fusedInputOperands.push_back(opOperand->get());
392 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
396 for (
const auto &opOperand :
llvm::enumerate(producer.getDpsInitsMutable())) {
397 if (!preservedProducerResults.count(opOperand.index()))
400 fusedOutputOperands.push_back(opOperand.value().get());
402 &opOperand.value(), producerResultIndexMap,
403 consumer.getMatchingIndexingMap(fusedOperand));
404 fusedIndexMaps.push_back(map);
405 fusedResultTypes.push_back(opOperand.value().get().getType());
409 for (
OpOperand &opOperand : consumer.getDpsInitsMutable()) {
410 fusedOutputOperands.push_back(opOperand.get());
411 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
412 Type resultType = opOperand.get().getType();
413 if (!isa<MemRefType>(resultType))
414 fusedResultTypes.push_back(resultType);
418 auto fusedOp = rewriter.
create<GenericOp>(
419 consumer.getLoc(), fusedResultTypes, fusedInputOperands,
421 consumer.getIteratorTypes(),
424 if (!fusedOp.getShapesToLoopsMap()) {
430 fusedOp,
"fused op failed loop bound computation check");
436 consumer.getMatchingIndexingMap(fusedOperand);
440 assert(invProducerResultIndexMap &&
441 "expected producer result indexig map to be invertible");
444 invProducerResultIndexMap.
compose(consumerResultIndexMap);
447 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
448 consumer.getNumLoops(), preservedProducerResults);
452 for (
auto [index, producerResult] :
llvm::enumerate(producer->getResults()))
453 if (preservedProducerResults.count(index))
454 result.
replacements[producerResult] = fusedOp->getResult(resultNum++);
455 for (
auto consumerResult : consumer->getResults())
456 result.
replacements[consumerResult] = fusedOp->getResult(resultNum++);
467 controlFn(std::move(fun)) {}
469 LogicalResult matchAndRewrite(GenericOp genericOp,
472 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
475 if (!controlFn(&opOperand))
478 Operation *producer = opOperand.get().getDefiningOp();
481 FailureOr<ElementwiseOpFusionResult> fusionResult =
483 if (failed(fusionResult))
487 for (
auto [origVal, replacement] : fusionResult->replacements) {
490 return use.
get().getDefiningOp() != producer;
570 linalgOp.getIteratorTypesArray();
571 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
572 return linalgOp.hasPureTensorSemantics() &&
573 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
575 return cast<AffineMapAttr>(attr)
577 .isProjectedPermutation();
581 return isParallelIterator(
582 iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
589 class ExpansionInfo {
595 LogicalResult compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
600 unsigned getOrigOpNumDims()
const {
return reassociation.size(); }
601 unsigned getExpandedOpNumDims()
const {
return expandedOpNumDims; }
603 return reassociation[i];
606 return expandedShapeMap[i];
619 unsigned expandedOpNumDims;
623 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
629 if (reassociationMaps.empty())
631 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
634 originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
636 reassociation.clear();
637 expandedShapeMap.clear();
641 expandedShapeMap.resize(fusedIndexMap.
getNumDims());
643 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
644 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
647 expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
648 expandedShapeMap[pos].assign(shape.begin(), shape.end());
651 for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
652 if (expandedShapeMap[i].empty())
653 expandedShapeMap[i] = {originalLoopExtent[i]};
657 reassociation.reserve(fusedIndexMap.
getNumDims());
659 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
660 reassociation.emplace_back(seq.begin(), seq.end());
661 sum += numFoldedDim.value();
663 expandedOpNumDims = sum;
676 const ExpansionInfo &expansionInfo,
678 if (!linalgOp.hasIndexSemantics())
680 for (
unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
682 if (expandedShape.size() == 1)
684 for (int64_t shape : expandedShape.drop_front()) {
685 if (ShapedType::isDynamic(shape)) {
687 linalgOp,
"cannot expand due to index semantics and dynamic dims");
698 const ExpansionInfo &expansionInfo) {
701 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
703 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
704 return builder.getAffineDimExpr(static_cast<unsigned>(v));
706 newExprs.append(expandedExprs.begin(), expandedExprs.end());
717 const ExpansionInfo &expansionInfo) {
720 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
721 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
722 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
735 const ExpansionInfo &expansionInfo) {
737 unsigned numReshapeDims = 0;
739 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
740 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
742 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
743 reassociation.emplace_back(std::move(indices));
744 numReshapeDims += numExpandedDims;
746 return reassociation;
756 const ExpansionInfo &expansionInfo) {
758 for (IndexOp indexOp :
759 llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
761 expansionInfo.getExpandedDims(indexOp.getDim());
762 assert(!expandedDims.empty() &&
"expected valid expansion info");
765 if (expandedDims.size() == 1 &&
766 expandedDims.front() == (int64_t)indexOp.getDim())
773 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
775 expandedIndices.reserve(expandedDims.size() - 1);
777 expandedDims.drop_front(), std::back_inserter(expandedIndices),
778 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
779 Value newIndex = rewriter.
create<IndexOp>(loc, expandedDims.front());
780 for (
auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
781 assert(!ShapedType::isDynamic(std::get<0>(it)));
784 newIndex = rewriter.
create<affine::AffineApplyOp>(
785 indexOp.getLoc(), idx + acc * std::get<0>(it),
796 const ExpansionInfo &expansionInfo,
798 for (
unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
800 if (expandedShape.size() == 1)
802 bool foundDynamic =
false;
803 for (int64_t shape : expandedShape) {
804 if (!ShapedType::isDynamic(shape))
808 linalgOp,
"cannot infer expanded shape with multiple dynamic "
809 "dims in the same reassociation group");
820 static std::optional<SmallVector<Value>>
825 "preconditions for fuse operation failed");
829 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
830 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
831 bool isExpanding = (expandingReshapeOp !=
nullptr);
832 RankedTensorType expandedType = isExpanding
833 ? expandingReshapeOp.getResultType()
834 : collapsingReshapeOp.getSrcType();
835 RankedTensorType collapsedType = isExpanding
836 ? expandingReshapeOp.getSrcType()
837 : collapsingReshapeOp.getResultType();
839 ExpansionInfo expansionInfo;
840 if (failed(expansionInfo.compute(
841 linalgOp, fusableOpOperand,
842 isExpanding ? expandingReshapeOp.getReassociationMaps()
843 : collapsingReshapeOp.getReassociationMaps(),
844 expandedType.getShape(), collapsedType.getShape(), rewriter)))
856 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](
AffineMap m) {
857 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
865 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
866 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
867 if (opOperand == fusableOpOperand) {
868 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
869 : collapsingReshapeOp.getSrc());
872 if (
auto opOperandType =
873 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
874 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
875 RankedTensorType expandedOperandType =
877 if (expandedOperandType != opOperand->get().getType()) {
882 [&](
const Twine &msg) {
885 opOperandType.getShape(), expandedOperandType.getShape(),
889 expandedOpOperands.push_back(rewriter.
create<tensor::ExpandShapeOp>(
890 loc, expandedOperandType, opOperand->get(), reassociation));
894 expandedOpOperands.push_back(opOperand->get());
898 for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
899 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
900 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
901 RankedTensorType expandedOutputType =
903 if (expandedOutputType != opOperand.get().getType()) {
907 [&](
const Twine &msg) {
910 opOperandType.getShape(), expandedOutputType.getShape(),
914 outputs.push_back(rewriter.
create<tensor::ExpandShapeOp>(
915 loc, expandedOutputType, opOperand.get(), reassociation));
917 outputs.push_back(opOperand.get());
923 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
924 for (
auto [i, type] :
llvm::enumerate(linalgOp.getIteratorTypesArray()))
925 for (
auto j : expansionInfo.getExpandedDims(i))
926 iteratorTypes[
j] = type;
930 rewriter.
create<GenericOp>(linalgOp.getLoc(), resultTypes,
931 expandedOpOperands, outputs,
932 expandedOpIndexingMaps, iteratorTypes);
933 Region &fusedRegion = fusedOp->getRegion(0);
934 Region &originalRegion = linalgOp->getRegion(0);
943 for (
OpResult opResult : linalgOp->getOpResults()) {
944 int64_t resultNumber = opResult.getResultNumber();
945 if (resultTypes[resultNumber] != opResult.getType()) {
948 linalgOp.getMatchingIndexingMap(
949 linalgOp.getDpsInitOperand(resultNumber)),
951 resultVals.push_back(rewriter.
create<tensor::CollapseShapeOp>(
952 linalgOp.getLoc(), opResult.getType(),
953 fusedOp->getResult(resultNumber), reassociation));
955 resultVals.push_back(fusedOp->getResult(resultNumber));
967 class FoldWithProducerReshapeOpByExpansion
970 FoldWithProducerReshapeOpByExpansion(
MLIRContext *context,
974 controlFoldingReshapes(std::move(foldReshapes)) {}
976 LogicalResult matchAndRewrite(LinalgOp linalgOp,
978 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
979 tensor::CollapseShapeOp reshapeOp =
980 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
987 (!controlFoldingReshapes(opOperand)))
990 std::optional<SmallVector<Value>> replacementValues =
992 if (!replacementValues)
994 rewriter.
replaceOp(linalgOp, *replacementValues);
1004 class FoldPadWithProducerReshapeOpByExpansion
1007 FoldPadWithProducerReshapeOpByExpansion(
MLIRContext *context,
1011 controlFoldingReshapes(std::move(foldReshapes)) {}
1013 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1015 tensor::CollapseShapeOp reshapeOp =
1016 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1019 if (!reshapeOp->hasOneUse())
1022 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1024 "fusion blocked by control function");
1030 reshapeOp.getReassociationIndices();
1032 for (
auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1033 if (reInd.size() != 1 && (l != 0 || h != 0))
1038 RankedTensorType expandedType = reshapeOp.getSrcType();
1039 RankedTensorType paddedType = padOp.getResultType();
1042 if (reInd.size() == 1) {
1043 expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1045 for (
size_t i = 0; i < reInd.size(); ++i) {
1046 newLow.push_back(padOp.getMixedLowPad()[idx]);
1047 newHigh.push_back(padOp.getMixedHighPad()[idx]);
1052 RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1053 auto newPadOp = rewriter.
create<tensor::PadOp>(
1054 loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1055 padOp.getConstantPaddingValue(), padOp.getNofold());
1058 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1069 struct FoldReshapeWithGenericOpByExpansion
1072 FoldReshapeWithGenericOpByExpansion(
MLIRContext *context,
1076 controlFoldingReshapes(std::move(foldReshapes)) {}
1078 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1081 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1082 if (!producerResult) {
1084 "source not produced by an operation");
1087 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1090 "producer not a generic op");
1095 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1097 reshapeOp,
"failed preconditions of fusion with producer generic op");
1100 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1102 "fusion blocked by control function");
1105 std::optional<SmallVector<Value>> replacementValues =
1107 producer, reshapeOp,
1108 producer.getDpsInitOperand(producerResult.getResultNumber()),
1110 if (!replacementValues) {
1112 "fusion by expansion failed");
1119 Value reshapeReplacement =
1120 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1121 .getResultNumber()];
1122 if (
auto collapseOp =
1123 reshapeReplacement.
getDefiningOp<tensor::CollapseShapeOp>()) {
1124 reshapeReplacement = collapseOp.getSrc();
1126 rewriter.
replaceOp(reshapeOp, reshapeReplacement);
1127 rewriter.
replaceOp(producer, *replacementValues);
1149 "expected projected permutation");
1152 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1153 return cast<AffineDimExpr>(indexingMap.
getResults()[pos]).getPosition();
1157 return domainReassociation;
1165 assert(!dimSequence.empty() &&
1166 "expected non-empty list for dimension sequence");
1168 "expected indexing map to be projected permutation");
1170 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1171 sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1173 unsigned dimSequenceStart = dimSequence[0];
1175 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1177 if (dimInMapStart == dimSequenceStart) {
1178 if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
1181 for (
const auto &dimInSequence :
enumerate(dimSequence)) {
1183 cast<AffineDimExpr>(
1184 indexingMap.
getResult(expr.index() + dimInSequence.index()))
1186 if (dimInMap != dimInSequence.value())
1197 if (sequenceElements.count(dimInMapStart))
1206 return llvm::all_of(maps, [&](
AffineMap map) {
1263 if (!genericOp.hasPureTensorSemantics())
1266 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
1267 return map.isProjectedPermutation();
1274 genericOp.getReductionDims(reductionDims);
1276 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1277 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1278 auto iteratorTypes = genericOp.getIteratorTypesArray();
1281 assert(!foldedRangeDims.empty() &&
"unexpected empty reassociation");
1284 if (foldedRangeDims.size() == 1)
1292 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1293 return processedIterationDims.count(dim);
1298 utils::IteratorType startIteratorType =
1299 iteratorTypes[foldedIterationSpaceDims[0]];
1303 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1304 return iteratorTypes[dim] != startIteratorType;
1313 bool isContiguous =
false;
1316 if (startDim.value() != foldedIterationSpaceDims[0])
1320 if (startDim.index() + foldedIterationSpaceDims.size() >
1321 reductionDims.size())
1324 isContiguous =
true;
1325 for (
const auto &foldedDim :
1327 if (reductionDims[foldedDim.index() + startDim.index()] !=
1328 foldedDim.value()) {
1329 isContiguous =
false;
1340 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1342 return !isDimSequencePreserved(indexingMap,
1343 foldedIterationSpaceDims);
1347 processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1348 foldedIterationSpaceDims.end());
1349 iterationSpaceReassociation.emplace_back(
1350 std::move(foldedIterationSpaceDims));
1353 return iterationSpaceReassociation;
1358 class CollapsingInfo {
1360 LogicalResult initialize(
unsigned origNumLoops,
1362 llvm::SmallDenseSet<int64_t, 4> processedDims;
1365 if (foldedIterationDim.empty())
1369 for (
auto dim : foldedIterationDim) {
1370 if (dim >= origNumLoops)
1372 if (processedDims.count(dim))
1374 processedDims.insert(dim);
1376 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1377 foldedIterationDim.end());
1379 if (processedDims.size() > origNumLoops)
1384 for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1385 if (processedDims.count(dim))
1390 llvm::sort(collapsedOpToOrigOpIterationDim,
1392 return lhs[0] < rhs[0];
1394 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1395 for (
const auto &foldedDims :
1397 for (
const auto &dim :
enumerate(foldedDims.value()))
1398 origOpToCollapsedOpIterationDim[dim.value()] =
1399 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1406 return collapsedOpToOrigOpIterationDim;
1430 return origOpToCollapsedOpIterationDim;
1434 unsigned getCollapsedOpIterationRank()
const {
1435 return collapsedOpToOrigOpIterationDim.size();
1453 const CollapsingInfo &collapsingInfo) {
1456 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1457 assert(!foldedIterDims.empty() &&
1458 "reassociation indices expected to have non-empty sets");
1462 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1464 return collapsedIteratorTypes;
1471 const CollapsingInfo &collapsingInfo) {
1474 "expected indexing map to be projected permutation");
1476 auto origOpToCollapsedOpMapping =
1477 collapsingInfo.getOrigOpToCollapsedOpMapping();
1479 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1481 if (origOpToCollapsedOpMapping[dim].second != 0)
1485 resultExprs.push_back(
1488 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1489 resultExprs, context);
1496 const CollapsingInfo &collapsingInfo) {
1497 unsigned counter = 0;
1499 auto origOpToCollapsedOpMapping =
1500 collapsingInfo.getOrigOpToCollapsedOpMapping();
1501 auto collapsedOpToOrigOpMapping =
1502 collapsingInfo.getCollapsedOpToOrigOpMapping();
1505 cast<AffineDimExpr>(indexingMap.
getResult(counter)).getPosition();
1509 unsigned numFoldedDims =
1510 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1512 if (origOpToCollapsedOpMapping[dim].second == 0) {
1513 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1514 operandReassociation.emplace_back(range.begin(), range.end());
1516 counter += numFoldedDims;
1518 return operandReassociation;
1524 const CollapsingInfo &collapsingInfo,
1526 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1534 if (operandReassociation.size() == indexingMap.
getNumResults())
1538 if (isa<MemRefType>(operand.
getType())) {
1540 .
create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1544 .
create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1551 const CollapsingInfo &collapsingInfo,
1558 auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
1568 for (
auto foldedDims :
1569 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1572 rewriter.
create<linalg::IndexOp>(loc, foldedDims.index());
1573 for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1574 indexReplacementVals[dim] =
1575 rewriter.
create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1577 rewriter.
create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1579 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1582 for (
auto indexOp : indexOps) {
1583 auto dim = indexOp.getDim();
1584 rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
1589 const CollapsingInfo &collapsingInfo,
1596 llvm::map_to_vector(op.getDpsInputOperands(), [&](
OpOperand *opOperand) {
1597 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1602 resultTypes.reserve(op.getNumDpsInits());
1603 outputOperands.reserve(op.getNumDpsInits());
1604 for (
OpOperand &output : op.getDpsInitsMutable()) {
1607 outputOperands.push_back(newOutput);
1610 if (!op.hasPureBufferSemantics())
1611 resultTypes.push_back(newOutput.
getType());
1616 template <
typename OpTy>
1618 const CollapsingInfo &collapsingInfo) {
1626 const CollapsingInfo &collapsingInfo) {
1630 outputOperands, resultTypes);
1633 rewriter, origOp, resultTypes,
1634 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1641 const CollapsingInfo &collapsingInfo) {
1645 outputOperands, resultTypes);
1647 llvm::map_range(origOp.getIndexingMapsArray(), [&](
AffineMap map) {
1648 return getCollapsedOpIndexingMap(map, collapsingInfo);
1652 origOp.getIteratorTypesArray(), collapsingInfo));
1654 GenericOp collapsedOp = rewriter.
create<linalg::GenericOp>(
1655 origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1657 Block *origOpBlock = &origOp->getRegion(0).
front();
1658 Block *collapsedOpBlock = &collapsedOp->getRegion(0).
front();
1659 rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
1666 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1678 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1680 return foldedDims.size() <= 1;
1684 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1685 if (hasPureBufferSemantics &&
1686 !llvm::all_of(op->getOperands(), [&](
Value operand) ->
bool {
1687 MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1688 if (!memRefToCollapse)
1691 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1692 memRefToCollapse, foldedIterationDims);
1695 "memref is not guaranteed collapsible");
1697 CollapsingInfo collapsingInfo;
1699 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1701 op,
"illegal to collapse specified dimensions");
1706 auto opFoldIsConstantValue = [](
OpFoldResult ofr, int64_t value) {
1707 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1708 return cast<IntegerAttr>(attr).getInt() == value;
1711 actual.getSExtValue() == value;
1713 if (!llvm::all_of(loopRanges, [&](
Range range) {
1714 return opFoldIsConstantValue(range.
offset, 0) &&
1715 opFoldIsConstantValue(range.
stride, 1);
1718 op,
"expected all loop ranges to have zero start and unit stride");
1724 if (collapsedOp.hasIndexSemantics()) {
1729 llvm::map_to_vector(loopRanges, [&](
Range range) {
1733 collapsingInfo, loopBound, rewriter);
1739 for (
const auto &originalResult :
llvm::enumerate(op->getResults())) {
1740 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1741 auto originalResultType =
1742 cast<ShapedType>(originalResult.value().getType());
1743 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.
getType());
1744 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1746 op.getIndexingMapMatchingResult(originalResult.value());
1750 if (isa<MemRefType>(collapsedOpResult.
getType())) {
1752 originalResultType.getShape(), originalResultType.getElementType());
1753 result = rewriter.
create<memref::ExpandShapeOp>(
1754 loc, expandShapeResultType, collapsedOpResult, reassociation);
1756 result = rewriter.
create<tensor::ExpandShapeOp>(
1757 loc, originalResultType, collapsedOpResult, reassociation);
1759 results.push_back(result);
1761 results.push_back(collapsedOpResult);
1771 class FoldWithProducerReshapeOpByCollapsing
1774 FoldWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1778 controlFoldingReshapes(std::move(foldReshapes)) {}
1780 LogicalResult matchAndRewrite(GenericOp genericOp,
1782 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
1783 tensor::ExpandShapeOp reshapeOp =
1784 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1790 reshapeOp.getReassociationIndices());
1791 if (collapsableIterationDims.empty() ||
1792 !controlFoldingReshapes(&opOperand)) {
1797 genericOp, collapsableIterationDims, rewriter);
1798 if (!collapseResult) {
1800 genericOp,
"failed to do the fusion by collapsing transformation");
1803 rewriter.
replaceOp(genericOp, collapseResult->results);
1813 class FoldPadWithProducerReshapeOpByCollapsing
1816 FoldPadWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1820 controlFoldingReshapes(std::move(foldReshapes)) {}
1822 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1824 tensor::ExpandShapeOp reshapeOp =
1825 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1828 if (!reshapeOp->hasOneUse())
1831 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1833 "fusion blocked by control function");
1839 reshapeOp.getReassociationIndices();
1841 for (
auto reInd : reassociations) {
1842 if (reInd.size() == 1)
1844 if (llvm::any_of(reInd, [&](int64_t ind) {
1845 return low[ind] != 0 || high[ind] != 0;
1852 RankedTensorType collapsedType = reshapeOp.getSrcType();
1853 RankedTensorType paddedType = padOp.getResultType();
1857 reshapeOp.getOutputShape(), rewriter));
1861 Location loc = reshapeOp->getLoc();
1865 if (reInd.size() == 1) {
1866 collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1868 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1869 expandedPaddedSizes[reInd[0]] = paddedSize;
1871 newLow.push_back(l);
1872 newHigh.push_back(h);
1875 RankedTensorType collapsedPaddedType =
1876 paddedType.clone(collapsedPaddedShape);
1877 auto newPadOp = rewriter.
create<tensor::PadOp>(
1878 loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1879 padOp.getConstantPaddingValue(), padOp.getNofold());
1882 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1883 expandedPaddedSizes);
1893 template <
typename LinalgType>
1900 controlCollapseDimension(std::move(collapseDimensions)) {}
1902 LogicalResult matchAndRewrite(LinalgType op,
1905 controlCollapseDimension(op);
1906 if (collapsableIterationDims.empty())
1911 collapsableIterationDims)) {
1913 op,
"specified dimensions cannot be collapsed");
1916 std::optional<CollapseResult> collapseResult =
1918 if (!collapseResult) {
1921 rewriter.
replaceOp(op, collapseResult->results);
1943 LogicalResult matchAndRewrite(GenericOp genericOp,
1945 if (!genericOp.hasPureTensorSemantics())
1947 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1948 Operation *def = opOperand->get().getDefiningOp();
1949 TypedAttr constantAttr;
1950 auto isScalarOrSplatConstantOp = [&constantAttr](
Operation *def) ->
bool {
1953 if (
matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1955 splatAttr.
getType().getElementType().isIntOrFloat()) {
1961 IntegerAttr intAttr;
1962 if (
matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1963 constantAttr = intAttr;
1968 FloatAttr floatAttr;
1969 if (
matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1970 constantAttr = floatAttr;
1977 auto resultValue = dyn_cast<OpResult>(opOperand->get());
1978 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1987 fusedIndexMaps.reserve(genericOp->getNumOperands());
1988 fusedOperands.reserve(genericOp.getNumDpsInputs());
1989 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
1990 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
1991 if (inputOperand == opOperand)
1993 Value inputValue = inputOperand->get();
1994 fusedIndexMaps.push_back(
1995 genericOp.getMatchingIndexingMap(inputOperand));
1996 fusedOperands.push_back(inputValue);
1997 fusedLocs.push_back(inputValue.
getLoc());
1999 for (
OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2000 fusedIndexMaps.push_back(
2001 genericOp.getMatchingIndexingMap(&outputOperand));
2007 genericOp,
"fused op loop bound computation failed");
2011 Value scalarConstant =
2012 rewriter.
create<arith::ConstantOp>(def->
getLoc(), constantAttr);
2015 auto fusedOp = rewriter.
create<GenericOp>(
2016 rewriter.
getFusedLoc(fusedLocs), genericOp->getResultTypes(),
2020 genericOp.getIteratorTypes(),
2026 Region ®ion = genericOp->getRegion(0);
2029 mapping.
map(entryBlock.
getArgument(opOperand->getOperandNumber()),
2031 Region &fusedRegion = fusedOp->getRegion(0);
2034 rewriter.
replaceOp(genericOp, fusedOp->getResults());
2055 LogicalResult matchAndRewrite(GenericOp op,
2058 bool modifiedOutput =
false;
2060 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
2061 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2062 Value operandVal = opOperand.get();
2063 auto operandType = dyn_cast<RankedTensorType>(operandVal.
getType());
2072 auto definingOp = operandVal.
getDefiningOp<tensor::EmptyOp>();
2075 modifiedOutput =
true;
2078 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
2079 loc, mixedSizes, operandType.getElementType());
2080 op->
setOperand(opOperand.getOperandNumber(), emptyTensor);
2083 if (!modifiedOutput) {
2096 LogicalResult matchAndRewrite(GenericOp genericOp,
2098 if (!genericOp.hasPureTensorSemantics())
2100 bool fillFound =
false;
2101 Block &payload = genericOp.getRegion().
front();
2102 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2103 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2105 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2109 Value fillVal = fillOp.value();
2111 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2112 Value convertedVal =
2116 payload.
getArgument(opOperand->getOperandNumber()), convertedVal);
2118 return success(fillFound);
2127 controlFoldingReshapes);
2128 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(
patterns.getContext(),
2129 controlFoldingReshapes);
2131 controlFoldingReshapes);
2138 controlFoldingReshapes);
2139 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2140 patterns.getContext(), controlFoldingReshapes);
2146 auto *context =
patterns.getContext();
2147 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2148 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2149 RemoveOutsDependency>(context);
2157 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2158 CollapseLinalgDimensions<linalg::CopyOp>>(
2159 patterns.getContext(), controlCollapseDimensions);
2174 struct LinalgElementwiseOpFusionPass
2175 :
public impl::LinalgElementwiseOpFusionPassBase<
2176 LinalgElementwiseOpFusionPass> {
2177 using impl::LinalgElementwiseOpFusionPassBase<
2178 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2179 void runOnOperation()
override {
2186 Operation *producer = fusedOperand->get().getDefiningOp();
2187 return producer && producer->
hasOneUse();
2196 affine::AffineApplyOp::getCanonicalizationPatterns(
patterns, context);
2197 GenericOp::getCanonicalizationPatterns(
patterns, context);
2198 tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
2199 tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Expanding the body of a linalg operation requires adaptations of the accessed loop indices.
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ValueRange loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
static LogicalResult validateDynamicDimExpansion(LinalgOp linalgOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Checks if a single dynamic dimension expanded into multiple dynamic dimensions.
void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the type of the operand/result to use in the expanded op given the type in the original op.
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
bool hasOneUse()
Returns true if this operation has exactly one use.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
ArrayRef< int64_t > ReassociationIndicesRef
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
llvm::DenseMap< Value, Value > replacements
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.