33 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
34 #include "mlir/Dialect/Linalg/Passes.h.inc"
60 assert(invProducerResultIndexMap &&
61 "expected producer result indexing map to be invertible");
63 LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
65 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
73 return t1.
compose(fusedConsumerArgIndexMap);
80 GenericOp producer, GenericOp consumer,
85 for (
auto &op : ops) {
86 for (
auto &opOperand : op->getOpOperands()) {
87 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
90 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
93 if (indexingMaps.empty()) {
96 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
104 indexingMaps, producer.getContext())) !=
AffineMap();
113 GenericOp producer, GenericOp consumer,
OpOperand *fusedOperand) {
114 llvm::SmallDenseSet<int> preservedProducerResults;
118 opOperandsToIgnore.emplace_back(fusedOperand);
120 for (
const auto &producerResult :
llvm::enumerate(producer->getResults())) {
121 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122 opOperandsToIgnore.emplace_back(outputOperand);
123 if (producer.payloadUsesValueFromOperand(outputOperand) ||
125 opOperandsToIgnore) ||
126 llvm::any_of(producerResult.value().getUsers(), [&](
Operation *user) {
127 return user != consumer.getOperation();
129 preservedProducerResults.insert(producerResult.index());
132 (void)opOperandsToIgnore.pop_back_val();
135 return preservedProducerResults;
144 auto consumer = dyn_cast<GenericOp>(fusedOperand->
getOwner());
147 if (!producer || !consumer)
153 if (!producer.hasPureTensorSemantics() ||
154 !isa<RankedTensorType>(fusedOperand->
get().
getType()))
159 if (producer.getNumParallelLoops() != producer.getNumLoops())
164 if (!consumer.isDpsInput(fusedOperand))
169 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170 if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
176 producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
184 if ((consumer.getNumReductionLoops())) {
185 BitVector coveredDims(consumer.getNumLoops(),
false);
187 auto addToCoveredDims = [&](
AffineMap map) {
188 for (
auto result : map.getResults())
189 if (
auto dimExpr = dyn_cast<AffineDimExpr>(result))
190 coveredDims[dimExpr.getPosition()] =
true;
194 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
195 Value operand = std::get<0>(pair);
196 if (operand == fusedOperand->
get())
198 AffineMap operandMap = std::get<1>(pair);
199 addToCoveredDims(operandMap);
202 for (
OpOperand *operand : producer.getDpsInputOperands()) {
205 operand, producerResultIndexMap, consumerIndexMap);
206 addToCoveredDims(newIndexingMap);
208 if (!coveredDims.all())
220 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
222 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
224 Block &producerBlock = producer->getRegion(0).
front();
225 Block &consumerBlock = consumer->getRegion(0).
front();
232 if (producer.hasIndexSemantics()) {
234 unsigned numFusedOpLoops =
235 std::max(producer.getNumLoops(), consumer.getNumLoops());
237 fusedIndices.reserve(numFusedOpLoops);
238 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
239 std::back_inserter(fusedIndices), [&](uint64_t dim) {
240 return rewriter.
create<IndexOp>(producer.getLoc(), dim);
242 for (IndexOp indexOp :
243 llvm::make_early_inc_range(producerBlock.
getOps<IndexOp>())) {
244 Value newIndex = rewriter.
create<affine::AffineApplyOp>(
246 consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
247 mapper.
map(indexOp.getResult(), newIndex);
251 assert(consumer.isDpsInput(fusedOperand) &&
252 "expected producer of input operand");
256 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
263 producerBlock.
getArguments().take_front(producer.getNumDpsInputs()))
264 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
269 .take_front(consumer.getNumDpsInputs())
271 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
275 producerBlock.
getArguments().take_back(producer.getNumDpsInits()))) {
276 if (!preservedProducerResults.count(bbArg.index()))
278 mapper.
map(bbArg.value(), fusedBlock->
addArgument(bbArg.value().getType(),
279 bbArg.value().getLoc()));
284 consumerBlock.
getArguments().take_back(consumer.getNumDpsInits()))
285 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
290 if (!isa<IndexOp>(op))
291 rewriter.
clone(op, mapper);
295 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.
getTerminator());
296 unsigned producerResultNumber =
297 cast<OpResult>(fusedOperand->
get()).getResultNumber();
299 mapper.
lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
303 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
304 if (
auto bb = dyn_cast<BlockArgument>(replacement))
305 assert(bb.getOwner() != &producerBlock &&
306 "yielded block argument must have been mapped");
309 "yielded value must have been mapped");
315 rewriter.
clone(op, mapper);
319 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.
getTerminator());
321 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
322 consumerYieldOp.getNumOperands());
323 for (
const auto &producerYieldVal :
325 if (preservedProducerResults.count(producerYieldVal.index()))
326 fusedYieldValues.push_back(
329 for (
auto consumerYieldVal : consumerYieldOp.getOperands())
331 rewriter.
create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
335 "Ill-formed GenericOp region");
338 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
342 "expected elementwise operation pre-conditions to pass");
343 auto producerResult = cast<OpResult>(fusedOperand->
get());
344 auto producer = cast<GenericOp>(producerResult.getOwner());
345 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
347 assert(consumer.isDpsInput(fusedOperand) &&
348 "expected producer of input operand");
351 llvm::SmallDenseSet<int> preservedProducerResults =
359 fusedInputOperands.reserve(producer.getNumDpsInputs() +
360 consumer.getNumDpsInputs());
361 fusedOutputOperands.reserve(preservedProducerResults.size() +
362 consumer.getNumDpsInits());
363 fusedResultTypes.reserve(preservedProducerResults.size() +
364 consumer.getNumDpsInits());
365 fusedIndexMaps.reserve(producer->getNumOperands() +
366 consumer->getNumOperands());
369 auto consumerInputs = consumer.getDpsInputOperands();
370 auto *it = llvm::find_if(consumerInputs, [&](
OpOperand *operand) {
371 return operand == fusedOperand;
373 assert(it != consumerInputs.end() &&
"expected to find the consumer operand");
374 for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
375 fusedInputOperands.push_back(opOperand->get());
376 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
380 producer.getIndexingMapMatchingResult(producerResult);
381 for (
OpOperand *opOperand : producer.getDpsInputOperands()) {
382 fusedInputOperands.push_back(opOperand->get());
385 opOperand, producerResultIndexMap,
386 consumer.getMatchingIndexingMap(fusedOperand));
387 fusedIndexMaps.push_back(map);
392 llvm::make_range(std::next(it), consumerInputs.end())) {
393 fusedInputOperands.push_back(opOperand->get());
394 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
398 for (
const auto &opOperand :
llvm::enumerate(producer.getDpsInitsMutable())) {
399 if (!preservedProducerResults.count(opOperand.index()))
402 fusedOutputOperands.push_back(opOperand.value().get());
404 &opOperand.value(), producerResultIndexMap,
405 consumer.getMatchingIndexingMap(fusedOperand));
406 fusedIndexMaps.push_back(map);
407 fusedResultTypes.push_back(opOperand.value().get().getType());
411 for (
OpOperand &opOperand : consumer.getDpsInitsMutable()) {
412 fusedOutputOperands.push_back(opOperand.get());
413 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
414 Type resultType = opOperand.get().getType();
415 if (!isa<MemRefType>(resultType))
416 fusedResultTypes.push_back(resultType);
420 auto fusedOp = rewriter.
create<GenericOp>(
421 consumer.getLoc(), fusedResultTypes, fusedInputOperands,
423 consumer.getIteratorTypes(),
426 if (!fusedOp.getShapesToLoopsMap()) {
432 fusedOp,
"fused op failed loop bound computation check");
438 consumer.getMatchingIndexingMap(fusedOperand);
442 assert(invProducerResultIndexMap &&
443 "expected producer result indexig map to be invertible");
446 invProducerResultIndexMap.
compose(consumerResultIndexMap);
449 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
450 consumer.getNumLoops(), preservedProducerResults);
454 for (
auto [index, producerResult] :
llvm::enumerate(producer->getResults()))
455 if (preservedProducerResults.count(index))
456 result.
replacements[producerResult] = fusedOp->getResult(resultNum++);
457 for (
auto consumerResult : consumer->getResults())
458 result.
replacements[consumerResult] = fusedOp->getResult(resultNum++);
469 controlFn(std::move(fun)) {}
471 LogicalResult matchAndRewrite(GenericOp genericOp,
474 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
477 if (!controlFn(&opOperand))
480 Operation *producer = opOperand.get().getDefiningOp();
483 FailureOr<ElementwiseOpFusionResult> fusionResult =
485 if (failed(fusionResult))
489 for (
auto [origVal, replacement] : fusionResult->replacements) {
492 return use.
get().getDefiningOp() != producer;
571 linalgOp.getIteratorTypesArray();
572 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573 return linalgOp.hasPureTensorSemantics() &&
574 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
576 return cast<AffineMapAttr>(attr)
578 .isProjectedPermutation();
586 class ExpansionInfo {
592 LogicalResult compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
596 unsigned getOrigOpNumDims()
const {
return reassociation.size(); }
597 unsigned getExpandedOpNumDims()
const {
return expandedOpNumDims; }
599 return reassociation[i];
602 return expandedShapeMap[i];
615 unsigned expandedOpNumDims;
619 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
624 if (reassociationMaps.empty())
626 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
630 originalLoopExtent = llvm::map_to_vector(
631 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
632 [](
Range r) { return r.size; });
634 reassociation.clear();
635 expandedShapeMap.clear();
639 expandedShapeMap.resize(fusedIndexMap.
getNumDims());
641 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
642 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
645 expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
646 expandedShapeMap[pos].assign(shape.begin(), shape.end());
649 for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
650 if (expandedShapeMap[i].empty())
651 expandedShapeMap[i] = {originalLoopExtent[i]};
655 reassociation.reserve(fusedIndexMap.
getNumDims());
657 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
658 reassociation.emplace_back(seq.begin(), seq.end());
659 sum += numFoldedDim.value();
661 expandedOpNumDims = sum;
669 const ExpansionInfo &expansionInfo) {
672 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
674 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
675 return builder.getAffineDimExpr(static_cast<unsigned>(v));
677 newExprs.append(expandedExprs.begin(), expandedExprs.end());
686 static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
688 const ExpansionInfo &expansionInfo) {
691 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
693 expansionInfo.getExpandedShapeOfDim(dim);
694 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
697 std::tie(expandedStaticShape, std::ignore) =
700 originalType.getElementType())};
711 const ExpansionInfo &expansionInfo) {
713 unsigned numReshapeDims = 0;
715 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
716 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
718 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
719 reassociation.emplace_back(std::move(indices));
720 numReshapeDims += numExpandedDims;
722 return reassociation;
732 const ExpansionInfo &expansionInfo) {
734 for (IndexOp indexOp :
735 llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
737 expansionInfo.getExpandedDims(indexOp.getDim());
738 assert(!expandedDims.empty() &&
"expected valid expansion info");
741 if (expandedDims.size() == 1 &&
742 expandedDims.front() == (int64_t)indexOp.getDim())
749 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
751 expandedIndices.reserve(expandedDims.size() - 1);
753 expandedDims.drop_front(), std::back_inserter(expandedIndices),
754 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
756 rewriter.
create<IndexOp>(loc, expandedDims.front()).getResult();
757 for (
auto [expandedShape, expandedIndex] :
758 llvm::zip(expandedDimsShape, expandedIndices)) {
763 rewriter, indexOp.getLoc(), idx + acc * shape,
768 rewriter.
replaceOp(indexOp, newIndexVal);
790 TransposeOp transposeOp,
792 ExpansionInfo &expansionInfo) {
795 auto reassoc = expansionInfo.getExpandedDims(perm);
796 for (int64_t dim : reassoc) {
797 newPerm.push_back(dim);
800 return rewriter.
create<TransposeOp>(transposeOp.getLoc(), expandedInput,
811 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
813 for (
auto [i, type] :
llvm::enumerate(linalgOp.getIteratorTypesArray()))
814 for (
auto j : expansionInfo.getExpandedDims(i))
815 iteratorTypes[
j] = type;
818 linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
819 expandedOpIndexingMaps, iteratorTypes);
822 Region &originalRegion = linalgOp->getRegion(0);
840 ExpansionInfo &expansionInfo) {
843 .Case<TransposeOp>([&](TransposeOp transposeOp) {
845 expandedOpOperands[0], outputs[0],
848 .Case<FillOp, CopyOp>([&](
Operation *op) {
849 return clone(rewriter, linalgOp, resultTypes,
850 llvm::to_vector(llvm::concat<Value>(
851 llvm::to_vector(expandedOpOperands),
852 llvm::to_vector(outputs))));
856 expandedOpOperands, outputs,
857 expansionInfo, expandedOpIndexingMaps);
864 static std::optional<SmallVector<Value>>
869 "preconditions for fuse operation failed");
875 if (
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
879 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
882 expandedShape = expandingReshapeOp.getMixedOutputShape();
883 reassociationIndices = expandingReshapeOp.getReassociationMaps();
884 src = expandingReshapeOp.getSrc();
886 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887 if (!collapsingReshapeOp)
891 rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
892 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
893 src = collapsingReshapeOp.getSrc();
896 ExpansionInfo expansionInfo;
897 if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
898 reassociationIndices, expandedShape,
903 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](
AffineMap m) {
904 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
912 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
913 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
914 if (opOperand == fusableOpOperand) {
915 expandedOpOperands.push_back(src);
918 if (
auto opOperandType =
919 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
920 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
922 RankedTensorType expandedOperandType;
923 std::tie(expandedOperandShape, expandedOperandType) =
925 if (expandedOperandType != opOperand->get().getType()) {
930 [&](
const Twine &msg) {
933 opOperandType.getShape(), expandedOperandType.getShape(),
937 expandedOpOperands.push_back(rewriter.
create<tensor::ExpandShapeOp>(
938 loc, expandedOperandType, opOperand->get(), reassociation,
939 expandedOperandShape));
943 expandedOpOperands.push_back(opOperand->get());
947 for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
948 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
949 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
951 RankedTensorType expandedOutputType;
952 std::tie(expandedOutputShape, expandedOutputType) =
954 if (expandedOutputType != opOperand.get().getType()) {
958 [&](
const Twine &msg) {
961 opOperandType.getShape(), expandedOutputType.getShape(),
965 outputs.push_back(rewriter.
create<tensor::ExpandShapeOp>(
966 loc, expandedOutputType, opOperand.get(), reassociation,
967 expandedOutputShape));
969 outputs.push_back(opOperand.get());
976 outputs, expandedOpIndexingMaps, expansionInfo);
980 for (
OpResult opResult : linalgOp->getOpResults()) {
981 int64_t resultNumber = opResult.getResultNumber();
982 if (resultTypes[resultNumber] != opResult.getType()) {
985 linalgOp.getMatchingIndexingMap(
986 linalgOp.getDpsInitOperand(resultNumber)),
988 resultVals.push_back(rewriter.
create<tensor::CollapseShapeOp>(
989 linalgOp.getLoc(), opResult.getType(),
990 fusedOp->
getResult(resultNumber), reassociation));
992 resultVals.push_back(fusedOp->
getResult(resultNumber));
1004 class FoldWithProducerReshapeOpByExpansion
1007 FoldWithProducerReshapeOpByExpansion(
MLIRContext *context,
1011 controlFoldingReshapes(std::move(foldReshapes)) {}
1013 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1015 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1016 tensor::CollapseShapeOp reshapeOp =
1017 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1024 (!controlFoldingReshapes(opOperand)))
1027 std::optional<SmallVector<Value>> replacementValues =
1029 if (!replacementValues)
1031 rewriter.
replaceOp(linalgOp, *replacementValues);
1041 class FoldPadWithProducerReshapeOpByExpansion
1044 FoldPadWithProducerReshapeOpByExpansion(
MLIRContext *context,
1048 controlFoldingReshapes(std::move(foldReshapes)) {}
1050 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1052 tensor::CollapseShapeOp reshapeOp =
1053 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1056 if (!reshapeOp->hasOneUse())
1059 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1061 "fusion blocked by control function");
1067 reshapeOp.getReassociationIndices();
1069 for (
auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1070 if (reInd.size() != 1 && (l != 0 || h != 0))
1075 RankedTensorType expandedType = reshapeOp.getSrcType();
1076 RankedTensorType paddedType = padOp.getResultType();
1079 if (reInd.size() == 1) {
1080 expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1082 for (
size_t i = 0; i < reInd.size(); ++i) {
1083 newLow.push_back(padOp.getMixedLowPad()[idx]);
1084 newHigh.push_back(padOp.getMixedHighPad()[idx]);
1089 RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1090 auto newPadOp = rewriter.
create<tensor::PadOp>(
1091 loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1092 padOp.getConstantPaddingValue(), padOp.getNofold());
1095 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1106 struct FoldReshapeWithGenericOpByExpansion
1109 FoldReshapeWithGenericOpByExpansion(
MLIRContext *context,
1113 controlFoldingReshapes(std::move(foldReshapes)) {}
1115 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1118 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1119 if (!producerResult) {
1121 "source not produced by an operation");
1124 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1127 "producer not a generic op");
1132 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1134 reshapeOp,
"failed preconditions of fusion with producer generic op");
1137 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1139 "fusion blocked by control function");
1142 std::optional<SmallVector<Value>> replacementValues =
1144 producer, reshapeOp,
1145 producer.getDpsInitOperand(producerResult.getResultNumber()),
1147 if (!replacementValues) {
1149 "fusion by expansion failed");
1156 Value reshapeReplacement =
1157 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1158 .getResultNumber()];
1159 if (
auto collapseOp =
1160 reshapeReplacement.
getDefiningOp<tensor::CollapseShapeOp>()) {
1161 reshapeReplacement = collapseOp.getSrc();
1163 rewriter.
replaceOp(reshapeOp, reshapeReplacement);
1164 rewriter.
replaceOp(producer, *replacementValues);
1186 "expected projected permutation");
1189 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1190 return cast<AffineDimExpr>(indexingMap.
getResults()[pos]).getPosition();
1194 return domainReassociation;
1202 assert(!dimSequence.empty() &&
1203 "expected non-empty list for dimension sequence");
1205 "expected indexing map to be projected permutation");
1207 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1208 sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1210 unsigned dimSequenceStart = dimSequence[0];
1212 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1214 if (dimInMapStart == dimSequenceStart) {
1215 if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
1218 for (
const auto &dimInSequence :
enumerate(dimSequence)) {
1220 cast<AffineDimExpr>(
1221 indexingMap.
getResult(expr.index() + dimInSequence.index()))
1223 if (dimInMap != dimInSequence.value())
1234 if (sequenceElements.count(dimInMapStart))
1243 return llvm::all_of(maps, [&](
AffineMap map) {
1300 if (!genericOp.hasPureTensorSemantics())
1303 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
1304 return map.isProjectedPermutation();
1311 genericOp.getReductionDims(reductionDims);
1313 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1314 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1315 auto iteratorTypes = genericOp.getIteratorTypesArray();
1318 assert(!foldedRangeDims.empty() &&
"unexpected empty reassociation");
1321 if (foldedRangeDims.size() == 1)
1329 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1330 return processedIterationDims.count(dim);
1335 utils::IteratorType startIteratorType =
1336 iteratorTypes[foldedIterationSpaceDims[0]];
1340 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1341 return iteratorTypes[dim] != startIteratorType;
1350 bool isContiguous =
false;
1353 if (startDim.value() != foldedIterationSpaceDims[0])
1357 if (startDim.index() + foldedIterationSpaceDims.size() >
1358 reductionDims.size())
1361 isContiguous =
true;
1362 for (
const auto &foldedDim :
1364 if (reductionDims[foldedDim.index() + startDim.index()] !=
1365 foldedDim.value()) {
1366 isContiguous =
false;
1377 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1379 return !isDimSequencePreserved(indexingMap,
1380 foldedIterationSpaceDims);
1384 processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1385 foldedIterationSpaceDims.end());
1386 iterationSpaceReassociation.emplace_back(
1387 std::move(foldedIterationSpaceDims));
1390 return iterationSpaceReassociation;
1395 class CollapsingInfo {
1397 LogicalResult initialize(
unsigned origNumLoops,
1399 llvm::SmallDenseSet<int64_t, 4> processedDims;
1402 if (foldedIterationDim.empty())
1406 for (
auto dim : foldedIterationDim) {
1407 if (dim >= origNumLoops)
1409 if (processedDims.count(dim))
1411 processedDims.insert(dim);
1413 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1414 foldedIterationDim.end());
1416 if (processedDims.size() > origNumLoops)
1421 for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1422 if (processedDims.count(dim))
1427 llvm::sort(collapsedOpToOrigOpIterationDim,
1429 return lhs[0] < rhs[0];
1431 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1432 for (
const auto &foldedDims :
1434 for (
const auto &dim :
enumerate(foldedDims.value()))
1435 origOpToCollapsedOpIterationDim[dim.value()] =
1436 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1443 return collapsedOpToOrigOpIterationDim;
1467 return origOpToCollapsedOpIterationDim;
1471 unsigned getCollapsedOpIterationRank()
const {
1472 return collapsedOpToOrigOpIterationDim.size();
1490 const CollapsingInfo &collapsingInfo) {
1493 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1494 assert(!foldedIterDims.empty() &&
1495 "reassociation indices expected to have non-empty sets");
1499 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1501 return collapsedIteratorTypes;
1508 const CollapsingInfo &collapsingInfo) {
1511 "expected indexing map to be projected permutation");
1513 auto origOpToCollapsedOpMapping =
1514 collapsingInfo.getOrigOpToCollapsedOpMapping();
1516 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1518 if (origOpToCollapsedOpMapping[dim].second != 0)
1522 resultExprs.push_back(
1525 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1526 resultExprs, context);
1533 const CollapsingInfo &collapsingInfo) {
1534 unsigned counter = 0;
1536 auto origOpToCollapsedOpMapping =
1537 collapsingInfo.getOrigOpToCollapsedOpMapping();
1538 auto collapsedOpToOrigOpMapping =
1539 collapsingInfo.getCollapsedOpToOrigOpMapping();
1542 cast<AffineDimExpr>(indexingMap.
getResult(counter)).getPosition();
1546 unsigned numFoldedDims =
1547 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1549 if (origOpToCollapsedOpMapping[dim].second == 0) {
1550 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1551 operandReassociation.emplace_back(range.begin(), range.end());
1553 counter += numFoldedDims;
1555 return operandReassociation;
1561 const CollapsingInfo &collapsingInfo,
1563 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1571 if (operandReassociation.size() == indexingMap.
getNumResults())
1575 if (isa<MemRefType>(operand.
getType())) {
1577 .
create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1581 .
create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1588 Location loc,
Block *block,
const CollapsingInfo &collapsingInfo,
1594 auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
1604 for (
auto foldedDims :
1605 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1608 rewriter.
create<linalg::IndexOp>(loc, foldedDims.index());
1609 for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1612 indexReplacementVals[dim] =
1613 rewriter.
createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1615 rewriter.
createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1617 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1620 for (
auto indexOp : indexOps) {
1621 auto dim = indexOp.getDim();
1622 rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
1627 const CollapsingInfo &collapsingInfo,
1634 llvm::map_to_vector(op.getDpsInputOperands(), [&](
OpOperand *opOperand) {
1635 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1640 resultTypes.reserve(op.getNumDpsInits());
1641 outputOperands.reserve(op.getNumDpsInits());
1642 for (
OpOperand &output : op.getDpsInitsMutable()) {
1645 outputOperands.push_back(newOutput);
1648 if (!op.hasPureBufferSemantics())
1649 resultTypes.push_back(newOutput.
getType());
1654 template <
typename OpTy>
1656 const CollapsingInfo &collapsingInfo) {
1664 const CollapsingInfo &collapsingInfo) {
1668 outputOperands, resultTypes);
1671 rewriter, origOp, resultTypes,
1672 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1679 const CollapsingInfo &collapsingInfo) {
1683 outputOperands, resultTypes);
1685 llvm::map_range(origOp.getIndexingMapsArray(), [&](
AffineMap map) {
1686 return getCollapsedOpIndexingMap(map, collapsingInfo);
1690 origOp.getIteratorTypesArray(), collapsingInfo));
1692 GenericOp collapsedOp = rewriter.
create<linalg::GenericOp>(
1693 origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1695 Block *origOpBlock = &origOp->getRegion(0).
front();
1696 Block *collapsedOpBlock = &collapsedOp->getRegion(0).
front();
1697 rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
1704 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1716 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1718 return foldedDims.size() <= 1;
1722 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1723 if (hasPureBufferSemantics &&
1724 !llvm::all_of(op->getOperands(), [&](
Value operand) ->
bool {
1725 MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1726 if (!memRefToCollapse)
1729 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1730 memRefToCollapse, foldedIterationDims);
1733 "memref is not guaranteed collapsible");
1735 CollapsingInfo collapsingInfo;
1737 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1739 op,
"illegal to collapse specified dimensions");
1744 auto opFoldIsConstantValue = [](
OpFoldResult ofr, int64_t value) {
1745 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1746 return cast<IntegerAttr>(attr).getInt() == value;
1749 actual.getSExtValue() == value;
1751 if (!llvm::all_of(loopRanges, [&](
Range range) {
1752 return opFoldIsConstantValue(range.
offset, 0) &&
1753 opFoldIsConstantValue(range.
stride, 1);
1756 op,
"expected all loop ranges to have zero start and unit stride");
1763 llvm::map_to_vector(loopRanges, [](
Range range) {
return range.
size; });
1765 if (collapsedOp.hasIndexSemantics()) {
1770 collapsingInfo, loopBound, rewriter);
1776 for (
const auto &originalResult :
llvm::enumerate(op->getResults())) {
1777 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1778 auto originalResultType =
1779 cast<ShapedType>(originalResult.value().getType());
1780 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.
getType());
1781 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1783 op.getIndexingMapMatchingResult(originalResult.value());
1788 "Expected indexing map to be a projected permutation for collapsing");
1792 if (isa<MemRefType>(collapsedOpResult.
getType())) {
1794 originalResultType.getShape(), originalResultType.getElementType());
1795 result = rewriter.
create<memref::ExpandShapeOp>(
1796 loc, expandShapeResultType, collapsedOpResult, reassociation,
1799 result = rewriter.
create<tensor::ExpandShapeOp>(
1800 loc, originalResultType, collapsedOpResult, reassociation,
1803 results.push_back(result);
1805 results.push_back(collapsedOpResult);
1815 class FoldWithProducerReshapeOpByCollapsing
1819 FoldWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1823 controlFoldingReshapes(std::move(foldReshapes)) {}
1825 LogicalResult matchAndRewrite(GenericOp genericOp,
1827 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
1828 tensor::ExpandShapeOp reshapeOp =
1829 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1835 reshapeOp.getReassociationIndices());
1836 if (collapsableIterationDims.empty() ||
1837 !controlFoldingReshapes(&opOperand)) {
1842 genericOp, collapsableIterationDims, rewriter);
1843 if (!collapseResult) {
1845 genericOp,
"failed to do the fusion by collapsing transformation");
1848 rewriter.
replaceOp(genericOp, collapseResult->results);
1860 struct FoldReshapeWithGenericOpByCollapsing
1863 FoldReshapeWithGenericOpByCollapsing(
MLIRContext *context,
1867 controlFoldingReshapes(std::move(foldReshapes)) {}
1869 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1873 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1874 if (!producerResult) {
1876 "source not produced by an operation");
1880 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1883 "producer not a generic op");
1889 producer.getDpsInitOperand(producerResult.getResultNumber()),
1890 reshapeOp.getReassociationIndices());
1891 if (collapsableIterationDims.empty()) {
1893 reshapeOp,
"failed preconditions of fusion with producer generic op");
1896 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1898 "fusion blocked by control function");
1901 std::optional<CollapseResult> collapseResult =
1903 if (!collapseResult) {
1905 producer,
"failed to do the fusion by collapsing transformation");
1908 if (!collapseResult) {
1910 "fusion by expansion failed");
1917 Value reshapeReplacement =
1919 ->results)[cast<OpResult>(reshapeOp.getSrc()).getResultNumber()];
1921 reshapeReplacement.
getDefiningOp<tensor::ExpandShapeOp>()) {
1922 reshapeReplacement = expandOp.getSrc();
1924 rewriter.
replaceOp(reshapeOp, reshapeReplacement);
1925 rewriter.
replaceOp(producer, collapseResult->results);
1933 class FoldPadWithProducerReshapeOpByCollapsing
1936 FoldPadWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1940 controlFoldingReshapes(std::move(foldReshapes)) {}
1942 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1944 tensor::ExpandShapeOp reshapeOp =
1945 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1948 if (!reshapeOp->hasOneUse())
1951 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1953 "fusion blocked by control function");
1959 reshapeOp.getReassociationIndices();
1961 for (
auto reInd : reassociations) {
1962 if (reInd.size() == 1)
1964 if (llvm::any_of(reInd, [&](int64_t ind) {
1965 return low[ind] != 0 || high[ind] != 0;
1972 RankedTensorType collapsedType = reshapeOp.getSrcType();
1973 RankedTensorType paddedType = padOp.getResultType();
1977 reshapeOp.getOutputShape(), rewriter));
1981 Location loc = reshapeOp->getLoc();
1985 if (reInd.size() == 1) {
1986 collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1988 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1989 expandedPaddedSizes[reInd[0]] = paddedSize;
1991 newLow.push_back(l);
1992 newHigh.push_back(h);
1995 RankedTensorType collapsedPaddedType =
1996 paddedType.clone(collapsedPaddedShape);
1997 auto newPadOp = rewriter.
create<tensor::PadOp>(
1998 loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1999 padOp.getConstantPaddingValue(), padOp.getNofold());
2002 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
2003 expandedPaddedSizes);
2013 template <
typename LinalgType>
2020 controlCollapseDimension(std::move(collapseDimensions)) {}
2022 LogicalResult matchAndRewrite(LinalgType op,
2025 controlCollapseDimension(op);
2026 if (collapsableIterationDims.empty())
2031 collapsableIterationDims)) {
2033 op,
"specified dimensions cannot be collapsed");
2036 std::optional<CollapseResult> collapseResult =
2038 if (!collapseResult) {
2041 rewriter.
replaceOp(op, collapseResult->results);
2063 LogicalResult matchAndRewrite(GenericOp genericOp,
2065 if (!genericOp.hasPureTensorSemantics())
2067 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2068 Operation *def = opOperand->get().getDefiningOp();
2069 TypedAttr constantAttr;
2070 auto isScalarOrSplatConstantOp = [&constantAttr](
Operation *def) ->
bool {
2073 if (
matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2075 splatAttr.
getType().getElementType().isIntOrFloat()) {
2081 IntegerAttr intAttr;
2082 if (
matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2083 constantAttr = intAttr;
2088 FloatAttr floatAttr;
2089 if (
matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2090 constantAttr = floatAttr;
2097 auto resultValue = dyn_cast<OpResult>(opOperand->get());
2098 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2107 fusedIndexMaps.reserve(genericOp->getNumOperands());
2108 fusedOperands.reserve(genericOp.getNumDpsInputs());
2109 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2110 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2111 if (inputOperand == opOperand)
2113 Value inputValue = inputOperand->get();
2114 fusedIndexMaps.push_back(
2115 genericOp.getMatchingIndexingMap(inputOperand));
2116 fusedOperands.push_back(inputValue);
2117 fusedLocs.push_back(inputValue.
getLoc());
2119 for (
OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2120 fusedIndexMaps.push_back(
2121 genericOp.getMatchingIndexingMap(&outputOperand));
2127 genericOp,
"fused op loop bound computation failed");
2131 Value scalarConstant =
2132 rewriter.
create<arith::ConstantOp>(def->
getLoc(), constantAttr);
2135 auto fusedOp = rewriter.
create<GenericOp>(
2136 rewriter.
getFusedLoc(fusedLocs), genericOp->getResultTypes(),
2140 genericOp.getIteratorTypes(),
2146 Region ®ion = genericOp->getRegion(0);
2149 mapping.
map(entryBlock.
getArgument(opOperand->getOperandNumber()),
2151 Region &fusedRegion = fusedOp->getRegion(0);
2154 rewriter.
replaceOp(genericOp, fusedOp->getResults());
2175 LogicalResult matchAndRewrite(GenericOp op,
2178 bool modifiedOutput =
false;
2180 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
2181 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2182 Value operandVal = opOperand.get();
2183 auto operandType = dyn_cast<RankedTensorType>(operandVal.
getType());
2192 auto definingOp = operandVal.
getDefiningOp<tensor::EmptyOp>();
2195 modifiedOutput =
true;
2198 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
2199 loc, mixedSizes, operandType.getElementType());
2200 op->
setOperand(opOperand.getOperandNumber(), emptyTensor);
2203 if (!modifiedOutput) {
2216 LogicalResult matchAndRewrite(GenericOp genericOp,
2218 if (!genericOp.hasPureTensorSemantics())
2220 bool fillFound =
false;
2221 Block &payload = genericOp.getRegion().
front();
2222 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2223 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2225 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2229 Value fillVal = fillOp.value();
2231 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2232 Value convertedVal =
2236 payload.
getArgument(opOperand->getOperandNumber()), convertedVal);
2238 return success(fillFound);
2247 controlFoldingReshapes);
2248 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(
patterns.getContext(),
2249 controlFoldingReshapes);
2251 controlFoldingReshapes);
2258 controlFoldingReshapes);
2259 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2260 patterns.getContext(), controlFoldingReshapes);
2262 controlFoldingReshapes);
2268 auto *context =
patterns.getContext();
2269 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2270 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2271 RemoveOutsDependency>(context);
2279 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2280 CollapseLinalgDimensions<linalg::CopyOp>>(
2281 patterns.getContext(), controlCollapseDimensions);
2296 struct LinalgElementwiseOpFusionPass
2297 :
public impl::LinalgElementwiseOpFusionPassBase<
2298 LinalgElementwiseOpFusionPass> {
2299 using impl::LinalgElementwiseOpFusionPassBase<
2300 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2301 void runOnOperation()
override {
2308 Operation *producer = fusedOperand->get().getDefiningOp();
2309 return producer && producer->
hasOneUse();
2318 affine::AffineApplyOp::getCanonicalizationPatterns(
patterns, context);
2319 GenericOp::getCanonicalizationPatterns(
patterns, context);
2320 tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
2321 tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ArrayRef< OpFoldResult > loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
ArrayRef< int64_t > ReassociationIndicesRef
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
llvm::DenseMap< Value, Value > replacements
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.