31 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
32 #include "mlir/Dialect/Linalg/Passes.h.inc"
58 assert(invProducerResultIndexMap &&
59 "expected producer result indexing map to be invertible");
61 LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
63 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
71 return t1.
compose(fusedConsumerArgIndexMap);
78 GenericOp producer, GenericOp consumer,
83 for (
auto &op : ops) {
84 for (
auto &opOperand : op->getOpOperands()) {
85 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
88 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
105 GenericOp producer, GenericOp consumer,
OpOperand *fusedOperand) {
106 llvm::SmallDenseSet<int> preservedProducerResults;
110 opOperandsToIgnore.emplace_back(fusedOperand);
112 for (
const auto &producerResult :
llvm::enumerate(producer->getResults())) {
113 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
114 opOperandsToIgnore.emplace_back(outputOperand);
115 if (producer.payloadUsesValueFromOperand(outputOperand) ||
117 opOperandsToIgnore) ||
118 llvm::any_of(producerResult.value().getUsers(), [&](
Operation *user) {
119 return user != consumer.getOperation();
121 preservedProducerResults.insert(producerResult.index());
124 (void)opOperandsToIgnore.pop_back_val();
127 return preservedProducerResults;
136 auto consumer = dyn_cast<GenericOp>(fusedOperand->
getOwner());
139 if (!producer || !consumer)
145 if (!producer.hasPureTensorSemantics() ||
146 !isa<RankedTensorType>(fusedOperand->
get().
getType()))
151 if (producer.getNumParallelLoops() != producer.getNumLoops())
156 if (!consumer.isDpsInput(fusedOperand))
161 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
162 if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
168 producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
176 if ((consumer.getNumReductionLoops())) {
177 BitVector coveredDims(consumer.getNumLoops(),
false);
179 auto addToCoveredDims = [&](
AffineMap map) {
180 for (
auto result : map.getResults())
181 if (
auto dimExpr = dyn_cast<AffineDimExpr>(result))
182 coveredDims[dimExpr.getPosition()] =
true;
186 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
187 Value operand = std::get<0>(pair);
188 if (operand == fusedOperand->
get())
190 AffineMap operandMap = std::get<1>(pair);
191 addToCoveredDims(operandMap);
194 for (
OpOperand *operand : producer.getDpsInputOperands()) {
197 operand, producerResultIndexMap, consumerIndexMap);
198 addToCoveredDims(newIndexingMap);
200 if (!coveredDims.all())
212 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
214 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
216 Block &producerBlock = producer->getRegion(0).
front();
217 Block &consumerBlock = consumer->getRegion(0).
front();
224 if (producer.hasIndexSemantics()) {
226 unsigned numFusedOpLoops =
227 std::max(producer.getNumLoops(), consumer.getNumLoops());
229 fusedIndices.reserve(numFusedOpLoops);
230 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
231 std::back_inserter(fusedIndices), [&](uint64_t dim) {
232 return rewriter.
create<IndexOp>(producer.getLoc(), dim);
234 for (IndexOp indexOp :
235 llvm::make_early_inc_range(producerBlock.
getOps<IndexOp>())) {
236 Value newIndex = rewriter.
create<affine::AffineApplyOp>(
238 consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
239 mapper.
map(indexOp.getResult(), newIndex);
243 assert(consumer.isDpsInput(fusedOperand) &&
244 "expected producer of input operand");
248 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
255 producerBlock.
getArguments().take_front(producer.getNumDpsInputs()))
256 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
261 .take_front(consumer.getNumDpsInputs())
263 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
267 producerBlock.
getArguments().take_back(producer.getNumDpsInits()))) {
268 if (!preservedProducerResults.count(bbArg.index()))
270 mapper.
map(bbArg.value(), fusedBlock->
addArgument(bbArg.value().getType(),
271 bbArg.value().getLoc()));
276 consumerBlock.
getArguments().take_back(consumer.getNumDpsInits()))
277 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
282 if (!isa<IndexOp>(op))
283 rewriter.
clone(op, mapper);
287 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.
getTerminator());
288 unsigned producerResultNumber =
289 cast<OpResult>(fusedOperand->
get()).getResultNumber();
291 mapper.
lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
295 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
296 if (
auto bb = dyn_cast<BlockArgument>(replacement))
297 assert(bb.getOwner() != &producerBlock &&
298 "yielded block argument must have been mapped");
301 "yielded value must have been mapped");
307 rewriter.
clone(op, mapper);
311 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.
getTerminator());
313 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
314 consumerYieldOp.getNumOperands());
315 for (
const auto &producerYieldVal :
317 if (preservedProducerResults.count(producerYieldVal.index()))
318 fusedYieldValues.push_back(
321 for (
auto consumerYieldVal : consumerYieldOp.getOperands())
323 rewriter.
create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
327 "Ill-formed GenericOp region");
330 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
334 "expected elementwise operation pre-conditions to pass");
335 auto producerResult = cast<OpResult>(fusedOperand->
get());
336 auto producer = cast<GenericOp>(producerResult.getOwner());
337 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
339 assert(consumer.isDpsInput(fusedOperand) &&
340 "expected producer of input operand");
343 llvm::SmallDenseSet<int> preservedProducerResults =
351 fusedInputOperands.reserve(producer.getNumDpsInputs() +
352 consumer.getNumDpsInputs());
353 fusedOutputOperands.reserve(preservedProducerResults.size() +
354 consumer.getNumDpsInits());
355 fusedResultTypes.reserve(preservedProducerResults.size() +
356 consumer.getNumDpsInits());
357 fusedIndexMaps.reserve(producer->getNumOperands() +
358 consumer->getNumOperands());
361 auto consumerInputs = consumer.getDpsInputOperands();
362 auto *it = llvm::find_if(consumerInputs, [&](
OpOperand *operand) {
363 return operand == fusedOperand;
365 assert(it != consumerInputs.end() &&
"expected to find the consumer operand");
366 for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
367 fusedInputOperands.push_back(opOperand->get());
368 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
372 producer.getIndexingMapMatchingResult(producerResult);
373 for (
OpOperand *opOperand : producer.getDpsInputOperands()) {
374 fusedInputOperands.push_back(opOperand->get());
377 opOperand, producerResultIndexMap,
378 consumer.getMatchingIndexingMap(fusedOperand));
379 fusedIndexMaps.push_back(map);
384 llvm::make_range(std::next(it), consumerInputs.end())) {
385 fusedInputOperands.push_back(opOperand->get());
386 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
390 for (
const auto &opOperand :
llvm::enumerate(producer.getDpsInitsMutable())) {
391 if (!preservedProducerResults.count(opOperand.index()))
394 fusedOutputOperands.push_back(opOperand.value().get());
396 &opOperand.value(), producerResultIndexMap,
397 consumer.getMatchingIndexingMap(fusedOperand));
398 fusedIndexMaps.push_back(map);
399 fusedResultTypes.push_back(opOperand.value().get().getType());
403 for (
OpOperand &opOperand : consumer.getDpsInitsMutable()) {
404 fusedOutputOperands.push_back(opOperand.get());
405 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
406 Type resultType = opOperand.get().getType();
407 if (!isa<MemRefType>(resultType))
408 fusedResultTypes.push_back(resultType);
412 auto fusedOp = rewriter.
create<GenericOp>(
413 consumer.getLoc(), fusedResultTypes, fusedInputOperands,
415 consumer.getIteratorTypes(),
418 if (!fusedOp.getShapesToLoopsMap()) {
424 fusedOp,
"fused op failed loop bound computation check");
430 consumer.getMatchingIndexingMap(fusedOperand);
434 assert(invProducerResultIndexMap &&
435 "expected producer result indexig map to be invertible");
438 invProducerResultIndexMap.
compose(consumerResultIndexMap);
441 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
442 consumer.getNumLoops(), preservedProducerResults);
446 for (
auto [index, producerResult] :
llvm::enumerate(producer->getResults()))
447 if (preservedProducerResults.count(index))
448 result.
replacements[producerResult] = fusedOp->getResult(resultNum++);
449 for (
auto consumerResult : consumer->getResults())
450 result.
replacements[consumerResult] = fusedOp->getResult(resultNum++);
461 controlFn(std::move(fun)) {}
463 LogicalResult matchAndRewrite(GenericOp genericOp,
466 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
469 if (!controlFn(&opOperand))
472 Operation *producer = opOperand.get().getDefiningOp();
475 FailureOr<ElementwiseOpFusionResult> fusionResult =
477 if (failed(fusionResult))
481 for (
auto [origVal, replacement] : fusionResult->replacements) {
484 return use.
get().getDefiningOp() != producer;
564 linalgOp.getIteratorTypesArray();
565 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
566 return linalgOp.hasPureTensorSemantics() &&
567 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
569 return cast<AffineMapAttr>(attr)
571 .isProjectedPermutation();
575 return isParallelIterator(
576 iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
583 class ExpansionInfo {
589 LogicalResult compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
594 unsigned getOrigOpNumDims()
const {
return reassociation.size(); }
595 unsigned getExpandedOpNumDims()
const {
return expandedOpNumDims; }
597 return reassociation[i];
600 return expandedShapeMap[i];
613 unsigned expandedOpNumDims;
617 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
623 if (reassociationMaps.empty())
625 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
628 originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
630 reassociation.clear();
631 expandedShapeMap.clear();
635 expandedShapeMap.resize(fusedIndexMap.
getNumDims());
637 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
638 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
641 expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
642 expandedShapeMap[pos].assign(shape.begin(), shape.end());
645 for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
646 if (expandedShapeMap[i].empty())
647 expandedShapeMap[i] = {originalLoopExtent[i]};
651 reassociation.reserve(fusedIndexMap.
getNumDims());
653 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
654 reassociation.emplace_back(seq.begin(), seq.end());
655 sum += numFoldedDim.value();
657 expandedOpNumDims = sum;
670 const ExpansionInfo &expansionInfo,
672 if (!linalgOp.hasIndexSemantics())
674 for (
unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
676 if (expandedShape.size() == 1)
678 for (int64_t shape : expandedShape.drop_front()) {
679 if (ShapedType::isDynamic(shape)) {
681 linalgOp,
"cannot expand due to index semantics and dynamic dims");
692 const ExpansionInfo &expansionInfo) {
695 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
697 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
698 return builder.getAffineDimExpr(static_cast<unsigned>(v));
700 newExprs.append(expandedExprs.begin(), expandedExprs.end());
711 const ExpansionInfo &expansionInfo) {
714 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
715 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
716 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
729 const ExpansionInfo &expansionInfo) {
731 unsigned numReshapeDims = 0;
733 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
734 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
736 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
737 reassociation.emplace_back(std::move(indices));
738 numReshapeDims += numExpandedDims;
740 return reassociation;
750 const ExpansionInfo &expansionInfo) {
752 for (IndexOp indexOp :
753 llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
755 expansionInfo.getExpandedDims(indexOp.getDim());
756 assert(!expandedDims.empty() &&
"expected valid expansion info");
759 if (expandedDims.size() == 1 &&
760 expandedDims.front() == (int64_t)indexOp.getDim())
767 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
769 expandedIndices.reserve(expandedDims.size() - 1);
771 expandedDims.drop_front(), std::back_inserter(expandedIndices),
772 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
773 Value newIndex = rewriter.
create<IndexOp>(loc, expandedDims.front());
774 for (
auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
775 assert(!ShapedType::isDynamic(std::get<0>(it)));
778 newIndex = rewriter.
create<affine::AffineApplyOp>(
779 indexOp.getLoc(), idx + acc * std::get<0>(it),
790 const ExpansionInfo &expansionInfo,
792 for (
unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
794 if (expandedShape.size() == 1)
796 bool foundDynamic =
false;
797 for (int64_t shape : expandedShape) {
798 if (!ShapedType::isDynamic(shape))
802 linalgOp,
"cannot infer expanded shape with multiple dynamic "
803 "dims in the same reassociation group");
814 static std::optional<SmallVector<Value>>
819 "preconditions for fuse operation failed");
823 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
824 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
825 bool isExpanding = (expandingReshapeOp !=
nullptr);
826 RankedTensorType expandedType = isExpanding
827 ? expandingReshapeOp.getResultType()
828 : collapsingReshapeOp.getSrcType();
829 RankedTensorType collapsedType = isExpanding
830 ? expandingReshapeOp.getSrcType()
831 : collapsingReshapeOp.getResultType();
833 ExpansionInfo expansionInfo;
834 if (failed(expansionInfo.compute(
835 linalgOp, fusableOpOperand,
836 isExpanding ? expandingReshapeOp.getReassociationMaps()
837 : collapsingReshapeOp.getReassociationMaps(),
838 expandedType.getShape(), collapsedType.getShape(), rewriter)))
850 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](
AffineMap m) {
851 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
859 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
860 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
861 if (opOperand == fusableOpOperand) {
862 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
863 : collapsingReshapeOp.getSrc());
866 if (
auto opOperandType =
867 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
868 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
869 RankedTensorType expandedOperandType =
871 if (expandedOperandType != opOperand->get().getType()) {
876 [&](
const Twine &msg) {
879 opOperandType.getShape(), expandedOperandType.getShape(),
883 expandedOpOperands.push_back(rewriter.
create<tensor::ExpandShapeOp>(
884 loc, expandedOperandType, opOperand->get(), reassociation));
888 expandedOpOperands.push_back(opOperand->get());
892 for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
893 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
894 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
895 RankedTensorType expandedOutputType =
897 if (expandedOutputType != opOperand.get().getType()) {
901 [&](
const Twine &msg) {
904 opOperandType.getShape(), expandedOutputType.getShape(),
908 outputs.push_back(rewriter.
create<tensor::ExpandShapeOp>(
909 loc, expandedOutputType, opOperand.get(), reassociation));
911 outputs.push_back(opOperand.get());
917 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
918 for (
auto [i, type] :
llvm::enumerate(linalgOp.getIteratorTypesArray()))
919 for (
auto j : expansionInfo.getExpandedDims(i))
920 iteratorTypes[
j] = type;
924 rewriter.
create<GenericOp>(linalgOp.getLoc(), resultTypes,
925 expandedOpOperands, outputs,
926 expandedOpIndexingMaps, iteratorTypes);
927 Region &fusedRegion = fusedOp->getRegion(0);
928 Region &originalRegion = linalgOp->getRegion(0);
937 for (
OpResult opResult : linalgOp->getOpResults()) {
938 int64_t resultNumber = opResult.getResultNumber();
939 if (resultTypes[resultNumber] != opResult.getType()) {
942 linalgOp.getMatchingIndexingMap(
943 linalgOp.getDpsInitOperand(resultNumber)),
945 resultVals.push_back(rewriter.
create<tensor::CollapseShapeOp>(
946 linalgOp.getLoc(), opResult.getType(),
947 fusedOp->getResult(resultNumber), reassociation));
949 resultVals.push_back(fusedOp->getResult(resultNumber));
961 class FoldWithProducerReshapeOpByExpansion
964 FoldWithProducerReshapeOpByExpansion(
MLIRContext *context,
968 controlFoldingReshapes(std::move(foldReshapes)) {}
970 LogicalResult matchAndRewrite(LinalgOp linalgOp,
972 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
973 tensor::CollapseShapeOp reshapeOp =
974 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
981 (!controlFoldingReshapes(opOperand)))
984 std::optional<SmallVector<Value>> replacementValues =
986 if (!replacementValues)
988 rewriter.
replaceOp(linalgOp, *replacementValues);
998 class FoldPadWithProducerReshapeOpByExpansion
1001 FoldPadWithProducerReshapeOpByExpansion(
MLIRContext *context,
1005 controlFoldingReshapes(std::move(foldReshapes)) {}
1007 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1009 tensor::CollapseShapeOp reshapeOp =
1010 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1013 if (!reshapeOp->hasOneUse())
1016 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1018 "fusion blocked by control function");
1024 reshapeOp.getReassociationIndices();
1026 for (
auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1027 if (reInd.size() != 1 && (l != 0 || h != 0))
1032 RankedTensorType expandedType = reshapeOp.getSrcType();
1033 RankedTensorType paddedType = padOp.getResultType();
1036 if (reInd.size() == 1) {
1037 expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1039 for (
size_t i = 0; i < reInd.size(); ++i) {
1040 newLow.push_back(padOp.getMixedLowPad()[idx]);
1041 newHigh.push_back(padOp.getMixedHighPad()[idx]);
1046 RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1047 auto newPadOp = rewriter.
create<tensor::PadOp>(
1048 loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1049 padOp.getConstantPaddingValue(), padOp.getNofold());
1052 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1063 struct FoldReshapeWithGenericOpByExpansion
1066 FoldReshapeWithGenericOpByExpansion(
MLIRContext *context,
1070 controlFoldingReshapes(std::move(foldReshapes)) {}
1072 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1075 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1076 if (!producerResult) {
1078 "source not produced by an operation");
1081 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1084 "producer not a generic op");
1089 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1091 reshapeOp,
"failed preconditions of fusion with producer generic op");
1094 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1096 "fusion blocked by control function");
1099 std::optional<SmallVector<Value>> replacementValues =
1101 producer, reshapeOp,
1102 producer.getDpsInitOperand(producerResult.getResultNumber()),
1104 if (!replacementValues) {
1106 "fusion by expansion failed");
1113 Value reshapeReplacement =
1114 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1115 .getResultNumber()];
1116 if (
auto collapseOp =
1117 reshapeReplacement.
getDefiningOp<tensor::CollapseShapeOp>()) {
1118 reshapeReplacement = collapseOp.getSrc();
1120 rewriter.
replaceOp(reshapeOp, reshapeReplacement);
1121 rewriter.
replaceOp(producer, *replacementValues);
1143 "expected projected permutation");
1146 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1147 return cast<AffineDimExpr>(indexingMap.
getResults()[pos]).getPosition();
1151 return domainReassociation;
1159 assert(!dimSequence.empty() &&
1160 "expected non-empty list for dimension sequence");
1162 "expected indexing map to be projected permutation");
1164 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1165 sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1167 unsigned dimSequenceStart = dimSequence[0];
1169 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1171 if (dimInMapStart == dimSequenceStart) {
1172 if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
1175 for (
const auto &dimInSequence :
enumerate(dimSequence)) {
1177 cast<AffineDimExpr>(
1178 indexingMap.
getResult(expr.index() + dimInSequence.index()))
1180 if (dimInMap != dimInSequence.value())
1191 if (sequenceElements.count(dimInMapStart))
1200 return llvm::all_of(maps, [&](
AffineMap map) {
1257 if (!genericOp.hasPureTensorSemantics())
1260 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
1261 return map.isProjectedPermutation();
1268 genericOp.getReductionDims(reductionDims);
1270 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1271 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1272 auto iteratorTypes = genericOp.getIteratorTypesArray();
1275 assert(!foldedRangeDims.empty() &&
"unexpected empty reassociation");
1278 if (foldedRangeDims.size() == 1)
1286 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1287 return processedIterationDims.count(dim);
1292 utils::IteratorType startIteratorType =
1293 iteratorTypes[foldedIterationSpaceDims[0]];
1297 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1298 return iteratorTypes[dim] != startIteratorType;
1307 bool isContiguous =
false;
1310 if (startDim.value() != foldedIterationSpaceDims[0])
1314 if (startDim.index() + foldedIterationSpaceDims.size() >
1315 reductionDims.size())
1318 isContiguous =
true;
1319 for (
const auto &foldedDim :
1321 if (reductionDims[foldedDim.index() + startDim.index()] !=
1322 foldedDim.value()) {
1323 isContiguous =
false;
1334 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1336 return !isDimSequencePreserved(indexingMap,
1337 foldedIterationSpaceDims);
1341 processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1342 foldedIterationSpaceDims.end());
1343 iterationSpaceReassociation.emplace_back(
1344 std::move(foldedIterationSpaceDims));
1347 return iterationSpaceReassociation;
1352 class CollapsingInfo {
1354 LogicalResult initialize(
unsigned origNumLoops,
1356 llvm::SmallDenseSet<int64_t, 4> processedDims;
1359 if (foldedIterationDim.empty())
1363 for (
auto dim : foldedIterationDim) {
1364 if (dim >= origNumLoops)
1366 if (processedDims.count(dim))
1368 processedDims.insert(dim);
1370 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1371 foldedIterationDim.end());
1373 if (processedDims.size() > origNumLoops)
1378 for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1379 if (processedDims.count(dim))
1384 llvm::sort(collapsedOpToOrigOpIterationDim,
1386 return lhs[0] < rhs[0];
1388 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1389 for (
const auto &foldedDims :
1391 for (
const auto &dim :
enumerate(foldedDims.value()))
1392 origOpToCollapsedOpIterationDim[dim.value()] =
1393 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1400 return collapsedOpToOrigOpIterationDim;
1424 return origOpToCollapsedOpIterationDim;
1428 unsigned getCollapsedOpIterationRank()
const {
1429 return collapsedOpToOrigOpIterationDim.size();
1447 const CollapsingInfo &collapsingInfo) {
1450 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1451 assert(!foldedIterDims.empty() &&
1452 "reassociation indices expected to have non-empty sets");
1456 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1458 return collapsedIteratorTypes;
1465 const CollapsingInfo &collapsingInfo) {
1468 "expected indexing map to be projected permutation");
1470 auto origOpToCollapsedOpMapping =
1471 collapsingInfo.getOrigOpToCollapsedOpMapping();
1473 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1475 if (origOpToCollapsedOpMapping[dim].second != 0)
1479 resultExprs.push_back(
1482 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1483 resultExprs, context);
1490 const CollapsingInfo &collapsingInfo) {
1491 unsigned counter = 0;
1493 auto origOpToCollapsedOpMapping =
1494 collapsingInfo.getOrigOpToCollapsedOpMapping();
1495 auto collapsedOpToOrigOpMapping =
1496 collapsingInfo.getCollapsedOpToOrigOpMapping();
1499 cast<AffineDimExpr>(indexingMap.
getResult(counter)).getPosition();
1503 unsigned numFoldedDims =
1504 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1506 if (origOpToCollapsedOpMapping[dim].second == 0) {
1507 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1508 operandReassociation.emplace_back(range.begin(), range.end());
1510 counter += numFoldedDims;
1512 return operandReassociation;
1518 const CollapsingInfo &collapsingInfo,
1520 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1528 if (operandReassociation.size() == indexingMap.
getNumResults())
1532 if (isa<MemRefType>(operand.
getType())) {
1534 .
create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1538 .
create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1545 const CollapsingInfo &collapsingInfo,
1552 auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
1562 for (
auto foldedDims :
1563 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1566 rewriter.
create<linalg::IndexOp>(loc, foldedDims.index());
1567 for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1568 indexReplacementVals[dim] =
1569 rewriter.
create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1571 rewriter.
create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1573 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1576 for (
auto indexOp : indexOps) {
1577 auto dim = indexOp.getDim();
1578 rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
1583 const CollapsingInfo &collapsingInfo,
1590 llvm::map_to_vector(op.getDpsInputOperands(), [&](
OpOperand *opOperand) {
1591 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1596 resultTypes.reserve(op.getNumDpsInits());
1597 outputOperands.reserve(op.getNumDpsInits());
1598 for (
OpOperand &output : op.getDpsInitsMutable()) {
1601 outputOperands.push_back(newOutput);
1604 if (!op.hasPureBufferSemantics())
1605 resultTypes.push_back(newOutput.
getType());
1610 template <
typename OpTy>
1612 const CollapsingInfo &collapsingInfo) {
1620 const CollapsingInfo &collapsingInfo) {
1624 outputOperands, resultTypes);
1627 rewriter, origOp, resultTypes,
1628 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1635 const CollapsingInfo &collapsingInfo) {
1639 outputOperands, resultTypes);
1641 llvm::map_range(origOp.getIndexingMapsArray(), [&](
AffineMap map) {
1642 return getCollapsedOpIndexingMap(map, collapsingInfo);
1646 origOp.getIteratorTypesArray(), collapsingInfo));
1648 GenericOp collapsedOp = rewriter.
create<linalg::GenericOp>(
1649 origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1651 Block *origOpBlock = &origOp->getRegion(0).
front();
1652 Block *collapsedOpBlock = &collapsedOp->getRegion(0).
front();
1653 rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
1660 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1672 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1674 return foldedDims.size() <= 1;
1678 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1679 if (hasPureBufferSemantics &&
1680 !llvm::all_of(op->getOperands(), [&](
Value operand) ->
bool {
1681 MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1682 if (!memRefToCollapse)
1685 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1686 memRefToCollapse, foldedIterationDims);
1689 "memref is not guaranteed collapsible");
1691 CollapsingInfo collapsingInfo;
1693 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1695 op,
"illegal to collapse specified dimensions");
1700 auto opFoldIsConstantValue = [](
OpFoldResult ofr, int64_t value) {
1701 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1702 return cast<IntegerAttr>(attr).getInt() == value;
1705 actual.getSExtValue() == value;
1707 if (!llvm::all_of(loopRanges, [&](
Range range) {
1708 return opFoldIsConstantValue(range.
offset, 0) &&
1709 opFoldIsConstantValue(range.
stride, 1);
1712 op,
"expected all loop ranges to have zero start and unit stride");
1718 if (collapsedOp.hasIndexSemantics()) {
1723 llvm::map_to_vector(loopRanges, [&](
Range range) {
1727 collapsingInfo, loopBound, rewriter);
1733 for (
const auto &originalResult :
llvm::enumerate(op->getResults())) {
1734 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1735 auto originalResultType =
1736 cast<ShapedType>(originalResult.value().getType());
1737 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.
getType());
1738 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1740 op.getIndexingMapMatchingResult(originalResult.value());
1744 if (isa<MemRefType>(collapsedOpResult.
getType())) {
1746 originalResultType.getShape(), originalResultType.getElementType());
1747 result = rewriter.
create<memref::ExpandShapeOp>(
1748 loc, expandShapeResultType, collapsedOpResult, reassociation);
1750 result = rewriter.
create<tensor::ExpandShapeOp>(
1751 loc, originalResultType, collapsedOpResult, reassociation);
1753 results.push_back(result);
1755 results.push_back(collapsedOpResult);
1765 class FoldWithProducerReshapeOpByCollapsing
1768 FoldWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1772 controlFoldingReshapes(std::move(foldReshapes)) {}
1774 LogicalResult matchAndRewrite(GenericOp genericOp,
1776 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
1777 tensor::ExpandShapeOp reshapeOp =
1778 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1784 reshapeOp.getReassociationIndices());
1785 if (collapsableIterationDims.empty() ||
1786 !controlFoldingReshapes(&opOperand)) {
1791 genericOp, collapsableIterationDims, rewriter);
1792 if (!collapseResult) {
1794 genericOp,
"failed to do the fusion by collapsing transformation");
1797 rewriter.
replaceOp(genericOp, collapseResult->results);
1807 class FoldPadWithProducerReshapeOpByCollapsing
1810 FoldPadWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1814 controlFoldingReshapes(std::move(foldReshapes)) {}
1816 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1818 tensor::ExpandShapeOp reshapeOp =
1819 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1822 if (!reshapeOp->hasOneUse())
1825 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1827 "fusion blocked by control function");
1833 reshapeOp.getReassociationIndices();
1835 for (
auto reInd : reassociations) {
1836 if (reInd.size() == 1)
1838 if (llvm::any_of(reInd, [&](int64_t ind) {
1839 return low[ind] != 0 || high[ind] != 0;
1846 RankedTensorType collapsedType = reshapeOp.getSrcType();
1847 RankedTensorType paddedType = padOp.getResultType();
1851 reshapeOp.getOutputShape(), rewriter));
1855 Location loc = reshapeOp->getLoc();
1859 if (reInd.size() == 1) {
1860 collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1862 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1863 expandedPaddedSizes[reInd[0]] = paddedSize;
1865 newLow.push_back(l);
1866 newHigh.push_back(h);
1869 RankedTensorType collapsedPaddedType =
1870 paddedType.clone(collapsedPaddedShape);
1871 auto newPadOp = rewriter.
create<tensor::PadOp>(
1872 loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1873 padOp.getConstantPaddingValue(), padOp.getNofold());
1876 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1877 expandedPaddedSizes);
1887 template <
typename LinalgType>
1894 controlCollapseDimension(std::move(collapseDimensions)) {}
1896 LogicalResult matchAndRewrite(LinalgType op,
1899 controlCollapseDimension(op);
1900 if (collapsableIterationDims.empty())
1905 collapsableIterationDims)) {
1907 op,
"specified dimensions cannot be collapsed");
1910 std::optional<CollapseResult> collapseResult =
1912 if (!collapseResult) {
1915 rewriter.
replaceOp(op, collapseResult->results);
1937 LogicalResult matchAndRewrite(GenericOp genericOp,
1939 if (!genericOp.hasPureTensorSemantics())
1941 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1942 Operation *def = opOperand->get().getDefiningOp();
1943 TypedAttr constantAttr;
1944 auto isScalarOrSplatConstantOp = [&constantAttr](
Operation *def) ->
bool {
1947 if (
matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1949 splatAttr.
getType().getElementType().isIntOrFloat()) {
1955 IntegerAttr intAttr;
1956 if (
matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1957 constantAttr = intAttr;
1962 FloatAttr floatAttr;
1963 if (
matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1964 constantAttr = floatAttr;
1971 auto resultValue = dyn_cast<OpResult>(opOperand->get());
1972 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1981 fusedIndexMaps.reserve(genericOp->getNumOperands());
1982 fusedOperands.reserve(genericOp.getNumDpsInputs());
1983 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
1984 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
1985 if (inputOperand == opOperand)
1987 Value inputValue = inputOperand->get();
1988 fusedIndexMaps.push_back(
1989 genericOp.getMatchingIndexingMap(inputOperand));
1990 fusedOperands.push_back(inputValue);
1991 fusedLocs.push_back(inputValue.
getLoc());
1993 for (
OpOperand &outputOperand : genericOp.getDpsInitsMutable())
1994 fusedIndexMaps.push_back(
1995 genericOp.getMatchingIndexingMap(&outputOperand));
2000 genericOp,
"fused op loop bound computation failed");
2004 Value scalarConstant =
2005 rewriter.
create<arith::ConstantOp>(def->
getLoc(), constantAttr);
2008 auto fusedOp = rewriter.
create<GenericOp>(
2009 rewriter.
getFusedLoc(fusedLocs), genericOp->getResultTypes(),
2013 genericOp.getIteratorTypes(),
2019 Region ®ion = genericOp->getRegion(0);
2022 mapping.
map(entryBlock.
getArgument(opOperand->getOperandNumber()),
2024 Region &fusedRegion = fusedOp->getRegion(0);
2027 rewriter.
replaceOp(genericOp, fusedOp->getResults());
2048 LogicalResult matchAndRewrite(GenericOp op,
2051 bool modifiedOutput =
false;
2053 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
2054 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2055 Value operandVal = opOperand.get();
2056 auto operandType = dyn_cast<RankedTensorType>(operandVal.
getType());
2065 auto definingOp = operandVal.
getDefiningOp<tensor::EmptyOp>();
2068 modifiedOutput =
true;
2071 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
2072 loc, mixedSizes, operandType.getElementType());
2073 op->
setOperand(opOperand.getOperandNumber(), emptyTensor);
2076 if (!modifiedOutput) {
2089 LogicalResult matchAndRewrite(GenericOp genericOp,
2091 if (!genericOp.hasPureTensorSemantics())
2093 bool fillFound =
false;
2094 Block &payload = genericOp.getRegion().
front();
2095 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2096 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2098 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2102 Value fillVal = fillOp.value();
2104 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2105 Value convertedVal =
2109 payload.
getArgument(opOperand->getOperandNumber()), convertedVal);
2111 return success(fillFound);
2119 patterns.
add<FoldReshapeWithGenericOpByExpansion>(patterns.
getContext(),
2120 controlFoldingReshapes);
2121 patterns.
add<FoldPadWithProducerReshapeOpByExpansion>(patterns.
getContext(),
2122 controlFoldingReshapes);
2123 patterns.
add<FoldWithProducerReshapeOpByExpansion>(patterns.
getContext(),
2124 controlFoldingReshapes);
2130 patterns.
add<FoldWithProducerReshapeOpByCollapsing>(patterns.
getContext(),
2131 controlFoldingReshapes);
2132 patterns.
add<FoldPadWithProducerReshapeOpByCollapsing>(
2133 patterns.
getContext(), controlFoldingReshapes);
2140 patterns.
add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2141 patterns.
add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2142 RemoveOutsDependency>(context);
2150 patterns.
add<CollapseLinalgDimensions<linalg::GenericOp>,
2151 CollapseLinalgDimensions<linalg::CopyOp>>(
2152 patterns.
getContext(), controlCollapseDimensions);
2167 struct LinalgElementwiseOpFusionPass
2168 :
public impl::LinalgElementwiseOpFusionPassBase<
2169 LinalgElementwiseOpFusionPass> {
2170 using impl::LinalgElementwiseOpFusionPassBase<
2171 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2172 void runOnOperation()
override {
2179 Operation *producer = fusedOperand->get().getDefiningOp();
2180 return producer && producer->
hasOneUse();
2189 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2190 GenericOp::getCanonicalizationPatterns(patterns, context);
2191 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2192 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Expanding the body of a linalg operation requires adaptations of the accessed loop indices.
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ValueRange loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
static LogicalResult validateDynamicDimExpansion(LinalgOp linalgOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Checks if a single dynamic dimension expanded into multiple dynamic dimensions.
void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the type of the operand/result to use in the expanded op given the type in the original op.
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
bool hasOneUse()
Returns true if this operation has exactly one use.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
ArrayRef< int64_t > ReassociationIndicesRef
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
llvm::DenseMap< Value, Value > replacements
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.