33 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
34 #include "mlir/Dialect/Linalg/Passes.h.inc"
60 assert(invProducerResultIndexMap &&
61 "expected producer result indexing map to be invertible");
63 LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
65 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
73 return t1.
compose(fusedConsumerArgIndexMap);
80 GenericOp producer, GenericOp consumer,
85 for (
auto &op : ops) {
86 for (
auto &opOperand : op->getOpOperands()) {
87 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
90 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
93 if (indexingMaps.empty()) {
96 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
104 indexingMaps, producer.getContext())) !=
AffineMap();
113 GenericOp producer, GenericOp consumer,
OpOperand *fusedOperand) {
114 llvm::SmallDenseSet<int> preservedProducerResults;
118 opOperandsToIgnore.emplace_back(fusedOperand);
120 for (
const auto &producerResult :
llvm::enumerate(producer->getResults())) {
121 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122 opOperandsToIgnore.emplace_back(outputOperand);
123 if (producer.payloadUsesValueFromOperand(outputOperand) ||
125 opOperandsToIgnore) ||
126 llvm::any_of(producerResult.value().getUsers(), [&](
Operation *user) {
127 return user != consumer.getOperation();
129 preservedProducerResults.insert(producerResult.index());
132 (void)opOperandsToIgnore.pop_back_val();
135 return preservedProducerResults;
144 auto consumer = dyn_cast<GenericOp>(fusedOperand->
getOwner());
147 if (!producer || !consumer)
153 if (!producer.hasPureTensorSemantics() ||
154 !isa<RankedTensorType>(fusedOperand->
get().
getType()))
159 if (producer.getNumParallelLoops() != producer.getNumLoops())
164 if (!consumer.isDpsInput(fusedOperand))
169 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170 if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
175 auto producerResult = cast<OpResult>(fusedOperand->
get());
177 producer.getIndexingMapMatchingResult(producerResult);
185 if ((consumer.getNumReductionLoops())) {
186 BitVector coveredDims(consumer.getNumLoops(),
false);
188 auto addToCoveredDims = [&](
AffineMap map) {
189 for (
auto result : map.getResults())
190 if (
auto dimExpr = dyn_cast<AffineDimExpr>(result))
191 coveredDims[dimExpr.getPosition()] =
true;
195 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
196 Value operand = std::get<0>(pair);
197 if (operand == fusedOperand->
get())
199 AffineMap operandMap = std::get<1>(pair);
200 addToCoveredDims(operandMap);
203 for (
OpOperand *operand : producer.getDpsInputOperands()) {
206 operand, producerResultIndexMap, consumerIndexMap);
207 addToCoveredDims(newIndexingMap);
209 if (!coveredDims.all())
221 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
223 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
225 Block &producerBlock = producer->getRegion(0).
front();
226 Block &consumerBlock = consumer->getRegion(0).
front();
233 if (producer.hasIndexSemantics()) {
235 unsigned numFusedOpLoops = fusedOp.getNumLoops();
237 fusedIndices.reserve(numFusedOpLoops);
238 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
239 std::back_inserter(fusedIndices), [&](uint64_t dim) {
240 return IndexOp::create(rewriter, producer.getLoc(), dim);
242 for (IndexOp indexOp :
243 llvm::make_early_inc_range(producerBlock.
getOps<IndexOp>())) {
244 Value newIndex = affine::AffineApplyOp::create(
245 rewriter, producer.getLoc(),
246 consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
247 mapper.
map(indexOp.getResult(), newIndex);
251 assert(consumer.isDpsInput(fusedOperand) &&
252 "expected producer of input operand");
256 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
263 producerBlock.
getArguments().take_front(producer.getNumDpsInputs()))
264 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
269 .take_front(consumer.getNumDpsInputs())
271 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
275 producerBlock.
getArguments().take_back(producer.getNumDpsInits()))) {
276 if (!preservedProducerResults.count(bbArg.index()))
278 mapper.
map(bbArg.value(), fusedBlock->
addArgument(bbArg.value().getType(),
279 bbArg.value().getLoc()));
284 consumerBlock.
getArguments().take_back(consumer.getNumDpsInits()))
285 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
290 if (!isa<IndexOp>(op))
291 rewriter.
clone(op, mapper);
295 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.
getTerminator());
296 unsigned producerResultNumber =
297 cast<OpResult>(fusedOperand->
get()).getResultNumber();
299 mapper.
lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
303 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
304 if (
auto bb = dyn_cast<BlockArgument>(replacement))
305 assert(bb.getOwner() != &producerBlock &&
306 "yielded block argument must have been mapped");
309 "yielded value must have been mapped");
315 rewriter.
clone(op, mapper);
319 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.
getTerminator());
321 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
322 consumerYieldOp.getNumOperands());
323 for (
const auto &producerYieldVal :
325 if (preservedProducerResults.count(producerYieldVal.index()))
326 fusedYieldValues.push_back(
329 for (
auto consumerYieldVal : consumerYieldOp.getOperands())
331 YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);
335 "Ill-formed GenericOp region");
338 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
342 "expected elementwise operation pre-conditions to pass");
343 auto producerResult = cast<OpResult>(fusedOperand->
get());
344 auto producer = cast<GenericOp>(producerResult.getOwner());
345 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
347 assert(consumer.isDpsInput(fusedOperand) &&
348 "expected producer of input operand");
351 llvm::SmallDenseSet<int> preservedProducerResults =
359 fusedInputOperands.reserve(producer.getNumDpsInputs() +
360 consumer.getNumDpsInputs());
361 fusedOutputOperands.reserve(preservedProducerResults.size() +
362 consumer.getNumDpsInits());
363 fusedResultTypes.reserve(preservedProducerResults.size() +
364 consumer.getNumDpsInits());
365 fusedIndexMaps.reserve(producer->getNumOperands() +
366 consumer->getNumOperands());
369 auto consumerInputs = consumer.getDpsInputOperands();
370 auto *it = llvm::find_if(consumerInputs, [&](
OpOperand *operand) {
371 return operand == fusedOperand;
373 assert(it != consumerInputs.end() &&
"expected to find the consumer operand");
374 for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
375 fusedInputOperands.push_back(opOperand->get());
376 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
380 producer.getIndexingMapMatchingResult(producerResult);
381 for (
OpOperand *opOperand : producer.getDpsInputOperands()) {
382 fusedInputOperands.push_back(opOperand->get());
385 opOperand, producerResultIndexMap,
386 consumer.getMatchingIndexingMap(fusedOperand));
387 fusedIndexMaps.push_back(map);
392 llvm::make_range(std::next(it), consumerInputs.end())) {
393 fusedInputOperands.push_back(opOperand->get());
394 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
398 for (
const auto &opOperand :
llvm::enumerate(producer.getDpsInitsMutable())) {
399 if (!preservedProducerResults.count(opOperand.index()))
402 fusedOutputOperands.push_back(opOperand.value().get());
404 &opOperand.value(), producerResultIndexMap,
405 consumer.getMatchingIndexingMap(fusedOperand));
406 fusedIndexMaps.push_back(map);
407 fusedResultTypes.push_back(opOperand.value().get().getType());
411 for (
OpOperand &opOperand : consumer.getDpsInitsMutable()) {
412 fusedOutputOperands.push_back(opOperand.get());
413 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
414 Type resultType = opOperand.get().getType();
415 if (!isa<MemRefType>(resultType))
416 fusedResultTypes.push_back(resultType);
420 auto fusedOp = GenericOp::create(
421 rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
423 consumer.getIteratorTypes(),
426 if (!fusedOp.getShapesToLoopsMap()) {
432 fusedOp,
"fused op failed loop bound computation check");
438 consumer.getMatchingIndexingMap(fusedOperand);
442 assert(invProducerResultIndexMap &&
443 "expected producer result indexig map to be invertible");
446 invProducerResultIndexMap.
compose(consumerResultIndexMap);
449 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
450 consumer.getNumLoops(), preservedProducerResults);
454 for (
auto [index, producerResult] :
llvm::enumerate(producer->getResults()))
455 if (preservedProducerResults.count(index))
456 result.
replacements[producerResult] = fusedOp->getResult(resultNum++);
457 for (
auto consumerResult : consumer->getResults())
458 result.
replacements[consumerResult] = fusedOp->getResult(resultNum++);
469 controlFn(std::move(fun)) {}
471 LogicalResult matchAndRewrite(GenericOp genericOp,
474 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
477 if (!controlFn(&opOperand))
480 Operation *producer = opOperand.get().getDefiningOp();
483 FailureOr<ElementwiseOpFusionResult> fusionResult =
489 for (
auto [origVal, replacement] : fusionResult->replacements) {
492 return use.
get().getDefiningOp() != producer;
571 linalgOp.getIteratorTypesArray();
572 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573 return linalgOp.hasPureTensorSemantics() &&
574 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
576 return cast<AffineMapAttr>(attr)
578 .isProjectedPermutation();
586 class ExpansionInfo {
592 LogicalResult compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
596 unsigned getOrigOpNumDims()
const {
return reassociation.size(); }
597 unsigned getExpandedOpNumDims()
const {
return expandedOpNumDims; }
599 return reassociation[i];
602 return expandedShapeMap[i];
615 unsigned expandedOpNumDims;
619 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
624 if (reassociationMaps.empty())
626 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
630 originalLoopExtent = llvm::map_to_vector(
631 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
632 [](
Range r) { return r.size; });
634 reassociation.clear();
635 expandedShapeMap.clear();
639 expandedShapeMap.resize(fusedIndexMap.
getNumDims());
641 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
642 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
645 expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
646 expandedShapeMap[pos].assign(shape.begin(), shape.end());
649 for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
650 if (expandedShapeMap[i].empty())
651 expandedShapeMap[i] = {originalLoopExtent[i]};
655 reassociation.reserve(fusedIndexMap.
getNumDims());
657 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
658 reassociation.emplace_back(seq.begin(), seq.end());
659 sum += numFoldedDim.value();
661 expandedOpNumDims = sum;
669 const ExpansionInfo &expansionInfo) {
672 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
674 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
675 return builder.getAffineDimExpr(static_cast<unsigned>(v));
677 newExprs.append(expandedExprs.begin(), expandedExprs.end());
686 static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
688 const ExpansionInfo &expansionInfo) {
691 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
693 expansionInfo.getExpandedShapeOfDim(dim);
694 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
697 std::tie(expandedStaticShape, std::ignore) =
700 originalType.getElementType())};
711 const ExpansionInfo &expansionInfo) {
713 unsigned numReshapeDims = 0;
715 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
716 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
718 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
719 reassociation.emplace_back(std::move(indices));
720 numReshapeDims += numExpandedDims;
722 return reassociation;
732 const ExpansionInfo &expansionInfo) {
734 for (IndexOp indexOp :
735 llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
737 expansionInfo.getExpandedDims(indexOp.getDim());
738 assert(!expandedDims.empty() &&
"expected valid expansion info");
741 if (expandedDims.size() == 1 &&
742 expandedDims.front() == (int64_t)indexOp.getDim())
749 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
751 expandedIndices.reserve(expandedDims.size() - 1);
753 expandedDims.drop_front(), std::back_inserter(expandedIndices),
754 [&](int64_t dim) { return IndexOp::create(rewriter, loc, dim); });
756 IndexOp::create(rewriter, loc, expandedDims.front()).getResult();
757 for (
auto [expandedShape, expandedIndex] :
758 llvm::zip(expandedDimsShape, expandedIndices)) {
763 rewriter, indexOp.getLoc(), idx + acc * shape,
768 rewriter.
replaceOp(indexOp, newIndexVal);
790 TransposeOp transposeOp,
792 ExpansionInfo &expansionInfo) {
795 auto reassoc = expansionInfo.getExpandedDims(perm);
796 for (int64_t dim : reassoc) {
797 newPerm.push_back(dim);
800 return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput,
811 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
813 for (
auto [i, type] :
llvm::enumerate(linalgOp.getIteratorTypesArray()))
814 for (
auto j : expansionInfo.getExpandedDims(i))
815 iteratorTypes[
j] = type;
817 Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes,
818 expandedOpOperands, outputs,
819 expandedOpIndexingMaps, iteratorTypes);
822 Region &originalRegion = linalgOp->getRegion(0);
840 ExpansionInfo &expansionInfo) {
843 .Case<TransposeOp>([&](TransposeOp transposeOp) {
845 expandedOpOperands[0], outputs[0],
848 .Case<FillOp, CopyOp>([&](
Operation *op) {
849 return clone(rewriter, linalgOp, resultTypes,
850 llvm::to_vector(llvm::concat<Value>(
851 llvm::to_vector(expandedOpOperands),
852 llvm::to_vector(outputs))));
856 expandedOpOperands, outputs,
857 expansionInfo, expandedOpIndexingMaps);
864 static std::optional<SmallVector<Value>>
869 "preconditions for fuse operation failed");
875 if (
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
879 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
882 expandedShape = expandingReshapeOp.getMixedOutputShape();
883 reassociationIndices = expandingReshapeOp.getReassociationMaps();
884 src = expandingReshapeOp.getSrc();
886 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887 if (!collapsingReshapeOp)
891 rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
892 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
893 src = collapsingReshapeOp.getSrc();
896 ExpansionInfo expansionInfo;
897 if (
failed(expansionInfo.compute(linalgOp, fusableOpOperand,
898 reassociationIndices, expandedShape,
903 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](
AffineMap m) {
904 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
912 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
913 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
914 if (opOperand == fusableOpOperand) {
915 expandedOpOperands.push_back(src);
918 if (
auto opOperandType =
919 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
920 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
922 RankedTensorType expandedOperandType;
923 std::tie(expandedOperandShape, expandedOperandType) =
925 if (expandedOperandType != opOperand->get().getType()) {
930 [&](
const Twine &msg) {
933 opOperandType.getShape(), expandedOperandType.getShape(),
937 expandedOpOperands.push_back(tensor::ExpandShapeOp::create(
938 rewriter, loc, expandedOperandType, opOperand->get(), reassociation,
939 expandedOperandShape));
943 expandedOpOperands.push_back(opOperand->get());
947 for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
948 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
949 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
951 RankedTensorType expandedOutputType;
952 std::tie(expandedOutputShape, expandedOutputType) =
954 if (expandedOutputType != opOperand.get().getType()) {
958 [&](
const Twine &msg) {
961 opOperandType.getShape(), expandedOutputType.getShape(),
965 outputs.push_back(tensor::ExpandShapeOp::create(
966 rewriter, loc, expandedOutputType, opOperand.get(), reassociation,
967 expandedOutputShape));
969 outputs.push_back(opOperand.get());
976 outputs, expandedOpIndexingMaps, expansionInfo);
980 for (
OpResult opResult : linalgOp->getOpResults()) {
981 int64_t resultNumber = opResult.getResultNumber();
982 if (resultTypes[resultNumber] != opResult.getType()) {
985 linalgOp.getMatchingIndexingMap(
986 linalgOp.getDpsInitOperand(resultNumber)),
988 resultVals.push_back(tensor::CollapseShapeOp::create(
989 rewriter, linalgOp.getLoc(), opResult.
getType(),
990 fusedOp->
getResult(resultNumber), reassociation));
992 resultVals.push_back(fusedOp->
getResult(resultNumber));
1004 class FoldWithProducerReshapeOpByExpansion
1007 FoldWithProducerReshapeOpByExpansion(
MLIRContext *context,
1011 controlFoldingReshapes(std::move(foldReshapes)) {}
1013 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1015 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1016 tensor::CollapseShapeOp reshapeOp =
1017 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1024 (!controlFoldingReshapes(opOperand)))
1027 std::optional<SmallVector<Value>> replacementValues =
1029 if (!replacementValues)
1031 rewriter.
replaceOp(linalgOp, *replacementValues);
1041 class FoldPadWithProducerReshapeOpByExpansion
1044 FoldPadWithProducerReshapeOpByExpansion(
MLIRContext *context,
1048 controlFoldingReshapes(std::move(foldReshapes)) {}
1050 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1052 tensor::CollapseShapeOp reshapeOp =
1053 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1056 if (!reshapeOp->hasOneUse())
1059 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1061 "fusion blocked by control function");
1067 reshapeOp.getReassociationIndices();
1069 for (
auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1070 if (reInd.size() != 1 && (l != 0 || h != 0))
1075 RankedTensorType expandedType = reshapeOp.getSrcType();
1076 RankedTensorType paddedType = padOp.getResultType();
1079 if (reInd.size() == 1) {
1080 expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1082 for (
size_t i = 0; i < reInd.size(); ++i) {
1083 newLow.push_back(padOp.getMixedLowPad()[idx]);
1084 newHigh.push_back(padOp.getMixedHighPad()[idx]);
1089 RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1090 auto newPadOp = tensor::PadOp::create(
1091 rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1092 padOp.getConstantPaddingValue(), padOp.getNofold());
1095 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1106 struct FoldReshapeWithGenericOpByExpansion
1109 FoldReshapeWithGenericOpByExpansion(
MLIRContext *context,
1113 controlFoldingReshapes(std::move(foldReshapes)) {}
1115 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1118 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1119 if (!producerResult) {
1121 "source not produced by an operation");
1124 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1127 "producer not a generic op");
1132 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1134 reshapeOp,
"failed preconditions of fusion with producer generic op");
1137 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1139 "fusion blocked by control function");
1142 std::optional<SmallVector<Value>> replacementValues =
1144 producer, reshapeOp,
1145 producer.getDpsInitOperand(producerResult.getResultNumber()),
1147 if (!replacementValues) {
1149 "fusion by expansion failed");
1156 Value reshapeReplacement =
1157 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1158 .getResultNumber()];
1159 if (
auto collapseOp =
1160 reshapeReplacement.
getDefiningOp<tensor::CollapseShapeOp>()) {
1161 reshapeReplacement = collapseOp.getSrc();
1163 rewriter.
replaceOp(reshapeOp, reshapeReplacement);
1164 rewriter.
replaceOp(producer, *replacementValues);
1186 "expected projected permutation");
1189 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1190 return cast<AffineDimExpr>(indexingMap.
getResults()[pos]).getPosition();
1194 return domainReassociation;
1202 assert(!dimSequence.empty() &&
1203 "expected non-empty list for dimension sequence");
1205 "expected indexing map to be projected permutation");
1207 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1208 sequenceElements.insert_range(dimSequence);
1210 unsigned dimSequenceStart = dimSequence[0];
1212 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1214 if (dimInMapStart == dimSequenceStart) {
1215 if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
1218 for (
const auto &dimInSequence :
enumerate(dimSequence)) {
1220 cast<AffineDimExpr>(
1221 indexingMap.
getResult(expr.index() + dimInSequence.index()))
1223 if (dimInMap != dimInSequence.value())
1234 if (sequenceElements.count(dimInMapStart))
1243 return llvm::all_of(maps, [&](
AffineMap map) {
1300 if (!genericOp.hasPureTensorSemantics())
1303 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
1304 return map.isProjectedPermutation();
1311 genericOp.getReductionDims(reductionDims);
1313 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1314 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1315 auto iteratorTypes = genericOp.getIteratorTypesArray();
1318 assert(!foldedRangeDims.empty() &&
"unexpected empty reassociation");
1321 if (foldedRangeDims.size() == 1)
1329 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1330 return processedIterationDims.count(dim);
1335 utils::IteratorType startIteratorType =
1336 iteratorTypes[foldedIterationSpaceDims[0]];
1340 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1341 return iteratorTypes[dim] != startIteratorType;
1350 bool isContiguous =
false;
1353 if (startDim.value() != foldedIterationSpaceDims[0])
1357 if (startDim.index() + foldedIterationSpaceDims.size() >
1358 reductionDims.size())
1361 isContiguous =
true;
1362 for (
const auto &foldedDim :
1364 if (reductionDims[foldedDim.index() + startDim.index()] !=
1365 foldedDim.value()) {
1366 isContiguous =
false;
1377 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1379 return !isDimSequencePreserved(indexingMap,
1380 foldedIterationSpaceDims);
1384 processedIterationDims.insert_range(foldedIterationSpaceDims);
1385 iterationSpaceReassociation.emplace_back(
1386 std::move(foldedIterationSpaceDims));
1389 return iterationSpaceReassociation;
1394 class CollapsingInfo {
1396 LogicalResult initialize(
unsigned origNumLoops,
1398 llvm::SmallDenseSet<int64_t, 4> processedDims;
1401 if (foldedIterationDim.empty())
1405 for (
auto dim : foldedIterationDim) {
1406 if (dim >= origNumLoops)
1408 if (processedDims.count(dim))
1410 processedDims.insert(dim);
1412 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1413 foldedIterationDim.end());
1415 if (processedDims.size() > origNumLoops)
1420 for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1421 if (processedDims.count(dim))
1426 llvm::sort(collapsedOpToOrigOpIterationDim,
1428 return lhs[0] < rhs[0];
1430 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1431 for (
const auto &foldedDims :
1433 for (
const auto &dim :
enumerate(foldedDims.value()))
1434 origOpToCollapsedOpIterationDim[dim.value()] =
1435 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1442 return collapsedOpToOrigOpIterationDim;
1466 return origOpToCollapsedOpIterationDim;
1470 unsigned getCollapsedOpIterationRank()
const {
1471 return collapsedOpToOrigOpIterationDim.size();
1489 const CollapsingInfo &collapsingInfo) {
1492 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1493 assert(!foldedIterDims.empty() &&
1494 "reassociation indices expected to have non-empty sets");
1498 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1500 return collapsedIteratorTypes;
1507 const CollapsingInfo &collapsingInfo) {
1510 "expected indexing map to be projected permutation");
1512 auto origOpToCollapsedOpMapping =
1513 collapsingInfo.getOrigOpToCollapsedOpMapping();
1515 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1517 if (origOpToCollapsedOpMapping[dim].second != 0)
1521 resultExprs.push_back(
1524 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1525 resultExprs, context);
1532 const CollapsingInfo &collapsingInfo) {
1533 unsigned counter = 0;
1535 auto origOpToCollapsedOpMapping =
1536 collapsingInfo.getOrigOpToCollapsedOpMapping();
1537 auto collapsedOpToOrigOpMapping =
1538 collapsingInfo.getCollapsedOpToOrigOpMapping();
1541 cast<AffineDimExpr>(indexingMap.
getResult(counter)).getPosition();
1545 unsigned numFoldedDims =
1546 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1548 if (origOpToCollapsedOpMapping[dim].second == 0) {
1549 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1550 operandReassociation.emplace_back(range.begin(), range.end());
1552 counter += numFoldedDims;
1554 return operandReassociation;
1560 const CollapsingInfo &collapsingInfo,
1562 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1570 if (operandReassociation.size() == indexingMap.
getNumResults())
1574 if (isa<MemRefType>(operand.
getType())) {
1575 return memref::CollapseShapeOp::create(builder, loc, operand,
1576 operandReassociation)
1579 return tensor::CollapseShapeOp::create(builder, loc, operand,
1580 operandReassociation)
1587 Location loc,
Block *block,
const CollapsingInfo &collapsingInfo,
1593 auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
1603 for (
auto foldedDims :
1604 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1607 linalg::IndexOp::create(rewriter, loc, foldedDims.index());
1608 for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1611 indexReplacementVals[dim] =
1612 rewriter.
createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1614 rewriter.
createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1616 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1619 for (
auto indexOp : indexOps) {
1620 auto dim = indexOp.getDim();
1621 rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
1626 const CollapsingInfo &collapsingInfo,
1633 llvm::map_to_vector(op.getDpsInputOperands(), [&](
OpOperand *opOperand) {
1634 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1639 resultTypes.reserve(op.getNumDpsInits());
1640 outputOperands.reserve(op.getNumDpsInits());
1641 for (
OpOperand &output : op.getDpsInitsMutable()) {
1644 outputOperands.push_back(newOutput);
1647 if (!op.hasPureBufferSemantics())
1648 resultTypes.push_back(newOutput.
getType());
1653 template <
typename OpTy>
1655 const CollapsingInfo &collapsingInfo) {
1663 const CollapsingInfo &collapsingInfo) {
1667 outputOperands, resultTypes);
1670 rewriter, origOp, resultTypes,
1671 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1678 const CollapsingInfo &collapsingInfo) {
1682 outputOperands, resultTypes);
1684 llvm::map_range(origOp.getIndexingMapsArray(), [&](
AffineMap map) {
1685 return getCollapsedOpIndexingMap(map, collapsingInfo);
1689 origOp.getIteratorTypesArray(), collapsingInfo));
1691 GenericOp collapsedOp = linalg::GenericOp::create(
1692 rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands,
1693 indexingMaps, iteratorTypes,
1695 Block *origOpBlock = &origOp->getRegion(0).
front();
1696 Block *collapsedOpBlock = &collapsedOp->getRegion(0).
front();
1697 rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
1703 const CollapsingInfo &collapsingInfo,
1705 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1717 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1719 return foldedDims.size() <= 1;
1723 CollapsingInfo collapsingInfo;
1725 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1727 op,
"illegal to collapse specified dimensions");
1730 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1731 if (hasPureBufferSemantics &&
1732 !llvm::all_of(op->getOpOperands(), [&](
OpOperand &opOperand) ->
bool {
1733 MemRefType memRefToCollapse =
1734 dyn_cast<MemRefType>(opOperand.get().getType());
1735 if (!memRefToCollapse)
1738 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1739 SmallVector<ReassociationIndices> operandReassociation =
1740 getOperandReassociation(indexingMap, collapsingInfo);
1741 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1742 memRefToCollapse, operandReassociation);
1745 "memref is not guaranteed collapsible");
1749 auto opFoldIsConstantValue = [](
OpFoldResult ofr, int64_t value) {
1750 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1751 return cast<IntegerAttr>(attr).getInt() == value;
1754 actual.getSExtValue() == value;
1756 if (!llvm::all_of(loopRanges, [&](
Range range) {
1757 return opFoldIsConstantValue(range.
offset, 0) &&
1758 opFoldIsConstantValue(range.
stride, 1);
1761 op,
"expected all loop ranges to have zero start and unit stride");
1768 llvm::map_to_vector(loopRanges, [](
Range range) {
return range.
size; });
1770 if (collapsedOp.hasIndexSemantics()) {
1775 collapsingInfo, loopBound, rewriter);
1781 for (
const auto &originalResult :
llvm::enumerate(op->getResults())) {
1782 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1783 auto originalResultType =
1784 cast<ShapedType>(originalResult.value().getType());
1785 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.
getType());
1786 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1788 op.getIndexingMapMatchingResult(originalResult.value());
1793 "Expected indexing map to be a projected permutation for collapsing");
1797 if (isa<MemRefType>(collapsedOpResult.
getType())) {
1799 originalResultType.getShape(), originalResultType.getElementType());
1800 result = memref::ExpandShapeOp::create(
1801 rewriter, loc, expandShapeResultType, collapsedOpResult,
1802 reassociation, resultShape);
1804 result = tensor::ExpandShapeOp::create(
1805 rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1808 results.push_back(result);
1810 results.push_back(collapsedOpResult);
1820 class FoldWithProducerReshapeOpByCollapsing
1824 FoldWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1828 controlFoldingReshapes(std::move(foldReshapes)) {}
1830 LogicalResult matchAndRewrite(GenericOp genericOp,
1832 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
1833 tensor::ExpandShapeOp reshapeOp =
1834 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1840 reshapeOp.getReassociationIndices());
1841 if (collapsableIterationDims.empty() ||
1842 !controlFoldingReshapes(&opOperand)) {
1847 genericOp, collapsableIterationDims, rewriter);
1848 if (!collapseResult) {
1850 genericOp,
"failed to do the fusion by collapsing transformation");
1853 rewriter.
replaceOp(genericOp, collapseResult->results);
1865 struct FoldReshapeWithGenericOpByCollapsing
1868 FoldReshapeWithGenericOpByCollapsing(
MLIRContext *context,
1872 controlFoldingReshapes(std::move(foldReshapes)) {}
1874 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1878 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1879 if (!producerResult) {
1881 "source not produced by an operation");
1885 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1888 "producer not a generic op");
1894 producer.getDpsInitOperand(producerResult.getResultNumber()),
1895 reshapeOp.getReassociationIndices());
1896 if (collapsableIterationDims.empty()) {
1898 reshapeOp,
"failed preconditions of fusion with producer generic op");
1901 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1903 "fusion blocked by control function");
1909 std::optional<CollapseResult> collapseResult =
1911 if (!collapseResult) {
1913 producer,
"failed to do the fusion by collapsing transformation");
1916 rewriter.
replaceOp(producer, collapseResult->results);
1924 class FoldPadWithProducerReshapeOpByCollapsing
1927 FoldPadWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1931 controlFoldingReshapes(std::move(foldReshapes)) {}
1933 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1935 tensor::ExpandShapeOp reshapeOp =
1936 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1939 if (!reshapeOp->hasOneUse())
1942 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1944 "fusion blocked by control function");
1950 reshapeOp.getReassociationIndices();
1952 for (
auto reInd : reassociations) {
1953 if (reInd.size() == 1)
1955 if (llvm::any_of(reInd, [&](int64_t ind) {
1956 return low[ind] != 0 || high[ind] != 0;
1963 RankedTensorType collapsedType = reshapeOp.getSrcType();
1964 RankedTensorType paddedType = padOp.getResultType();
1968 reshapeOp.getOutputShape(), rewriter));
1972 Location loc = reshapeOp->getLoc();
1976 if (reInd.size() == 1) {
1977 collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1979 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1980 expandedPaddedSizes[reInd[0]] = paddedSize;
1982 newLow.push_back(l);
1983 newHigh.push_back(h);
1986 RankedTensorType collapsedPaddedType =
1987 paddedType.clone(collapsedPaddedShape);
1988 auto newPadOp = tensor::PadOp::create(
1989 rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1990 padOp.getConstantPaddingValue(), padOp.getNofold());
1993 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1994 expandedPaddedSizes);
2004 template <
typename LinalgType>
2011 controlCollapseDimension(std::move(collapseDimensions)) {}
2013 LogicalResult matchAndRewrite(LinalgType op,
2016 controlCollapseDimension(op);
2017 if (collapsableIterationDims.empty())
2022 collapsableIterationDims)) {
2024 op,
"specified dimensions cannot be collapsed");
2027 std::optional<CollapseResult> collapseResult =
2029 if (!collapseResult) {
2032 rewriter.
replaceOp(op, collapseResult->results);
2054 LogicalResult matchAndRewrite(GenericOp genericOp,
2056 if (!genericOp.hasPureTensorSemantics())
2058 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2059 Operation *def = opOperand->get().getDefiningOp();
2060 TypedAttr constantAttr;
2061 auto isScalarOrSplatConstantOp = [&constantAttr](
Operation *def) ->
bool {
2064 if (
matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2066 splatAttr.
getType().getElementType().isIntOrFloat()) {
2072 IntegerAttr intAttr;
2073 if (
matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2074 constantAttr = intAttr;
2079 FloatAttr floatAttr;
2080 if (
matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2081 constantAttr = floatAttr;
2088 auto resultValue = dyn_cast<OpResult>(opOperand->get());
2089 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2098 fusedIndexMaps.reserve(genericOp->getNumOperands());
2099 fusedOperands.reserve(genericOp.getNumDpsInputs());
2100 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2101 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2102 if (inputOperand == opOperand)
2104 Value inputValue = inputOperand->get();
2105 fusedIndexMaps.push_back(
2106 genericOp.getMatchingIndexingMap(inputOperand));
2107 fusedOperands.push_back(inputValue);
2108 fusedLocs.push_back(inputValue.
getLoc());
2110 for (
OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2111 fusedIndexMaps.push_back(
2112 genericOp.getMatchingIndexingMap(&outputOperand));
2118 genericOp,
"fused op loop bound computation failed");
2122 Value scalarConstant =
2123 arith::ConstantOp::create(rewriter, def->
getLoc(), constantAttr);
2127 GenericOp::create(rewriter, rewriter.
getFusedLoc(fusedLocs),
2128 genericOp->getResultTypes(),
2132 genericOp.getIteratorTypes(),
2138 Region ®ion = genericOp->getRegion(0);
2141 mapping.
map(entryBlock.
getArgument(opOperand->getOperandNumber()),
2143 Region &fusedRegion = fusedOp->getRegion(0);
2146 rewriter.
replaceOp(genericOp, fusedOp->getResults());
2167 LogicalResult matchAndRewrite(GenericOp op,
2170 bool modifiedOutput =
false;
2172 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
2173 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2174 Value operandVal = opOperand.get();
2175 auto operandType = dyn_cast<RankedTensorType>(operandVal.
getType());
2184 auto definingOp = operandVal.
getDefiningOp<tensor::EmptyOp>();
2187 modifiedOutput =
true;
2190 Value emptyTensor = tensor::EmptyOp::create(
2191 rewriter, loc, mixedSizes, operandType.getElementType());
2192 op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2195 if (!modifiedOutput) {
2208 LogicalResult matchAndRewrite(GenericOp genericOp,
2210 if (!genericOp.hasPureTensorSemantics())
2212 bool fillFound =
false;
2213 Block &payload = genericOp.getRegion().
front();
2214 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2215 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2217 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2221 Value fillVal = fillOp.value();
2223 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2224 Value convertedVal =
2228 payload.
getArgument(opOperand->getOperandNumber()), convertedVal);
2230 return success(fillFound);
2239 controlFoldingReshapes);
2240 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(
patterns.getContext(),
2241 controlFoldingReshapes);
2243 controlFoldingReshapes);
2250 controlFoldingReshapes);
2251 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2252 patterns.getContext(), controlFoldingReshapes);
2254 controlFoldingReshapes);
2260 auto *context =
patterns.getContext();
2261 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2262 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2263 RemoveOutsDependency>(context);
2271 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2272 CollapseLinalgDimensions<linalg::CopyOp>>(
2273 patterns.getContext(), controlCollapseDimensions);
2288 struct LinalgElementwiseOpFusionPass
2289 :
public impl::LinalgElementwiseOpFusionPassBase<
2290 LinalgElementwiseOpFusionPass> {
2291 using impl::LinalgElementwiseOpFusionPassBase<
2292 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2293 void runOnOperation()
override {
2300 Operation *producer = fusedOperand->get().getDefiningOp();
2301 return producer && producer->
hasOneUse();
2310 affine::AffineApplyOp::getCanonicalizationPatterns(
patterns, context);
2311 GenericOp::getCanonicalizationPatterns(
patterns, context);
2312 tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
2313 tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer, GenericOp consumer, ArrayRef< OpOperand * > opOperandsToIgnore)
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static Operation * createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, Value expandedInput, Value output, ExpansionInfo &expansionInfo)
static std::tuple< SmallVector< OpFoldResult >, RankedTensorType > getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the shape and type of the operand/result to use in the expanded op given the type in the origi...
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ArrayRef< OpFoldResult > loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static Operation * createExpandedGenericOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > &expandedOpOperands, ArrayRef< Value > outputs, ExpansionInfo &expansionInfo, ArrayRef< AffineMap > expandedOpIndexingMaps)
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
static LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
static void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
static Operation * createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, ArrayRef< Value > expandedOpOperands, ArrayRef< Value > outputs, ArrayRef< AffineMap > expandedOpIndexingMaps, ExpansionInfo &expansionInfo)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setUseTopDownTraversal(bool use=true)
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand)
Returns a set of indices of the producer's results which would be preserved after the fusion.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
ArrayRef< int64_t > ReassociationIndicesRef
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance)
Move definitions of values before an insertion point.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
llvm::DenseMap< Value, Value > replacements
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.