33 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
34 #include "mlir/Dialect/Linalg/Passes.h.inc"
60 assert(invProducerResultIndexMap &&
61 "expected producer result indexing map to be invertible");
63 LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
65 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
73 return t1.
compose(fusedConsumerArgIndexMap);
80 GenericOp producer, GenericOp consumer,
85 for (
auto &op : ops) {
86 for (
auto &opOperand : op->getOpOperands()) {
87 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
90 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
93 if (indexingMaps.empty()) {
96 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
104 indexingMaps, producer.getContext())) !=
AffineMap();
113 GenericOp producer, GenericOp consumer,
OpOperand *fusedOperand) {
114 llvm::SmallDenseSet<int> preservedProducerResults;
118 opOperandsToIgnore.emplace_back(fusedOperand);
120 for (
const auto &producerResult :
llvm::enumerate(producer->getResults())) {
121 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122 opOperandsToIgnore.emplace_back(outputOperand);
123 if (producer.payloadUsesValueFromOperand(outputOperand) ||
125 opOperandsToIgnore) ||
126 llvm::any_of(producerResult.value().getUsers(), [&](
Operation *user) {
127 return user != consumer.getOperation();
129 preservedProducerResults.insert(producerResult.index());
132 (void)opOperandsToIgnore.pop_back_val();
135 return preservedProducerResults;
144 auto consumer = dyn_cast<GenericOp>(fusedOperand->
getOwner());
147 if (!producer || !consumer)
153 if (!producer.hasPureTensorSemantics() ||
154 !isa<RankedTensorType>(fusedOperand->
get().
getType()))
159 if (producer.getNumParallelLoops() != producer.getNumLoops())
164 if (!consumer.isDpsInput(fusedOperand))
169 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170 if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
175 auto producerResult = cast<OpResult>(fusedOperand->
get());
177 producer.getIndexingMapMatchingResult(producerResult);
185 if ((consumer.getNumReductionLoops())) {
186 BitVector coveredDims(consumer.getNumLoops(),
false);
188 auto addToCoveredDims = [&](
AffineMap map) {
189 for (
auto result : map.getResults())
190 if (
auto dimExpr = dyn_cast<AffineDimExpr>(result))
191 coveredDims[dimExpr.getPosition()] =
true;
195 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
196 Value operand = std::get<0>(pair);
197 if (operand == fusedOperand->
get())
199 AffineMap operandMap = std::get<1>(pair);
200 addToCoveredDims(operandMap);
203 for (
OpOperand *operand : producer.getDpsInputOperands()) {
206 operand, producerResultIndexMap, consumerIndexMap);
207 addToCoveredDims(newIndexingMap);
209 if (!coveredDims.all())
221 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
223 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
225 Block &producerBlock = producer->getRegion(0).
front();
226 Block &consumerBlock = consumer->getRegion(0).
front();
233 if (producer.hasIndexSemantics()) {
235 unsigned numFusedOpLoops = fusedOp.getNumLoops();
237 fusedIndices.reserve(numFusedOpLoops);
238 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
239 std::back_inserter(fusedIndices), [&](uint64_t dim) {
240 return IndexOp::create(rewriter, producer.getLoc(), dim);
242 for (IndexOp indexOp :
243 llvm::make_early_inc_range(producerBlock.
getOps<IndexOp>())) {
244 Value newIndex = affine::AffineApplyOp::create(
245 rewriter, producer.getLoc(),
246 consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
247 mapper.
map(indexOp.getResult(), newIndex);
251 assert(consumer.isDpsInput(fusedOperand) &&
252 "expected producer of input operand");
256 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
263 producerBlock.
getArguments().take_front(producer.getNumDpsInputs()))
264 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
269 .take_front(consumer.getNumDpsInputs())
271 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
275 producerBlock.
getArguments().take_back(producer.getNumDpsInits()))) {
276 if (!preservedProducerResults.count(bbArg.index()))
278 mapper.
map(bbArg.value(), fusedBlock->
addArgument(bbArg.value().getType(),
279 bbArg.value().getLoc()));
284 consumerBlock.
getArguments().take_back(consumer.getNumDpsInits()))
285 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
290 if (!isa<IndexOp>(op))
291 rewriter.
clone(op, mapper);
295 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.
getTerminator());
296 unsigned producerResultNumber =
297 cast<OpResult>(fusedOperand->
get()).getResultNumber();
299 mapper.
lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
303 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
304 if (
auto bb = dyn_cast<BlockArgument>(replacement))
305 assert(bb.getOwner() != &producerBlock &&
306 "yielded block argument must have been mapped");
309 "yielded value must have been mapped");
315 rewriter.
clone(op, mapper);
319 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.
getTerminator());
321 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
322 consumerYieldOp.getNumOperands());
323 for (
const auto &producerYieldVal :
325 if (preservedProducerResults.count(producerYieldVal.index()))
326 fusedYieldValues.push_back(
329 for (
auto consumerYieldVal : consumerYieldOp.getOperands())
331 YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);
335 "Ill-formed GenericOp region");
338 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
342 "expected elementwise operation pre-conditions to pass");
343 auto producerResult = cast<OpResult>(fusedOperand->
get());
344 auto producer = cast<GenericOp>(producerResult.getOwner());
345 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
347 assert(consumer.isDpsInput(fusedOperand) &&
348 "expected producer of input operand");
351 llvm::SmallDenseSet<int> preservedProducerResults =
359 fusedInputOperands.reserve(producer.getNumDpsInputs() +
360 consumer.getNumDpsInputs());
361 fusedOutputOperands.reserve(preservedProducerResults.size() +
362 consumer.getNumDpsInits());
363 fusedResultTypes.reserve(preservedProducerResults.size() +
364 consumer.getNumDpsInits());
365 fusedIndexMaps.reserve(producer->getNumOperands() +
366 consumer->getNumOperands());
369 auto consumerInputs = consumer.getDpsInputOperands();
370 auto *it = llvm::find_if(consumerInputs, [&](
OpOperand *operand) {
371 return operand == fusedOperand;
373 assert(it != consumerInputs.end() &&
"expected to find the consumer operand");
374 for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
375 fusedInputOperands.push_back(opOperand->get());
376 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
380 producer.getIndexingMapMatchingResult(producerResult);
381 for (
OpOperand *opOperand : producer.getDpsInputOperands()) {
382 fusedInputOperands.push_back(opOperand->get());
385 opOperand, producerResultIndexMap,
386 consumer.getMatchingIndexingMap(fusedOperand));
387 fusedIndexMaps.push_back(map);
392 llvm::make_range(std::next(it), consumerInputs.end())) {
393 fusedInputOperands.push_back(opOperand->get());
394 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
398 for (
const auto &opOperand :
llvm::enumerate(producer.getDpsInitsMutable())) {
399 if (!preservedProducerResults.count(opOperand.index()))
402 fusedOutputOperands.push_back(opOperand.value().get());
404 &opOperand.value(), producerResultIndexMap,
405 consumer.getMatchingIndexingMap(fusedOperand));
406 fusedIndexMaps.push_back(map);
407 fusedResultTypes.push_back(opOperand.value().get().getType());
411 for (
OpOperand &opOperand : consumer.getDpsInitsMutable()) {
412 fusedOutputOperands.push_back(opOperand.get());
413 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
414 Type resultType = opOperand.get().getType();
415 if (!isa<MemRefType>(resultType))
416 fusedResultTypes.push_back(resultType);
420 auto fusedOp = GenericOp::create(
421 rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
423 consumer.getIteratorTypes(),
426 if (!fusedOp.getShapesToLoopsMap()) {
432 fusedOp,
"fused op failed loop bound computation check");
438 consumer.getMatchingIndexingMap(fusedOperand);
442 assert(invProducerResultIndexMap &&
443 "expected producer result indexig map to be invertible");
446 invProducerResultIndexMap.
compose(consumerResultIndexMap);
449 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
450 consumer.getNumLoops(), preservedProducerResults);
454 for (
auto [index, producerResult] :
llvm::enumerate(producer->getResults()))
455 if (preservedProducerResults.count(index))
456 result.
replacements[producerResult] = fusedOp->getResult(resultNum++);
457 for (
auto consumerResult : consumer->getResults())
458 result.
replacements[consumerResult] = fusedOp->getResult(resultNum++);
469 controlFn(std::move(fun)) {}
471 LogicalResult matchAndRewrite(GenericOp genericOp,
474 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
477 if (!controlFn(&opOperand))
480 Operation *producer = opOperand.get().getDefiningOp();
483 FailureOr<ElementwiseOpFusionResult> fusionResult =
489 for (
auto [origVal, replacement] : fusionResult->replacements) {
492 return use.
get().getDefiningOp() != producer;
571 linalgOp.getIteratorTypesArray();
572 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573 return linalgOp.hasPureTensorSemantics() &&
574 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
576 return cast<AffineMapAttr>(attr)
578 .isProjectedPermutation();
586 class ExpansionInfo {
592 LogicalResult compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
596 unsigned getOrigOpNumDims()
const {
return reassociation.size(); }
597 unsigned getExpandedOpNumDims()
const {
return expandedOpNumDims; }
599 return reassociation[i];
602 return expandedShapeMap[i];
615 unsigned expandedOpNumDims;
619 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
624 if (reassociationMaps.empty())
626 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
630 originalLoopExtent = llvm::map_to_vector(
631 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
632 [](
Range r) { return r.size; });
634 reassociation.clear();
635 expandedShapeMap.clear();
639 expandedShapeMap.resize(fusedIndexMap.
getNumDims());
641 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
642 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
645 expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
646 expandedShapeMap[pos].assign(shape.begin(), shape.end());
649 for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
650 if (expandedShapeMap[i].empty())
651 expandedShapeMap[i] = {originalLoopExtent[i]};
655 reassociation.reserve(fusedIndexMap.
getNumDims());
657 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
658 reassociation.emplace_back(seq.begin(), seq.end());
659 sum += numFoldedDim.value();
661 expandedOpNumDims = sum;
669 const ExpansionInfo &expansionInfo) {
672 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
674 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
675 return builder.getAffineDimExpr(static_cast<unsigned>(v));
677 newExprs.append(expandedExprs.begin(), expandedExprs.end());
686 static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
688 const ExpansionInfo &expansionInfo) {
691 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
693 expansionInfo.getExpandedShapeOfDim(dim);
694 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
697 std::tie(expandedStaticShape, std::ignore) =
700 originalType.getElementType())};
711 const ExpansionInfo &expansionInfo) {
713 unsigned numReshapeDims = 0;
715 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
716 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
718 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
719 reassociation.emplace_back(std::move(indices));
720 numReshapeDims += numExpandedDims;
722 return reassociation;
732 const ExpansionInfo &expansionInfo) {
734 for (IndexOp indexOp :
735 llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
737 expansionInfo.getExpandedDims(indexOp.getDim());
738 assert(!expandedDims.empty() &&
"expected valid expansion info");
741 if (expandedDims.size() == 1 &&
742 expandedDims.front() == (int64_t)indexOp.getDim())
749 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
751 expandedIndices.reserve(expandedDims.size() - 1);
753 expandedDims.drop_front(), std::back_inserter(expandedIndices),
754 [&](int64_t dim) { return IndexOp::create(rewriter, loc, dim); });
756 IndexOp::create(rewriter, loc, expandedDims.front()).getResult();
757 for (
auto [expandedShape, expandedIndex] :
758 llvm::zip(expandedDimsShape, expandedIndices)) {
763 rewriter, indexOp.getLoc(), idx + acc * shape,
768 rewriter.
replaceOp(indexOp, newIndexVal);
790 TransposeOp transposeOp,
792 ExpansionInfo &expansionInfo) {
795 auto reassoc = expansionInfo.getExpandedDims(perm);
796 for (int64_t dim : reassoc) {
797 newPerm.push_back(dim);
800 return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput,
811 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
813 for (
auto [i, type] :
llvm::enumerate(linalgOp.getIteratorTypesArray()))
814 for (
auto j : expansionInfo.getExpandedDims(i))
815 iteratorTypes[
j] = type;
817 Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes,
818 expandedOpOperands, outputs,
819 expandedOpIndexingMaps, iteratorTypes);
822 Region &originalRegion = linalgOp->getRegion(0);
840 ExpansionInfo &expansionInfo) {
843 .Case<TransposeOp>([&](TransposeOp transposeOp) {
845 expandedOpOperands[0], outputs[0],
848 .Case<FillOp, CopyOp>([&](
Operation *op) {
849 return clone(rewriter, linalgOp, resultTypes,
850 llvm::to_vector(llvm::concat<Value>(
851 llvm::to_vector(expandedOpOperands),
852 llvm::to_vector(outputs))));
856 expandedOpOperands, outputs,
857 expansionInfo, expandedOpIndexingMaps);
864 static std::optional<SmallVector<Value>>
869 "preconditions for fuse operation failed");
875 if (
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
879 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
882 expandedShape = expandingReshapeOp.getMixedOutputShape();
883 reassociationIndices = expandingReshapeOp.getReassociationMaps();
884 src = expandingReshapeOp.getSrc();
886 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887 if (!collapsingReshapeOp)
891 rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
892 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
893 src = collapsingReshapeOp.getSrc();
896 ExpansionInfo expansionInfo;
897 if (
failed(expansionInfo.compute(linalgOp, fusableOpOperand,
898 reassociationIndices, expandedShape,
903 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](
AffineMap m) {
904 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
912 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
913 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
914 if (opOperand == fusableOpOperand) {
915 expandedOpOperands.push_back(src);
918 if (
auto opOperandType =
919 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
920 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
922 RankedTensorType expandedOperandType;
923 std::tie(expandedOperandShape, expandedOperandType) =
925 if (expandedOperandType != opOperand->get().getType()) {
930 [&](
const Twine &msg) {
933 opOperandType.getShape(), expandedOperandType.getShape(),
937 expandedOpOperands.push_back(tensor::ExpandShapeOp::create(
938 rewriter, loc, expandedOperandType, opOperand->get(), reassociation,
939 expandedOperandShape));
943 expandedOpOperands.push_back(opOperand->get());
947 for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
948 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
949 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
951 RankedTensorType expandedOutputType;
952 std::tie(expandedOutputShape, expandedOutputType) =
954 if (expandedOutputType != opOperand.get().getType()) {
958 [&](
const Twine &msg) {
961 opOperandType.getShape(), expandedOutputType.getShape(),
965 outputs.push_back(tensor::ExpandShapeOp::create(
966 rewriter, loc, expandedOutputType, opOperand.get(), reassociation,
967 expandedOutputShape));
969 outputs.push_back(opOperand.get());
976 outputs, expandedOpIndexingMaps, expansionInfo);
980 for (
OpResult opResult : linalgOp->getOpResults()) {
981 int64_t resultNumber = opResult.getResultNumber();
982 if (resultTypes[resultNumber] != opResult.getType()) {
985 linalgOp.getMatchingIndexingMap(
986 linalgOp.getDpsInitOperand(resultNumber)),
988 resultVals.push_back(tensor::CollapseShapeOp::create(
989 rewriter, linalgOp.getLoc(), opResult.
getType(),
990 fusedOp->
getResult(resultNumber), reassociation));
992 resultVals.push_back(fusedOp->
getResult(resultNumber));
1004 class FoldWithProducerReshapeOpByExpansion
1007 FoldWithProducerReshapeOpByExpansion(
MLIRContext *context,
1011 controlFoldingReshapes(std::move(foldReshapes)) {}
1013 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1015 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1016 tensor::CollapseShapeOp reshapeOp =
1017 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1024 (!controlFoldingReshapes(opOperand)))
1027 std::optional<SmallVector<Value>> replacementValues =
1029 if (!replacementValues)
1031 rewriter.
replaceOp(linalgOp, *replacementValues);
1041 class FoldPadWithProducerReshapeOpByExpansion
1044 FoldPadWithProducerReshapeOpByExpansion(
MLIRContext *context,
1048 controlFoldingReshapes(std::move(foldReshapes)) {}
1050 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1052 tensor::CollapseShapeOp reshapeOp =
1053 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1056 if (!reshapeOp->hasOneUse())
1059 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1061 "fusion blocked by control function");
1067 reshapeOp.getReassociationIndices();
1069 for (
auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1070 if (reInd.size() != 1 && (l != 0 || h != 0))
1075 RankedTensorType expandedType = reshapeOp.getSrcType();
1076 RankedTensorType paddedType = padOp.getResultType();
1079 if (reInd.size() == 1) {
1080 expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1082 for (
size_t i = 0; i < reInd.size(); ++i) {
1083 newLow.push_back(padOp.getMixedLowPad()[idx]);
1084 newHigh.push_back(padOp.getMixedHighPad()[idx]);
1089 RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1090 auto newPadOp = tensor::PadOp::create(
1091 rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1092 padOp.getConstantPaddingValue(), padOp.getNofold());
1095 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1106 struct FoldReshapeWithGenericOpByExpansion
1109 FoldReshapeWithGenericOpByExpansion(
MLIRContext *context,
1113 controlFoldingReshapes(std::move(foldReshapes)) {}
1115 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1118 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1119 if (!producerResult) {
1121 "source not produced by an operation");
1124 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1127 "producer not a generic op");
1132 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1134 reshapeOp,
"failed preconditions of fusion with producer generic op");
1137 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1139 "fusion blocked by control function");
1142 std::optional<SmallVector<Value>> replacementValues =
1144 producer, reshapeOp,
1145 producer.getDpsInitOperand(producerResult.getResultNumber()),
1147 if (!replacementValues) {
1149 "fusion by expansion failed");
1156 Value reshapeReplacement =
1157 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1158 .getResultNumber()];
1159 if (
auto collapseOp =
1160 reshapeReplacement.
getDefiningOp<tensor::CollapseShapeOp>()) {
1161 reshapeReplacement = collapseOp.getSrc();
1163 rewriter.
replaceOp(reshapeOp, reshapeReplacement);
1164 rewriter.
replaceOp(producer, *replacementValues);
1186 "expected projected permutation");
1189 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1190 return cast<AffineDimExpr>(indexingMap.
getResults()[pos]).getPosition();
1194 return domainReassociation;
1202 assert(!dimSequence.empty() &&
1203 "expected non-empty list for dimension sequence");
1205 "expected indexing map to be projected permutation");
1207 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1208 sequenceElements.insert_range(dimSequence);
1210 unsigned dimSequenceStart = dimSequence[0];
1212 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1214 if (dimInMapStart == dimSequenceStart) {
1215 if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
1218 for (
const auto &dimInSequence :
enumerate(dimSequence)) {
1220 cast<AffineDimExpr>(
1221 indexingMap.
getResult(expr.index() + dimInSequence.index()))
1223 if (dimInMap != dimInSequence.value())
1234 if (sequenceElements.count(dimInMapStart))
1243 return llvm::all_of(maps, [&](
AffineMap map) {
1300 if (!genericOp.hasPureTensorSemantics())
1303 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
1304 return map.isProjectedPermutation();
1311 genericOp.getReductionDims(reductionDims);
1313 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1314 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1315 auto iteratorTypes = genericOp.getIteratorTypesArray();
1318 assert(!foldedRangeDims.empty() &&
"unexpected empty reassociation");
1321 if (foldedRangeDims.size() == 1)
1329 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1330 return processedIterationDims.count(dim);
1335 utils::IteratorType startIteratorType =
1336 iteratorTypes[foldedIterationSpaceDims[0]];
1340 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1341 return iteratorTypes[dim] != startIteratorType;
1350 bool isContiguous =
false;
1353 if (startDim.value() != foldedIterationSpaceDims[0])
1357 if (startDim.index() + foldedIterationSpaceDims.size() >
1358 reductionDims.size())
1361 isContiguous =
true;
1362 for (
const auto &foldedDim :
1364 if (reductionDims[foldedDim.index() + startDim.index()] !=
1365 foldedDim.value()) {
1366 isContiguous =
false;
1377 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1379 return !isDimSequencePreserved(indexingMap,
1380 foldedIterationSpaceDims);
1384 processedIterationDims.insert_range(foldedIterationSpaceDims);
1385 iterationSpaceReassociation.emplace_back(
1386 std::move(foldedIterationSpaceDims));
1389 return iterationSpaceReassociation;
1394 class CollapsingInfo {
1396 LogicalResult initialize(
unsigned origNumLoops,
1398 llvm::SmallDenseSet<int64_t, 4> processedDims;
1401 if (foldedIterationDim.empty())
1405 for (
auto dim : foldedIterationDim) {
1406 if (dim >= origNumLoops)
1408 if (processedDims.count(dim))
1410 processedDims.insert(dim);
1412 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1413 foldedIterationDim.end());
1415 if (processedDims.size() > origNumLoops)
1420 for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1421 if (processedDims.count(dim))
1426 llvm::sort(collapsedOpToOrigOpIterationDim,
1428 return lhs[0] < rhs[0];
1430 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1431 for (
const auto &foldedDims :
1433 for (
const auto &dim :
enumerate(foldedDims.value()))
1434 origOpToCollapsedOpIterationDim[dim.value()] =
1435 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1442 return collapsedOpToOrigOpIterationDim;
1466 return origOpToCollapsedOpIterationDim;
1470 unsigned getCollapsedOpIterationRank()
const {
1471 return collapsedOpToOrigOpIterationDim.size();
1489 const CollapsingInfo &collapsingInfo) {
1492 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1493 assert(!foldedIterDims.empty() &&
1494 "reassociation indices expected to have non-empty sets");
1498 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1500 return collapsedIteratorTypes;
1507 const CollapsingInfo &collapsingInfo) {
1510 "expected indexing map to be projected permutation");
1512 auto origOpToCollapsedOpMapping =
1513 collapsingInfo.getOrigOpToCollapsedOpMapping();
1515 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1517 if (origOpToCollapsedOpMapping[dim].second != 0)
1521 resultExprs.push_back(
1524 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1525 resultExprs, context);
1532 const CollapsingInfo &collapsingInfo) {
1533 unsigned counter = 0;
1535 auto origOpToCollapsedOpMapping =
1536 collapsingInfo.getOrigOpToCollapsedOpMapping();
1537 auto collapsedOpToOrigOpMapping =
1538 collapsingInfo.getCollapsedOpToOrigOpMapping();
1541 cast<AffineDimExpr>(indexingMap.
getResult(counter)).getPosition();
1545 unsigned numFoldedDims =
1546 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1548 if (origOpToCollapsedOpMapping[dim].second == 0) {
1549 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1550 operandReassociation.emplace_back(range.begin(), range.end());
1552 counter += numFoldedDims;
1554 return operandReassociation;
1560 const CollapsingInfo &collapsingInfo,
1562 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1570 if (operandReassociation.size() == indexingMap.
getNumResults())
1574 if (isa<MemRefType>(operand.
getType())) {
1575 return memref::CollapseShapeOp::create(builder, loc, operand,
1576 operandReassociation)
1579 return tensor::CollapseShapeOp::create(builder, loc, operand,
1580 operandReassociation)
1587 Location loc,
Block *block,
const CollapsingInfo &collapsingInfo,
1593 auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
1603 for (
auto foldedDims :
1604 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1607 linalg::IndexOp::create(rewriter, loc, foldedDims.index());
1608 for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1611 indexReplacementVals[dim] =
1612 rewriter.
createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1614 rewriter.
createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1616 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1619 for (
auto indexOp : indexOps) {
1620 auto dim = indexOp.getDim();
1621 rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
1626 const CollapsingInfo &collapsingInfo,
1633 llvm::map_to_vector(op.getDpsInputOperands(), [&](
OpOperand *opOperand) {
1634 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1639 resultTypes.reserve(op.getNumDpsInits());
1640 outputOperands.reserve(op.getNumDpsInits());
1641 for (
OpOperand &output : op.getDpsInitsMutable()) {
1644 outputOperands.push_back(newOutput);
1647 if (!op.hasPureBufferSemantics())
1648 resultTypes.push_back(newOutput.
getType());
1653 template <
typename OpTy>
1655 const CollapsingInfo &collapsingInfo) {
1663 const CollapsingInfo &collapsingInfo) {
1667 outputOperands, resultTypes);
1670 rewriter, origOp, resultTypes,
1671 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1678 const CollapsingInfo &collapsingInfo) {
1682 outputOperands, resultTypes);
1684 llvm::map_range(origOp.getIndexingMapsArray(), [&](
AffineMap map) {
1685 return getCollapsedOpIndexingMap(map, collapsingInfo);
1689 origOp.getIteratorTypesArray(), collapsingInfo));
1691 GenericOp collapsedOp = linalg::GenericOp::create(
1692 rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands,
1693 indexingMaps, iteratorTypes,
1695 Block *origOpBlock = &origOp->getRegion(0).
front();
1696 Block *collapsedOpBlock = &collapsedOp->getRegion(0).
front();
1697 rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
1704 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1716 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1718 return foldedDims.size() <= 1;
1722 CollapsingInfo collapsingInfo;
1724 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1726 op,
"illegal to collapse specified dimensions");
1729 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1730 if (hasPureBufferSemantics &&
1731 !llvm::all_of(op->getOpOperands(), [&](
OpOperand &opOperand) ->
bool {
1732 MemRefType memRefToCollapse =
1733 dyn_cast<MemRefType>(opOperand.get().getType());
1734 if (!memRefToCollapse)
1737 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1738 SmallVector<ReassociationIndices> operandReassociation =
1739 getOperandReassociation(indexingMap, collapsingInfo);
1740 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1741 memRefToCollapse, operandReassociation);
1744 "memref is not guaranteed collapsible");
1748 auto opFoldIsConstantValue = [](
OpFoldResult ofr, int64_t value) {
1749 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1750 return cast<IntegerAttr>(attr).getInt() == value;
1753 actual.getSExtValue() == value;
1755 if (!llvm::all_of(loopRanges, [&](
Range range) {
1756 return opFoldIsConstantValue(range.
offset, 0) &&
1757 opFoldIsConstantValue(range.
stride, 1);
1760 op,
"expected all loop ranges to have zero start and unit stride");
1767 llvm::map_to_vector(loopRanges, [](
Range range) {
return range.
size; });
1769 if (collapsedOp.hasIndexSemantics()) {
1774 collapsingInfo, loopBound, rewriter);
1780 for (
const auto &originalResult :
llvm::enumerate(op->getResults())) {
1781 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1782 auto originalResultType =
1783 cast<ShapedType>(originalResult.value().getType());
1784 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.
getType());
1785 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1787 op.getIndexingMapMatchingResult(originalResult.value());
1792 "Expected indexing map to be a projected permutation for collapsing");
1796 if (isa<MemRefType>(collapsedOpResult.
getType())) {
1798 originalResultType.getShape(), originalResultType.getElementType());
1799 result = memref::ExpandShapeOp::create(
1800 rewriter, loc, expandShapeResultType, collapsedOpResult,
1801 reassociation, resultShape);
1803 result = tensor::ExpandShapeOp::create(
1804 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1807 results.push_back(result);
1809 results.push_back(collapsedOpResult);
1819 class FoldWithProducerReshapeOpByCollapsing
1823 FoldWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1827 controlFoldingReshapes(std::move(foldReshapes)) {}
1829 LogicalResult matchAndRewrite(GenericOp genericOp,
1831 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
1832 tensor::ExpandShapeOp reshapeOp =
1833 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1839 reshapeOp.getReassociationIndices());
1840 if (collapsableIterationDims.empty() ||
1841 !controlFoldingReshapes(&opOperand)) {
1846 genericOp, collapsableIterationDims, rewriter);
1847 if (!collapseResult) {
1849 genericOp,
"failed to do the fusion by collapsing transformation");
1852 rewriter.
replaceOp(genericOp, collapseResult->results);
1864 struct FoldReshapeWithGenericOpByCollapsing
1867 FoldReshapeWithGenericOpByCollapsing(
MLIRContext *context,
1871 controlFoldingReshapes(std::move(foldReshapes)) {}
1873 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1877 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1878 if (!producerResult) {
1880 "source not produced by an operation");
1884 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1887 "producer not a generic op");
1893 producer.getDpsInitOperand(producerResult.getResultNumber()),
1894 reshapeOp.getReassociationIndices());
1895 if (collapsableIterationDims.empty()) {
1897 reshapeOp,
"failed preconditions of fusion with producer generic op");
1900 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1902 "fusion blocked by control function");
1908 std::optional<CollapseResult> collapseResult =
1910 if (!collapseResult) {
1912 producer,
"failed to do the fusion by collapsing transformation");
1915 rewriter.
replaceOp(producer, collapseResult->results);
1923 class FoldPadWithProducerReshapeOpByCollapsing
1926 FoldPadWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1930 controlFoldingReshapes(std::move(foldReshapes)) {}
1932 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1934 tensor::ExpandShapeOp reshapeOp =
1935 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1938 if (!reshapeOp->hasOneUse())
1941 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1943 "fusion blocked by control function");
1949 reshapeOp.getReassociationIndices();
1951 for (
auto reInd : reassociations) {
1952 if (reInd.size() == 1)
1954 if (llvm::any_of(reInd, [&](int64_t ind) {
1955 return low[ind] != 0 || high[ind] != 0;
1962 RankedTensorType collapsedType = reshapeOp.getSrcType();
1963 RankedTensorType paddedType = padOp.getResultType();
1967 reshapeOp.getOutputShape(), rewriter));
1971 Location loc = reshapeOp->getLoc();
1975 if (reInd.size() == 1) {
1976 collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1978 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1979 expandedPaddedSizes[reInd[0]] = paddedSize;
1981 newLow.push_back(l);
1982 newHigh.push_back(h);
1985 RankedTensorType collapsedPaddedType =
1986 paddedType.clone(collapsedPaddedShape);
1987 auto newPadOp = tensor::PadOp::create(
1988 rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1989 padOp.getConstantPaddingValue(), padOp.getNofold());
1992 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1993 expandedPaddedSizes);
2003 template <
typename LinalgType>
2010 controlCollapseDimension(std::move(collapseDimensions)) {}
2012 LogicalResult matchAndRewrite(LinalgType op,
2015 controlCollapseDimension(op);
2016 if (collapsableIterationDims.empty())
2021 collapsableIterationDims)) {
2023 op,
"specified dimensions cannot be collapsed");
2026 std::optional<CollapseResult> collapseResult =
2028 if (!collapseResult) {
2031 rewriter.
replaceOp(op, collapseResult->results);
2053 LogicalResult matchAndRewrite(GenericOp genericOp,
2055 if (!genericOp.hasPureTensorSemantics())
2057 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2058 Operation *def = opOperand->get().getDefiningOp();
2059 TypedAttr constantAttr;
2060 auto isScalarOrSplatConstantOp = [&constantAttr](
Operation *def) ->
bool {
2063 if (
matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2065 splatAttr.
getType().getElementType().isIntOrFloat()) {
2071 IntegerAttr intAttr;
2072 if (
matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2073 constantAttr = intAttr;
2078 FloatAttr floatAttr;
2079 if (
matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2080 constantAttr = floatAttr;
2087 auto resultValue = dyn_cast<OpResult>(opOperand->get());
2088 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2097 fusedIndexMaps.reserve(genericOp->getNumOperands());
2098 fusedOperands.reserve(genericOp.getNumDpsInputs());
2099 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2100 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2101 if (inputOperand == opOperand)
2103 Value inputValue = inputOperand->get();
2104 fusedIndexMaps.push_back(
2105 genericOp.getMatchingIndexingMap(inputOperand));
2106 fusedOperands.push_back(inputValue);
2107 fusedLocs.push_back(inputValue.
getLoc());
2109 for (
OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2110 fusedIndexMaps.push_back(
2111 genericOp.getMatchingIndexingMap(&outputOperand));
2117 genericOp,
"fused op loop bound computation failed");
2121 Value scalarConstant =
2122 arith::ConstantOp::create(rewriter, def->
getLoc(), constantAttr);
2126 GenericOp::create(rewriter, rewriter.
getFusedLoc(fusedLocs),
2127 genericOp->getResultTypes(),
2131 genericOp.getIteratorTypes(),
2137 Region ®ion = genericOp->getRegion(0);
2140 mapping.
map(entryBlock.
getArgument(opOperand->getOperandNumber()),
2142 Region &fusedRegion = fusedOp->getRegion(0);
2145 rewriter.
replaceOp(genericOp, fusedOp->getResults());
2166 LogicalResult matchAndRewrite(GenericOp op,
2169 bool modifiedOutput =
false;
2171 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
2172 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2173 Value operandVal = opOperand.get();
2174 auto operandType = dyn_cast<RankedTensorType>(operandVal.
getType());
2183 auto definingOp = operandVal.
getDefiningOp<tensor::EmptyOp>();
2186 modifiedOutput =
true;
2189 Value emptyTensor = tensor::EmptyOp::create(
2190 rewriter, loc, mixedSizes, operandType.getElementType());
2191 op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2194 if (!modifiedOutput) {
2207 LogicalResult matchAndRewrite(GenericOp genericOp,
2209 if (!genericOp.hasPureTensorSemantics())
2211 bool fillFound =
false;
2212 Block &payload = genericOp.getRegion().
front();
2213 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2214 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2216 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2220 Value fillVal = fillOp.value();
2222 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2223 Value convertedVal =
2227 payload.
getArgument(opOperand->getOperandNumber()), convertedVal);
2229 return success(fillFound);
2238 controlFoldingReshapes);
2239 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(
patterns.getContext(),
2240 controlFoldingReshapes);
2242 controlFoldingReshapes);
2249 controlFoldingReshapes);
2250 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2251 patterns.getContext(), controlFoldingReshapes);
2253 controlFoldingReshapes);
2259 auto *context =
patterns.getContext();
2260 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2261 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2262 RemoveOutsDependency>(context);
2270 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2271 CollapseLinalgDimensions<linalg::CopyOp>>(
2272 patterns.getContext(), controlCollapseDimensions);
2287 struct LinalgElementwiseOpFusionPass
2288 :
public impl::LinalgElementwiseOpFusionPassBase<
2289 LinalgElementwiseOpFusionPass> {
2290 using impl::LinalgElementwiseOpFusionPassBase<
2291 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2292 void runOnOperation()
override {
2299 Operation *producer = fusedOperand->get().getDefiningOp();
2300 return producer && producer->
hasOneUse();
2309 affine::AffineApplyOp::getCanonicalizationPatterns(
patterns, context);
2310 GenericOp::getCanonicalizationPatterns(
patterns, context);
2311 tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
2312 tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ArrayRef< OpFoldResult > loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setUseTopDownTraversal(bool use=true)
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
ArrayRef< int64_t > ReassociationIndicesRef
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
llvm::DenseMap< Value, Value > replacements
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.