33 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
34 #include "mlir/Dialect/Linalg/Passes.h.inc"
60 assert(invProducerResultIndexMap &&
61 "expected producer result indexing map to be invertible");
63 LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
65 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
73 return t1.
compose(fusedConsumerArgIndexMap);
80 GenericOp producer, GenericOp consumer,
85 for (
auto &op : ops) {
86 for (
auto &opOperand : op->getOpOperands()) {
87 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
90 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
93 if (indexingMaps.empty()) {
96 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
104 indexingMaps, producer.getContext())) !=
AffineMap();
113 GenericOp producer, GenericOp consumer,
OpOperand *fusedOperand) {
114 llvm::SmallDenseSet<int> preservedProducerResults;
118 opOperandsToIgnore.emplace_back(fusedOperand);
120 for (
const auto &producerResult :
llvm::enumerate(producer->getResults())) {
121 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122 opOperandsToIgnore.emplace_back(outputOperand);
123 if (producer.payloadUsesValueFromOperand(outputOperand) ||
125 opOperandsToIgnore) ||
126 llvm::any_of(producerResult.value().getUsers(), [&](
Operation *user) {
127 return user != consumer.getOperation();
129 preservedProducerResults.insert(producerResult.index());
132 (void)opOperandsToIgnore.pop_back_val();
135 return preservedProducerResults;
144 auto consumer = dyn_cast<GenericOp>(fusedOperand->
getOwner());
147 if (!producer || !consumer)
153 if (!producer.hasPureTensorSemantics() ||
154 !isa<RankedTensorType>(fusedOperand->
get().
getType()))
159 if (producer.getNumParallelLoops() != producer.getNumLoops())
164 if (!consumer.isDpsInput(fusedOperand))
169 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170 if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
176 producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
184 if ((consumer.getNumReductionLoops())) {
185 BitVector coveredDims(consumer.getNumLoops(),
false);
187 auto addToCoveredDims = [&](
AffineMap map) {
188 for (
auto result : map.getResults())
189 if (
auto dimExpr = dyn_cast<AffineDimExpr>(result))
190 coveredDims[dimExpr.getPosition()] =
true;
194 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
195 Value operand = std::get<0>(pair);
196 if (operand == fusedOperand->
get())
198 AffineMap operandMap = std::get<1>(pair);
199 addToCoveredDims(operandMap);
202 for (
OpOperand *operand : producer.getDpsInputOperands()) {
205 operand, producerResultIndexMap, consumerIndexMap);
206 addToCoveredDims(newIndexingMap);
208 if (!coveredDims.all())
220 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
222 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
224 Block &producerBlock = producer->getRegion(0).
front();
225 Block &consumerBlock = consumer->getRegion(0).
front();
232 if (producer.hasIndexSemantics()) {
234 unsigned numFusedOpLoops =
235 std::max(producer.getNumLoops(), consumer.getNumLoops());
237 fusedIndices.reserve(numFusedOpLoops);
238 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
239 std::back_inserter(fusedIndices), [&](uint64_t dim) {
240 return rewriter.
create<IndexOp>(producer.getLoc(), dim);
242 for (IndexOp indexOp :
243 llvm::make_early_inc_range(producerBlock.
getOps<IndexOp>())) {
244 Value newIndex = rewriter.
create<affine::AffineApplyOp>(
246 consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
247 mapper.
map(indexOp.getResult(), newIndex);
251 assert(consumer.isDpsInput(fusedOperand) &&
252 "expected producer of input operand");
256 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
263 producerBlock.
getArguments().take_front(producer.getNumDpsInputs()))
264 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
269 .take_front(consumer.getNumDpsInputs())
271 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
275 producerBlock.
getArguments().take_back(producer.getNumDpsInits()))) {
276 if (!preservedProducerResults.count(bbArg.index()))
278 mapper.
map(bbArg.value(), fusedBlock->
addArgument(bbArg.value().getType(),
279 bbArg.value().getLoc()));
284 consumerBlock.
getArguments().take_back(consumer.getNumDpsInits()))
285 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
290 if (!isa<IndexOp>(op))
291 rewriter.
clone(op, mapper);
295 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.
getTerminator());
296 unsigned producerResultNumber =
297 cast<OpResult>(fusedOperand->
get()).getResultNumber();
299 mapper.
lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
303 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
304 if (
auto bb = dyn_cast<BlockArgument>(replacement))
305 assert(bb.getOwner() != &producerBlock &&
306 "yielded block argument must have been mapped");
309 "yielded value must have been mapped");
315 rewriter.
clone(op, mapper);
319 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.
getTerminator());
321 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
322 consumerYieldOp.getNumOperands());
323 for (
const auto &producerYieldVal :
325 if (preservedProducerResults.count(producerYieldVal.index()))
326 fusedYieldValues.push_back(
329 for (
auto consumerYieldVal : consumerYieldOp.getOperands())
331 rewriter.
create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
335 "Ill-formed GenericOp region");
338 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
342 "expected elementwise operation pre-conditions to pass");
343 auto producerResult = cast<OpResult>(fusedOperand->
get());
344 auto producer = cast<GenericOp>(producerResult.getOwner());
345 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
347 assert(consumer.isDpsInput(fusedOperand) &&
348 "expected producer of input operand");
351 llvm::SmallDenseSet<int> preservedProducerResults =
359 fusedInputOperands.reserve(producer.getNumDpsInputs() +
360 consumer.getNumDpsInputs());
361 fusedOutputOperands.reserve(preservedProducerResults.size() +
362 consumer.getNumDpsInits());
363 fusedResultTypes.reserve(preservedProducerResults.size() +
364 consumer.getNumDpsInits());
365 fusedIndexMaps.reserve(producer->getNumOperands() +
366 consumer->getNumOperands());
369 auto consumerInputs = consumer.getDpsInputOperands();
370 auto *it = llvm::find_if(consumerInputs, [&](
OpOperand *operand) {
371 return operand == fusedOperand;
373 assert(it != consumerInputs.end() &&
"expected to find the consumer operand");
374 for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
375 fusedInputOperands.push_back(opOperand->get());
376 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
380 producer.getIndexingMapMatchingResult(producerResult);
381 for (
OpOperand *opOperand : producer.getDpsInputOperands()) {
382 fusedInputOperands.push_back(opOperand->get());
385 opOperand, producerResultIndexMap,
386 consumer.getMatchingIndexingMap(fusedOperand));
387 fusedIndexMaps.push_back(map);
392 llvm::make_range(std::next(it), consumerInputs.end())) {
393 fusedInputOperands.push_back(opOperand->get());
394 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
398 for (
const auto &opOperand :
llvm::enumerate(producer.getDpsInitsMutable())) {
399 if (!preservedProducerResults.count(opOperand.index()))
402 fusedOutputOperands.push_back(opOperand.value().get());
404 &opOperand.value(), producerResultIndexMap,
405 consumer.getMatchingIndexingMap(fusedOperand));
406 fusedIndexMaps.push_back(map);
407 fusedResultTypes.push_back(opOperand.value().get().getType());
411 for (
OpOperand &opOperand : consumer.getDpsInitsMutable()) {
412 fusedOutputOperands.push_back(opOperand.get());
413 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
414 Type resultType = opOperand.get().getType();
415 if (!isa<MemRefType>(resultType))
416 fusedResultTypes.push_back(resultType);
420 auto fusedOp = rewriter.
create<GenericOp>(
421 consumer.getLoc(), fusedResultTypes, fusedInputOperands,
423 consumer.getIteratorTypes(),
426 if (!fusedOp.getShapesToLoopsMap()) {
432 fusedOp,
"fused op failed loop bound computation check");
438 consumer.getMatchingIndexingMap(fusedOperand);
442 assert(invProducerResultIndexMap &&
443 "expected producer result indexig map to be invertible");
446 invProducerResultIndexMap.
compose(consumerResultIndexMap);
449 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
450 consumer.getNumLoops(), preservedProducerResults);
454 for (
auto [index, producerResult] :
llvm::enumerate(producer->getResults()))
455 if (preservedProducerResults.count(index))
456 result.
replacements[producerResult] = fusedOp->getResult(resultNum++);
457 for (
auto consumerResult : consumer->getResults())
458 result.
replacements[consumerResult] = fusedOp->getResult(resultNum++);
469 controlFn(std::move(fun)) {}
471 LogicalResult matchAndRewrite(GenericOp genericOp,
474 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
477 if (!controlFn(&opOperand))
480 Operation *producer = opOperand.get().getDefiningOp();
483 FailureOr<ElementwiseOpFusionResult> fusionResult =
485 if (failed(fusionResult))
489 for (
auto [origVal, replacement] : fusionResult->replacements) {
492 return use.
get().getDefiningOp() != producer;
571 linalgOp.getIteratorTypesArray();
572 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573 return linalgOp.hasPureTensorSemantics() &&
574 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
576 return cast<AffineMapAttr>(attr)
578 .isProjectedPermutation();
586 class ExpansionInfo {
592 LogicalResult compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
596 unsigned getOrigOpNumDims()
const {
return reassociation.size(); }
597 unsigned getExpandedOpNumDims()
const {
return expandedOpNumDims; }
599 return reassociation[i];
602 return expandedShapeMap[i];
615 unsigned expandedOpNumDims;
619 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
624 if (reassociationMaps.empty())
626 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
630 originalLoopExtent = llvm::map_to_vector(
631 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
632 [](
Range r) { return r.size; });
634 reassociation.clear();
635 expandedShapeMap.clear();
639 expandedShapeMap.resize(fusedIndexMap.
getNumDims());
641 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
642 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
645 expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
646 expandedShapeMap[pos].assign(shape.begin(), shape.end());
649 for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
650 if (expandedShapeMap[i].empty())
651 expandedShapeMap[i] = {originalLoopExtent[i]};
655 reassociation.reserve(fusedIndexMap.
getNumDims());
657 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
658 reassociation.emplace_back(seq.begin(), seq.end());
659 sum += numFoldedDim.value();
661 expandedOpNumDims = sum;
669 const ExpansionInfo &expansionInfo) {
672 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
674 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
675 return builder.getAffineDimExpr(static_cast<unsigned>(v));
677 newExprs.append(expandedExprs.begin(), expandedExprs.end());
686 static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
688 const ExpansionInfo &expansionInfo) {
691 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
693 expansionInfo.getExpandedShapeOfDim(dim);
694 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
697 std::tie(expandedStaticShape, std::ignore) =
700 originalType.getElementType())};
711 const ExpansionInfo &expansionInfo) {
713 unsigned numReshapeDims = 0;
715 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
716 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
718 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
719 reassociation.emplace_back(std::move(indices));
720 numReshapeDims += numExpandedDims;
722 return reassociation;
732 const ExpansionInfo &expansionInfo) {
734 for (IndexOp indexOp :
735 llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
737 expansionInfo.getExpandedDims(indexOp.getDim());
738 assert(!expandedDims.empty() &&
"expected valid expansion info");
741 if (expandedDims.size() == 1 &&
742 expandedDims.front() == (int64_t)indexOp.getDim())
749 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
751 expandedIndices.reserve(expandedDims.size() - 1);
753 expandedDims.drop_front(), std::back_inserter(expandedIndices),
754 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
756 rewriter.
create<IndexOp>(loc, expandedDims.front()).getResult();
757 for (
auto [expandedShape, expandedIndex] :
758 llvm::zip(expandedDimsShape, expandedIndices)) {
763 rewriter, indexOp.getLoc(), idx + acc * shape,
768 rewriter.
replaceOp(indexOp, newIndexVal);
790 TransposeOp transposeOp,
792 ExpansionInfo &expansionInfo) {
795 auto reassoc = expansionInfo.getExpandedDims(perm);
796 for (int64_t dim : reassoc) {
797 newPerm.push_back(dim);
800 return rewriter.
create<TransposeOp>(transposeOp.getLoc(), expandedInput,
811 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
813 for (
auto [i, type] :
llvm::enumerate(linalgOp.getIteratorTypesArray()))
814 for (
auto j : expansionInfo.getExpandedDims(i))
815 iteratorTypes[
j] = type;
818 linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
819 expandedOpIndexingMaps, iteratorTypes);
822 Region &originalRegion = linalgOp->getRegion(0);
840 ExpansionInfo &expansionInfo) {
843 .Case<TransposeOp>([&](TransposeOp transposeOp) {
845 expandedOpOperands[0], outputs[0],
848 .Case<FillOp, CopyOp>([&](
Operation *op) {
849 return clone(rewriter, linalgOp, resultTypes,
850 llvm::to_vector(llvm::concat<Value>(
851 llvm::to_vector(expandedOpOperands),
852 llvm::to_vector(outputs))));
856 expandedOpOperands, outputs,
857 expansionInfo, expandedOpIndexingMaps);
864 static std::optional<SmallVector<Value>>
869 "preconditions for fuse operation failed");
875 if (
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
879 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
882 expandedShape = expandingReshapeOp.getMixedOutputShape();
883 reassociationIndices = expandingReshapeOp.getReassociationMaps();
884 src = expandingReshapeOp.getSrc();
886 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887 if (!collapsingReshapeOp)
891 rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
892 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
893 src = collapsingReshapeOp.getSrc();
896 ExpansionInfo expansionInfo;
897 if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
898 reassociationIndices, expandedShape,
903 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](
AffineMap m) {
904 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
912 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
913 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
914 if (opOperand == fusableOpOperand) {
915 expandedOpOperands.push_back(src);
918 if (
auto opOperandType =
919 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
920 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
922 RankedTensorType expandedOperandType;
923 std::tie(expandedOperandShape, expandedOperandType) =
925 if (expandedOperandType != opOperand->get().getType()) {
930 [&](
const Twine &msg) {
933 opOperandType.getShape(), expandedOperandType.getShape(),
937 expandedOpOperands.push_back(rewriter.
create<tensor::ExpandShapeOp>(
938 loc, expandedOperandType, opOperand->get(), reassociation,
939 expandedOperandShape));
943 expandedOpOperands.push_back(opOperand->get());
947 for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
948 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
949 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
951 RankedTensorType expandedOutputType;
952 std::tie(expandedOutputShape, expandedOutputType) =
954 if (expandedOutputType != opOperand.get().getType()) {
958 [&](
const Twine &msg) {
961 opOperandType.getShape(), expandedOutputType.getShape(),
965 outputs.push_back(rewriter.
create<tensor::ExpandShapeOp>(
966 loc, expandedOutputType, opOperand.get(), reassociation,
967 expandedOutputShape));
969 outputs.push_back(opOperand.get());
976 outputs, expandedOpIndexingMaps, expansionInfo);
980 for (
OpResult opResult : linalgOp->getOpResults()) {
981 int64_t resultNumber = opResult.getResultNumber();
982 if (resultTypes[resultNumber] != opResult.getType()) {
985 linalgOp.getMatchingIndexingMap(
986 linalgOp.getDpsInitOperand(resultNumber)),
988 resultVals.push_back(rewriter.
create<tensor::CollapseShapeOp>(
989 linalgOp.getLoc(), opResult.getType(),
990 fusedOp->
getResult(resultNumber), reassociation));
992 resultVals.push_back(fusedOp->
getResult(resultNumber));
1004 class FoldWithProducerReshapeOpByExpansion
1007 FoldWithProducerReshapeOpByExpansion(
MLIRContext *context,
1011 controlFoldingReshapes(std::move(foldReshapes)) {}
1013 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1015 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1016 tensor::CollapseShapeOp reshapeOp =
1017 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1024 (!controlFoldingReshapes(opOperand)))
1027 std::optional<SmallVector<Value>> replacementValues =
1029 if (!replacementValues)
1031 rewriter.
replaceOp(linalgOp, *replacementValues);
1041 class FoldPadWithProducerReshapeOpByExpansion
1044 FoldPadWithProducerReshapeOpByExpansion(
MLIRContext *context,
1048 controlFoldingReshapes(std::move(foldReshapes)) {}
1050 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1052 tensor::CollapseShapeOp reshapeOp =
1053 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1056 if (!reshapeOp->hasOneUse())
1059 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1061 "fusion blocked by control function");
1067 reshapeOp.getReassociationIndices();
1069 for (
auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1070 if (reInd.size() != 1 && (l != 0 || h != 0))
1075 RankedTensorType expandedType = reshapeOp.getSrcType();
1076 RankedTensorType paddedType = padOp.getResultType();
1079 if (reInd.size() == 1) {
1080 expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1082 for (
size_t i = 0; i < reInd.size(); ++i) {
1083 newLow.push_back(padOp.getMixedLowPad()[idx]);
1084 newHigh.push_back(padOp.getMixedHighPad()[idx]);
1089 RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1090 auto newPadOp = rewriter.
create<tensor::PadOp>(
1091 loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1092 padOp.getConstantPaddingValue(), padOp.getNofold());
1095 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1106 struct FoldReshapeWithGenericOpByExpansion
1109 FoldReshapeWithGenericOpByExpansion(
MLIRContext *context,
1113 controlFoldingReshapes(std::move(foldReshapes)) {}
1115 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1118 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1119 if (!producerResult) {
1121 "source not produced by an operation");
1124 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1127 "producer not a generic op");
1132 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1134 reshapeOp,
"failed preconditions of fusion with producer generic op");
1137 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1139 "fusion blocked by control function");
1142 std::optional<SmallVector<Value>> replacementValues =
1144 producer, reshapeOp,
1145 producer.getDpsInitOperand(producerResult.getResultNumber()),
1147 if (!replacementValues) {
1149 "fusion by expansion failed");
1156 Value reshapeReplacement =
1157 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1158 .getResultNumber()];
1159 if (
auto collapseOp =
1160 reshapeReplacement.
getDefiningOp<tensor::CollapseShapeOp>()) {
1161 reshapeReplacement = collapseOp.getSrc();
1163 rewriter.
replaceOp(reshapeOp, reshapeReplacement);
1164 rewriter.
replaceOp(producer, *replacementValues);
1186 "expected projected permutation");
1189 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1190 return cast<AffineDimExpr>(indexingMap.
getResults()[pos]).getPosition();
1194 return domainReassociation;
1202 assert(!dimSequence.empty() &&
1203 "expected non-empty list for dimension sequence");
1205 "expected indexing map to be projected permutation");
1207 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1208 sequenceElements.insert_range(dimSequence);
1210 unsigned dimSequenceStart = dimSequence[0];
1212 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1214 if (dimInMapStart == dimSequenceStart) {
1215 if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
1218 for (
const auto &dimInSequence :
enumerate(dimSequence)) {
1220 cast<AffineDimExpr>(
1221 indexingMap.
getResult(expr.index() + dimInSequence.index()))
1223 if (dimInMap != dimInSequence.value())
1234 if (sequenceElements.count(dimInMapStart))
1243 return llvm::all_of(maps, [&](
AffineMap map) {
1300 if (!genericOp.hasPureTensorSemantics())
1303 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
1304 return map.isProjectedPermutation();
1311 genericOp.getReductionDims(reductionDims);
1313 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1314 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1315 auto iteratorTypes = genericOp.getIteratorTypesArray();
1318 assert(!foldedRangeDims.empty() &&
"unexpected empty reassociation");
1321 if (foldedRangeDims.size() == 1)
1329 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1330 return processedIterationDims.count(dim);
1335 utils::IteratorType startIteratorType =
1336 iteratorTypes[foldedIterationSpaceDims[0]];
1340 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1341 return iteratorTypes[dim] != startIteratorType;
1350 bool isContiguous =
false;
1353 if (startDim.value() != foldedIterationSpaceDims[0])
1357 if (startDim.index() + foldedIterationSpaceDims.size() >
1358 reductionDims.size())
1361 isContiguous =
true;
1362 for (
const auto &foldedDim :
1364 if (reductionDims[foldedDim.index() + startDim.index()] !=
1365 foldedDim.value()) {
1366 isContiguous =
false;
1377 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1379 return !isDimSequencePreserved(indexingMap,
1380 foldedIterationSpaceDims);
1384 processedIterationDims.insert_range(foldedIterationSpaceDims);
1385 iterationSpaceReassociation.emplace_back(
1386 std::move(foldedIterationSpaceDims));
1389 return iterationSpaceReassociation;
1394 class CollapsingInfo {
1396 LogicalResult initialize(
unsigned origNumLoops,
1398 llvm::SmallDenseSet<int64_t, 4> processedDims;
1401 if (foldedIterationDim.empty())
1405 for (
auto dim : foldedIterationDim) {
1406 if (dim >= origNumLoops)
1408 if (processedDims.count(dim))
1410 processedDims.insert(dim);
1412 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1413 foldedIterationDim.end());
1415 if (processedDims.size() > origNumLoops)
1420 for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1421 if (processedDims.count(dim))
1426 llvm::sort(collapsedOpToOrigOpIterationDim,
1428 return lhs[0] < rhs[0];
1430 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1431 for (
const auto &foldedDims :
1433 for (
const auto &dim :
enumerate(foldedDims.value()))
1434 origOpToCollapsedOpIterationDim[dim.value()] =
1435 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1442 return collapsedOpToOrigOpIterationDim;
1466 return origOpToCollapsedOpIterationDim;
1470 unsigned getCollapsedOpIterationRank()
const {
1471 return collapsedOpToOrigOpIterationDim.size();
1489 const CollapsingInfo &collapsingInfo) {
1492 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1493 assert(!foldedIterDims.empty() &&
1494 "reassociation indices expected to have non-empty sets");
1498 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1500 return collapsedIteratorTypes;
1507 const CollapsingInfo &collapsingInfo) {
1510 "expected indexing map to be projected permutation");
1512 auto origOpToCollapsedOpMapping =
1513 collapsingInfo.getOrigOpToCollapsedOpMapping();
1515 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1517 if (origOpToCollapsedOpMapping[dim].second != 0)
1521 resultExprs.push_back(
1524 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1525 resultExprs, context);
1532 const CollapsingInfo &collapsingInfo) {
1533 unsigned counter = 0;
1535 auto origOpToCollapsedOpMapping =
1536 collapsingInfo.getOrigOpToCollapsedOpMapping();
1537 auto collapsedOpToOrigOpMapping =
1538 collapsingInfo.getCollapsedOpToOrigOpMapping();
1541 cast<AffineDimExpr>(indexingMap.
getResult(counter)).getPosition();
1545 unsigned numFoldedDims =
1546 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1548 if (origOpToCollapsedOpMapping[dim].second == 0) {
1549 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1550 operandReassociation.emplace_back(range.begin(), range.end());
1552 counter += numFoldedDims;
1554 return operandReassociation;
1560 const CollapsingInfo &collapsingInfo,
1562 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1570 if (operandReassociation.size() == indexingMap.
getNumResults())
1574 if (isa<MemRefType>(operand.
getType())) {
1576 .
create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1580 .
create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1587 Location loc,
Block *block,
const CollapsingInfo &collapsingInfo,
1593 auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
1603 for (
auto foldedDims :
1604 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1607 rewriter.
create<linalg::IndexOp>(loc, foldedDims.index());
1608 for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1611 indexReplacementVals[dim] =
1612 rewriter.
createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1614 rewriter.
createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1616 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1619 for (
auto indexOp : indexOps) {
1620 auto dim = indexOp.getDim();
1621 rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
1626 const CollapsingInfo &collapsingInfo,
1633 llvm::map_to_vector(op.getDpsInputOperands(), [&](
OpOperand *opOperand) {
1634 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1639 resultTypes.reserve(op.getNumDpsInits());
1640 outputOperands.reserve(op.getNumDpsInits());
1641 for (
OpOperand &output : op.getDpsInitsMutable()) {
1644 outputOperands.push_back(newOutput);
1647 if (!op.hasPureBufferSemantics())
1648 resultTypes.push_back(newOutput.
getType());
1653 template <
typename OpTy>
1655 const CollapsingInfo &collapsingInfo) {
1663 const CollapsingInfo &collapsingInfo) {
1667 outputOperands, resultTypes);
1670 rewriter, origOp, resultTypes,
1671 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1678 const CollapsingInfo &collapsingInfo) {
1682 outputOperands, resultTypes);
1684 llvm::map_range(origOp.getIndexingMapsArray(), [&](
AffineMap map) {
1685 return getCollapsedOpIndexingMap(map, collapsingInfo);
1689 origOp.getIteratorTypesArray(), collapsingInfo));
1691 GenericOp collapsedOp = rewriter.
create<linalg::GenericOp>(
1692 origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1694 Block *origOpBlock = &origOp->getRegion(0).
front();
1695 Block *collapsedOpBlock = &collapsedOp->getRegion(0).
front();
1696 rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
1703 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1715 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1717 return foldedDims.size() <= 1;
1721 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1722 if (hasPureBufferSemantics &&
1723 !llvm::all_of(op->getOperands(), [&](
Value operand) ->
bool {
1724 MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1725 if (!memRefToCollapse)
1728 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1729 memRefToCollapse, foldedIterationDims);
1732 "memref is not guaranteed collapsible");
1734 CollapsingInfo collapsingInfo;
1736 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1738 op,
"illegal to collapse specified dimensions");
1743 auto opFoldIsConstantValue = [](
OpFoldResult ofr, int64_t value) {
1744 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1745 return cast<IntegerAttr>(attr).getInt() == value;
1748 actual.getSExtValue() == value;
1750 if (!llvm::all_of(loopRanges, [&](
Range range) {
1751 return opFoldIsConstantValue(range.
offset, 0) &&
1752 opFoldIsConstantValue(range.
stride, 1);
1755 op,
"expected all loop ranges to have zero start and unit stride");
1762 llvm::map_to_vector(loopRanges, [](
Range range) {
return range.
size; });
1764 if (collapsedOp.hasIndexSemantics()) {
1769 collapsingInfo, loopBound, rewriter);
1775 for (
const auto &originalResult :
llvm::enumerate(op->getResults())) {
1776 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1777 auto originalResultType =
1778 cast<ShapedType>(originalResult.value().getType());
1779 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.
getType());
1780 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1782 op.getIndexingMapMatchingResult(originalResult.value());
1787 "Expected indexing map to be a projected permutation for collapsing");
1791 if (isa<MemRefType>(collapsedOpResult.
getType())) {
1793 originalResultType.getShape(), originalResultType.getElementType());
1794 result = rewriter.
create<memref::ExpandShapeOp>(
1795 loc, expandShapeResultType, collapsedOpResult, reassociation,
1798 result = rewriter.
create<tensor::ExpandShapeOp>(
1799 loc, originalResultType, collapsedOpResult, reassociation,
1802 results.push_back(result);
1804 results.push_back(collapsedOpResult);
1814 class FoldWithProducerReshapeOpByCollapsing
1818 FoldWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1822 controlFoldingReshapes(std::move(foldReshapes)) {}
1824 LogicalResult matchAndRewrite(GenericOp genericOp,
1826 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
1827 tensor::ExpandShapeOp reshapeOp =
1828 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1834 reshapeOp.getReassociationIndices());
1835 if (collapsableIterationDims.empty() ||
1836 !controlFoldingReshapes(&opOperand)) {
1841 genericOp, collapsableIterationDims, rewriter);
1842 if (!collapseResult) {
1844 genericOp,
"failed to do the fusion by collapsing transformation");
1847 rewriter.
replaceOp(genericOp, collapseResult->results);
1859 struct FoldReshapeWithGenericOpByCollapsing
1862 FoldReshapeWithGenericOpByCollapsing(
MLIRContext *context,
1866 controlFoldingReshapes(std::move(foldReshapes)) {}
1868 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1872 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1873 if (!producerResult) {
1875 "source not produced by an operation");
1879 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1882 "producer not a generic op");
1888 producer.getDpsInitOperand(producerResult.getResultNumber()),
1889 reshapeOp.getReassociationIndices());
1890 if (collapsableIterationDims.empty()) {
1892 reshapeOp,
"failed preconditions of fusion with producer generic op");
1895 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1897 "fusion blocked by control function");
1903 std::optional<CollapseResult> collapseResult =
1905 if (!collapseResult) {
1907 producer,
"failed to do the fusion by collapsing transformation");
1910 rewriter.
replaceOp(producer, collapseResult->results);
1918 class FoldPadWithProducerReshapeOpByCollapsing
1921 FoldPadWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1925 controlFoldingReshapes(std::move(foldReshapes)) {}
1927 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1929 tensor::ExpandShapeOp reshapeOp =
1930 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1933 if (!reshapeOp->hasOneUse())
1936 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1938 "fusion blocked by control function");
1944 reshapeOp.getReassociationIndices();
1946 for (
auto reInd : reassociations) {
1947 if (reInd.size() == 1)
1949 if (llvm::any_of(reInd, [&](int64_t ind) {
1950 return low[ind] != 0 || high[ind] != 0;
1957 RankedTensorType collapsedType = reshapeOp.getSrcType();
1958 RankedTensorType paddedType = padOp.getResultType();
1962 reshapeOp.getOutputShape(), rewriter));
1966 Location loc = reshapeOp->getLoc();
1970 if (reInd.size() == 1) {
1971 collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1973 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1974 expandedPaddedSizes[reInd[0]] = paddedSize;
1976 newLow.push_back(l);
1977 newHigh.push_back(h);
1980 RankedTensorType collapsedPaddedType =
1981 paddedType.clone(collapsedPaddedShape);
1982 auto newPadOp = rewriter.
create<tensor::PadOp>(
1983 loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1984 padOp.getConstantPaddingValue(), padOp.getNofold());
1987 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1988 expandedPaddedSizes);
1998 template <
typename LinalgType>
2005 controlCollapseDimension(std::move(collapseDimensions)) {}
2007 LogicalResult matchAndRewrite(LinalgType op,
2010 controlCollapseDimension(op);
2011 if (collapsableIterationDims.empty())
2016 collapsableIterationDims)) {
2018 op,
"specified dimensions cannot be collapsed");
2021 std::optional<CollapseResult> collapseResult =
2023 if (!collapseResult) {
2026 rewriter.
replaceOp(op, collapseResult->results);
2048 LogicalResult matchAndRewrite(GenericOp genericOp,
2050 if (!genericOp.hasPureTensorSemantics())
2052 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2053 Operation *def = opOperand->get().getDefiningOp();
2054 TypedAttr constantAttr;
2055 auto isScalarOrSplatConstantOp = [&constantAttr](
Operation *def) ->
bool {
2058 if (
matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2060 splatAttr.
getType().getElementType().isIntOrFloat()) {
2066 IntegerAttr intAttr;
2067 if (
matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2068 constantAttr = intAttr;
2073 FloatAttr floatAttr;
2074 if (
matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2075 constantAttr = floatAttr;
2082 auto resultValue = dyn_cast<OpResult>(opOperand->get());
2083 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2092 fusedIndexMaps.reserve(genericOp->getNumOperands());
2093 fusedOperands.reserve(genericOp.getNumDpsInputs());
2094 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2095 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2096 if (inputOperand == opOperand)
2098 Value inputValue = inputOperand->get();
2099 fusedIndexMaps.push_back(
2100 genericOp.getMatchingIndexingMap(inputOperand));
2101 fusedOperands.push_back(inputValue);
2102 fusedLocs.push_back(inputValue.
getLoc());
2104 for (
OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2105 fusedIndexMaps.push_back(
2106 genericOp.getMatchingIndexingMap(&outputOperand));
2112 genericOp,
"fused op loop bound computation failed");
2116 Value scalarConstant =
2117 rewriter.
create<arith::ConstantOp>(def->
getLoc(), constantAttr);
2120 auto fusedOp = rewriter.
create<GenericOp>(
2121 rewriter.
getFusedLoc(fusedLocs), genericOp->getResultTypes(),
2125 genericOp.getIteratorTypes(),
2131 Region ®ion = genericOp->getRegion(0);
2134 mapping.
map(entryBlock.
getArgument(opOperand->getOperandNumber()),
2136 Region &fusedRegion = fusedOp->getRegion(0);
2139 rewriter.
replaceOp(genericOp, fusedOp->getResults());
2160 LogicalResult matchAndRewrite(GenericOp op,
2163 bool modifiedOutput =
false;
2165 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
2166 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2167 Value operandVal = opOperand.get();
2168 auto operandType = dyn_cast<RankedTensorType>(operandVal.
getType());
2177 auto definingOp = operandVal.
getDefiningOp<tensor::EmptyOp>();
2180 modifiedOutput =
true;
2183 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
2184 loc, mixedSizes, operandType.getElementType());
2185 op->
setOperand(opOperand.getOperandNumber(), emptyTensor);
2188 if (!modifiedOutput) {
2201 LogicalResult matchAndRewrite(GenericOp genericOp,
2203 if (!genericOp.hasPureTensorSemantics())
2205 bool fillFound =
false;
2206 Block &payload = genericOp.getRegion().
front();
2207 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2208 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2210 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2214 Value fillVal = fillOp.value();
2216 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2217 Value convertedVal =
2221 payload.
getArgument(opOperand->getOperandNumber()), convertedVal);
2223 return success(fillFound);
2232 controlFoldingReshapes);
2233 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(
patterns.getContext(),
2234 controlFoldingReshapes);
2236 controlFoldingReshapes);
2243 controlFoldingReshapes);
2244 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2245 patterns.getContext(), controlFoldingReshapes);
2247 controlFoldingReshapes);
2253 auto *context =
patterns.getContext();
2254 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2255 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2256 RemoveOutsDependency>(context);
2264 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2265 CollapseLinalgDimensions<linalg::CopyOp>>(
2266 patterns.getContext(), controlCollapseDimensions);
2281 struct LinalgElementwiseOpFusionPass
2282 :
public impl::LinalgElementwiseOpFusionPassBase<
2283 LinalgElementwiseOpFusionPass> {
2284 using impl::LinalgElementwiseOpFusionPassBase<
2285 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2286 void runOnOperation()
override {
2293 Operation *producer = fusedOperand->get().getDefiningOp();
2294 return producer && producer->
hasOneUse();
2303 affine::AffineApplyOp::getCanonicalizationPatterns(
patterns, context);
2304 GenericOp::getCanonicalizationPatterns(
patterns, context);
2305 tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
2306 tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ArrayRef< OpFoldResult > loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setUseTopDownTraversal(bool use=true)
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
ArrayRef< int64_t > ReassociationIndicesRef
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
llvm::DenseMap< Value, Value > replacements
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.