30 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
31 #include "mlir/Dialect/Linalg/Passes.h.inc"
57 assert(invProducerResultIndexMap &&
58 "expected producer result indexing map to be invertible");
60 LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
62 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
70 return t1.
compose(fusedConsumerArgIndexMap);
75 llvm::SmallDenseSet<int>
78 llvm::SmallDenseSet<int> preservedProducerResults;
79 for (
const auto &producerResult :
llvm::enumerate(producer->getResults())) {
80 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
81 if (producer.payloadUsesValueFromOperand(outputOperand) ||
82 !producer.canOpOperandsBeDropped(outputOperand) ||
83 llvm::any_of(producerResult.value().getUsers(), [&](
Operation *user) {
84 return user != consumer.getOperation();
86 preservedProducerResults.insert(producerResult.index());
89 return preservedProducerResults;
98 auto consumer = dyn_cast<GenericOp>(fusedOperand->
getOwner());
101 if (!producer || !consumer)
107 if (!producer.hasPureTensorSemantics() ||
108 !isa<RankedTensorType>(fusedOperand->
get().
getType()))
113 if (producer.getNumParallelLoops() != producer.getNumLoops())
118 if (!consumer.isDpsInput(fusedOperand))
123 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
124 if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
130 producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
138 if ((consumer.getNumReductionLoops())) {
139 BitVector coveredDims(consumer.getNumLoops(),
false);
141 auto addToCoveredDims = [&](
AffineMap map) {
142 for (
auto result : map.getResults())
143 if (
auto dimExpr = dyn_cast<AffineDimExpr>(result))
144 coveredDims[dimExpr.getPosition()] =
true;
148 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
149 Value operand = std::get<0>(pair);
150 if (operand == fusedOperand->
get())
152 AffineMap operandMap = std::get<1>(pair);
153 addToCoveredDims(operandMap);
156 for (
OpOperand *operand : producer.getDpsInputOperands()) {
159 operand, producerResultIndexMap, consumerIndexMap);
160 addToCoveredDims(newIndexingMap);
162 if (!coveredDims.all())
174 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
176 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
178 Block &producerBlock = producer->getRegion(0).
front();
179 Block &consumerBlock = consumer->getRegion(0).
front();
186 if (producer.hasIndexSemantics()) {
188 unsigned numFusedOpLoops =
189 std::max(producer.getNumLoops(), consumer.getNumLoops());
191 fusedIndices.reserve(numFusedOpLoops);
192 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
193 std::back_inserter(fusedIndices), [&](uint64_t dim) {
194 return rewriter.
create<IndexOp>(producer.getLoc(), dim);
196 for (IndexOp indexOp :
197 llvm::make_early_inc_range(producerBlock.
getOps<IndexOp>())) {
198 Value newIndex = rewriter.
create<affine::AffineApplyOp>(
200 consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
201 mapper.
map(indexOp.getResult(), newIndex);
205 assert(consumer.isDpsInput(fusedOperand) &&
206 "expected producer of input operand");
210 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
217 producerBlock.
getArguments().take_front(producer.getNumDpsInputs()))
218 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
223 .take_front(consumer.getNumDpsInputs())
225 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
229 producerBlock.
getArguments().take_back(producer.getNumDpsInits()))) {
230 if (!preservedProducerResults.count(bbArg.index()))
232 mapper.
map(bbArg.value(), fusedBlock->
addArgument(bbArg.value().getType(),
233 bbArg.value().getLoc()));
238 consumerBlock.
getArguments().take_back(consumer.getNumDpsInits()))
239 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
244 if (!isa<IndexOp>(op))
245 rewriter.
clone(op, mapper);
249 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.
getTerminator());
250 unsigned producerResultNumber =
251 cast<OpResult>(fusedOperand->
get()).getResultNumber();
253 mapper.
lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
257 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
258 if (
auto bb = dyn_cast<BlockArgument>(replacement))
259 assert(bb.getOwner() != &producerBlock &&
260 "yielded block argument must have been mapped");
263 "yielded value must have been mapped");
269 rewriter.
clone(op, mapper);
273 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.
getTerminator());
275 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
276 consumerYieldOp.getNumOperands());
277 for (
const auto &producerYieldVal :
279 if (preservedProducerResults.count(producerYieldVal.index()))
280 fusedYieldValues.push_back(
283 for (
auto consumerYieldVal : consumerYieldOp.getOperands())
285 rewriter.
create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
289 "Ill-formed GenericOp region");
292 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
296 "expected elementwise operation pre-conditions to pass");
297 auto producerResult = cast<OpResult>(fusedOperand->
get());
298 auto producer = cast<GenericOp>(producerResult.getOwner());
299 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
301 assert(consumer.isDpsInput(fusedOperand) &&
302 "expected producer of input operand");
304 llvm::SmallDenseSet<int> preservedProducerResults =
312 fusedInputOperands.reserve(producer.getNumDpsInputs() +
313 consumer.getNumDpsInputs());
314 fusedOutputOperands.reserve(preservedProducerResults.size() +
315 consumer.getNumDpsInits());
316 fusedResultTypes.reserve(preservedProducerResults.size() +
317 consumer.getNumDpsInits());
318 fusedIndexMaps.reserve(producer->getNumOperands() +
319 consumer->getNumOperands());
322 auto consumerInputs = consumer.getDpsInputOperands();
323 auto *it = llvm::find_if(consumerInputs, [&](
OpOperand *operand) {
324 return operand == fusedOperand;
326 assert(it != consumerInputs.end() &&
"expected to find the consumer operand");
327 for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
328 fusedInputOperands.push_back(opOperand->get());
329 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
333 producer.getIndexingMapMatchingResult(producerResult);
334 for (
OpOperand *opOperand : producer.getDpsInputOperands()) {
335 fusedInputOperands.push_back(opOperand->get());
338 opOperand, producerResultIndexMap,
339 consumer.getMatchingIndexingMap(fusedOperand));
340 fusedIndexMaps.push_back(map);
345 llvm::make_range(std::next(it), consumerInputs.end())) {
346 fusedInputOperands.push_back(opOperand->get());
347 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
351 for (
const auto &opOperand :
llvm::enumerate(producer.getDpsInitsMutable())) {
352 if (!preservedProducerResults.count(opOperand.index()))
355 fusedOutputOperands.push_back(opOperand.value().get());
357 &opOperand.value(), producerResultIndexMap,
358 consumer.getMatchingIndexingMap(fusedOperand));
359 fusedIndexMaps.push_back(map);
360 fusedResultTypes.push_back(opOperand.value().get().getType());
364 for (
OpOperand &opOperand : consumer.getDpsInitsMutable()) {
365 fusedOutputOperands.push_back(opOperand.get());
366 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
367 Type resultType = opOperand.get().getType();
368 if (!isa<MemRefType>(resultType))
369 fusedResultTypes.push_back(resultType);
373 auto fusedOp = rewriter.
create<GenericOp>(
374 consumer.getLoc(), fusedResultTypes, fusedInputOperands,
376 consumer.getIteratorTypes(),
379 if (!fusedOp.getShapesToLoopsMap()) {
385 fusedOp,
"fused op failed loop bound computation check");
391 consumer.getMatchingIndexingMap(fusedOperand);
395 assert(invProducerResultIndexMap &&
396 "expected producer result indexig map to be invertible");
399 invProducerResultIndexMap.
compose(consumerResultIndexMap);
402 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
403 consumer.getNumLoops(), preservedProducerResults);
407 for (
auto [index, producerResult] :
llvm::enumerate(producer->getResults()))
408 if (preservedProducerResults.count(index))
409 result.
replacements[producerResult] = fusedOp->getResult(resultNum++);
410 for (
auto consumerResult : consumer->getResults())
411 result.
replacements[consumerResult] = fusedOp->getResult(resultNum++);
422 controlFn(std::move(fun)) {}
424 LogicalResult matchAndRewrite(GenericOp genericOp,
427 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
430 if (!controlFn(&opOperand))
433 Operation *producer = opOperand.get().getDefiningOp();
436 FailureOr<ElementwiseOpFusionResult> fusionResult =
438 if (failed(fusionResult))
442 for (
auto [origVal, replacement] : fusionResult->replacements) {
445 return use.
get().getDefiningOp() != producer;
525 linalgOp.getIteratorTypesArray();
526 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
527 return linalgOp.hasPureTensorSemantics() &&
528 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
530 return cast<AffineMapAttr>(attr)
532 .isProjectedPermutation();
536 return isParallelIterator(
537 iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
544 class ExpansionInfo {
550 LogicalResult compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
555 unsigned getOrigOpNumDims()
const {
return reassociation.size(); }
556 unsigned getExpandedOpNumDims()
const {
return expandedOpNumDims; }
558 return reassociation[i];
561 return expandedShapeMap[i];
574 unsigned expandedOpNumDims;
578 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
584 if (reassociationMaps.empty())
586 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
589 originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
591 reassociation.clear();
592 expandedShapeMap.clear();
596 expandedShapeMap.resize(fusedIndexMap.
getNumDims());
598 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
599 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
602 expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
603 expandedShapeMap[pos].assign(shape.begin(), shape.end());
606 for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
607 if (expandedShapeMap[i].empty())
608 expandedShapeMap[i] = {originalLoopExtent[i]};
612 reassociation.reserve(fusedIndexMap.
getNumDims());
614 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
615 reassociation.emplace_back(seq.begin(), seq.end());
616 sum += numFoldedDim.value();
618 expandedOpNumDims = sum;
631 const ExpansionInfo &expansionInfo,
633 if (!linalgOp.hasIndexSemantics())
635 for (
unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
637 if (expandedShape.size() == 1)
639 for (int64_t shape : expandedShape.drop_front()) {
640 if (ShapedType::isDynamic(shape)) {
642 linalgOp,
"cannot expand due to index semantics and dynamic dims");
653 const ExpansionInfo &expansionInfo) {
656 unsigned pos = cast<AffineDimExpr>(expr).getPosition();
658 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
659 return builder.getAffineDimExpr(static_cast<unsigned>(v));
661 newExprs.append(expandedExprs.begin(), expandedExprs.end());
672 const ExpansionInfo &expansionInfo) {
675 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
676 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
677 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
690 const ExpansionInfo &expansionInfo) {
692 unsigned numReshapeDims = 0;
694 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
695 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
697 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
698 reassociation.emplace_back(std::move(indices));
699 numReshapeDims += numExpandedDims;
701 return reassociation;
711 const ExpansionInfo &expansionInfo) {
713 for (IndexOp indexOp :
714 llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
716 expansionInfo.getExpandedDims(indexOp.getDim());
717 assert(!expandedDims.empty() &&
"expected valid expansion info");
720 if (expandedDims.size() == 1 &&
721 expandedDims.front() == (int64_t)indexOp.getDim())
728 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
730 expandedIndices.reserve(expandedDims.size() - 1);
732 expandedDims.drop_front(), std::back_inserter(expandedIndices),
733 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
734 Value newIndex = rewriter.
create<IndexOp>(loc, expandedDims.front());
735 for (
auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
736 assert(!ShapedType::isDynamic(std::get<0>(it)));
739 newIndex = rewriter.
create<affine::AffineApplyOp>(
740 indexOp.getLoc(), idx + acc * std::get<0>(it),
751 const ExpansionInfo &expansionInfo,
753 for (
unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
755 if (expandedShape.size() == 1)
757 bool foundDynamic =
false;
758 for (int64_t shape : expandedShape) {
759 if (!ShapedType::isDynamic(shape))
763 linalgOp,
"cannot infer expanded shape with multiple dynamic "
764 "dims in the same reassociation group");
775 static std::optional<SmallVector<Value>>
780 "preconditions for fuse operation failed");
784 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
785 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
786 bool isExpanding = (expandingReshapeOp !=
nullptr);
787 RankedTensorType expandedType = isExpanding
788 ? expandingReshapeOp.getResultType()
789 : collapsingReshapeOp.getSrcType();
790 RankedTensorType collapsedType = isExpanding
791 ? expandingReshapeOp.getSrcType()
792 : collapsingReshapeOp.getResultType();
794 ExpansionInfo expansionInfo;
795 if (failed(expansionInfo.compute(
796 linalgOp, fusableOpOperand,
797 isExpanding ? expandingReshapeOp.getReassociationMaps()
798 : collapsingReshapeOp.getReassociationMaps(),
799 expandedType.getShape(), collapsedType.getShape(), rewriter)))
811 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](
AffineMap m) {
812 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
820 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
821 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
822 if (opOperand == fusableOpOperand) {
823 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
824 : collapsingReshapeOp.getSrc());
827 if (
auto opOperandType =
828 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
829 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
830 RankedTensorType expandedOperandType =
832 if (expandedOperandType != opOperand->get().getType()) {
837 [&](
const Twine &msg) {
840 opOperandType.getShape(), expandedOperandType.getShape(),
844 expandedOpOperands.push_back(rewriter.
create<tensor::ExpandShapeOp>(
845 loc, expandedOperandType, opOperand->get(), reassociation));
849 expandedOpOperands.push_back(opOperand->get());
853 for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
854 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
855 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
856 RankedTensorType expandedOutputType =
858 if (expandedOutputType != opOperand.get().getType()) {
862 [&](
const Twine &msg) {
865 opOperandType.getShape(), expandedOutputType.getShape(),
869 outputs.push_back(rewriter.
create<tensor::ExpandShapeOp>(
870 loc, expandedOutputType, opOperand.get(), reassociation));
872 outputs.push_back(opOperand.get());
878 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
879 for (
auto [i, type] :
llvm::enumerate(linalgOp.getIteratorTypesArray()))
880 for (
auto j : expansionInfo.getExpandedDims(i))
881 iteratorTypes[
j] = type;
885 rewriter.
create<GenericOp>(linalgOp.getLoc(), resultTypes,
886 expandedOpOperands, outputs,
887 expandedOpIndexingMaps, iteratorTypes);
888 Region &fusedRegion = fusedOp->getRegion(0);
889 Region &originalRegion = linalgOp->getRegion(0);
898 for (
OpResult opResult : linalgOp->getOpResults()) {
899 int64_t resultNumber = opResult.getResultNumber();
900 if (resultTypes[resultNumber] != opResult.getType()) {
903 linalgOp.getMatchingIndexingMap(
904 linalgOp.getDpsInitOperand(resultNumber)),
906 resultVals.push_back(rewriter.
create<tensor::CollapseShapeOp>(
907 linalgOp.getLoc(), opResult.getType(),
908 fusedOp->getResult(resultNumber), reassociation));
910 resultVals.push_back(fusedOp->getResult(resultNumber));
922 class FoldWithProducerReshapeOpByExpansion
925 FoldWithProducerReshapeOpByExpansion(
MLIRContext *context,
929 controlFoldingReshapes(std::move(foldReshapes)) {}
931 LogicalResult matchAndRewrite(LinalgOp linalgOp,
933 for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
934 tensor::CollapseShapeOp reshapeOp =
935 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
942 (!controlFoldingReshapes(opOperand)))
945 std::optional<SmallVector<Value>> replacementValues =
947 if (!replacementValues)
949 rewriter.
replaceOp(linalgOp, *replacementValues);
959 class FoldPadWithProducerReshapeOpByExpansion
962 FoldPadWithProducerReshapeOpByExpansion(
MLIRContext *context,
966 controlFoldingReshapes(std::move(foldReshapes)) {}
968 LogicalResult matchAndRewrite(tensor::PadOp padOp,
970 tensor::CollapseShapeOp reshapeOp =
971 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
974 if (!reshapeOp->hasOneUse())
977 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
979 "fusion blocked by control function");
985 reshapeOp.getReassociationIndices();
987 for (
auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
988 if (reInd.size() != 1 && (l != 0 || h != 0))
993 RankedTensorType expandedType = reshapeOp.getSrcType();
994 RankedTensorType paddedType = padOp.getResultType();
997 if (reInd.size() == 1) {
998 expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1000 for (
size_t i = 0; i < reInd.size(); ++i) {
1001 newLow.push_back(padOp.getMixedLowPad()[idx]);
1002 newHigh.push_back(padOp.getMixedHighPad()[idx]);
1007 RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1008 auto newPadOp = rewriter.
create<tensor::PadOp>(
1009 loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1010 padOp.getConstantPaddingValue(), padOp.getNofold());
1013 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1024 struct FoldReshapeWithGenericOpByExpansion
1027 FoldReshapeWithGenericOpByExpansion(
MLIRContext *context,
1031 controlFoldingReshapes(std::move(foldReshapes)) {}
1033 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1036 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1037 if (!producerResult) {
1039 "source not produced by an operation");
1042 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1045 "producer not a generic op");
1050 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1052 reshapeOp,
"failed preconditions of fusion with producer generic op");
1055 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1057 "fusion blocked by control function");
1060 std::optional<SmallVector<Value>> replacementValues =
1062 producer, reshapeOp,
1063 producer.getDpsInitOperand(producerResult.getResultNumber()),
1065 if (!replacementValues) {
1067 "fusion by expansion failed");
1074 Value reshapeReplacement =
1075 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1076 .getResultNumber()];
1077 if (
auto collapseOp =
1078 reshapeReplacement.
getDefiningOp<tensor::CollapseShapeOp>()) {
1079 reshapeReplacement = collapseOp.getSrc();
1081 rewriter.
replaceOp(reshapeOp, reshapeReplacement);
1082 rewriter.
replaceOp(producer, *replacementValues);
1104 "expected projected permutation");
1107 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1108 return cast<AffineDimExpr>(indexingMap.
getResults()[pos]).getPosition();
1112 return domainReassociation;
1120 assert(!dimSequence.empty() &&
1121 "expected non-empty list for dimension sequence");
1123 "expected indexing map to be projected permutation");
1125 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1126 sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1128 unsigned dimSequenceStart = dimSequence[0];
1130 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1132 if (dimInMapStart == dimSequenceStart) {
1133 if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
1136 for (
const auto &dimInSequence :
enumerate(dimSequence)) {
1138 cast<AffineDimExpr>(
1139 indexingMap.
getResult(expr.index() + dimInSequence.index()))
1141 if (dimInMap != dimInSequence.value())
1152 if (sequenceElements.count(dimInMapStart))
1161 return llvm::all_of(maps, [&](
AffineMap map) {
1218 if (!genericOp.hasPureTensorSemantics() || genericOp.getNumDpsInits() != 1)
1221 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
1222 return map.isProjectedPermutation();
1229 genericOp.getReductionDims(reductionDims);
1231 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1232 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1233 auto iteratorTypes = genericOp.getIteratorTypesArray();
1236 assert(!foldedRangeDims.empty() &&
"unexpected empty reassociation");
1239 if (foldedRangeDims.size() == 1)
1247 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1248 return processedIterationDims.count(dim);
1253 utils::IteratorType startIteratorType =
1254 iteratorTypes[foldedIterationSpaceDims[0]];
1258 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1259 return iteratorTypes[dim] != startIteratorType;
1268 bool isContiguous =
false;
1271 if (startDim.value() != foldedIterationSpaceDims[0])
1275 if (startDim.index() + foldedIterationSpaceDims.size() >
1276 reductionDims.size())
1279 isContiguous =
true;
1280 for (
const auto &foldedDim :
1282 if (reductionDims[foldedDim.index() + startDim.index()] !=
1283 foldedDim.value()) {
1284 isContiguous =
false;
1295 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1297 return !isDimSequencePreserved(indexingMap,
1298 foldedIterationSpaceDims);
1302 processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1303 foldedIterationSpaceDims.end());
1304 iterationSpaceReassociation.emplace_back(
1305 std::move(foldedIterationSpaceDims));
1308 return iterationSpaceReassociation;
1313 class CollapsingInfo {
1315 LogicalResult initialize(
unsigned origNumLoops,
1317 llvm::SmallDenseSet<int64_t, 4> processedDims;
1320 if (foldedIterationDim.empty())
1324 for (
auto dim : foldedIterationDim) {
1325 if (dim >= origNumLoops)
1327 if (processedDims.count(dim))
1329 processedDims.insert(dim);
1331 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1332 foldedIterationDim.end());
1334 if (processedDims.size() > origNumLoops)
1339 for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1340 if (processedDims.count(dim))
1345 llvm::sort(collapsedOpToOrigOpIterationDim,
1347 return lhs[0] < rhs[0];
1349 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1350 for (
const auto &foldedDims :
1352 for (
const auto &dim :
enumerate(foldedDims.value()))
1353 origOpToCollapsedOpIterationDim[dim.value()] =
1354 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1361 return collapsedOpToOrigOpIterationDim;
1385 return origOpToCollapsedOpIterationDim;
1389 unsigned getCollapsedOpIterationRank()
const {
1390 return collapsedOpToOrigOpIterationDim.size();
1408 const CollapsingInfo &collapsingInfo) {
1411 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1412 assert(!foldedIterDims.empty() &&
1413 "reassociation indices expected to have non-empty sets");
1417 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1419 return collapsedIteratorTypes;
1426 const CollapsingInfo &collapsingInfo) {
1429 "expected indexing map to be projected permutation");
1431 auto origOpToCollapsedOpMapping =
1432 collapsingInfo.getOrigOpToCollapsedOpMapping();
1434 unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1436 if (origOpToCollapsedOpMapping[dim].second != 0)
1440 resultExprs.push_back(
1443 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1444 resultExprs, context);
1451 const CollapsingInfo &collapsingInfo) {
1452 unsigned counter = 0;
1454 auto origOpToCollapsedOpMapping =
1455 collapsingInfo.getOrigOpToCollapsedOpMapping();
1456 auto collapsedOpToOrigOpMapping =
1457 collapsingInfo.getCollapsedOpToOrigOpMapping();
1460 cast<AffineDimExpr>(indexingMap.
getResult(counter)).getPosition();
1464 unsigned numFoldedDims =
1465 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1467 if (origOpToCollapsedOpMapping[dim].second == 0) {
1468 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1469 operandReassociation.emplace_back(range.begin(), range.end());
1471 counter += numFoldedDims;
1473 return operandReassociation;
1479 const CollapsingInfo &collapsingInfo,
1481 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1489 if (operandReassociation.size() == indexingMap.
getNumResults())
1493 if (isa<MemRefType>(operand.
getType())) {
1495 .
create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1499 .
create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1506 const CollapsingInfo &collapsingInfo,
1513 auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
1523 for (
auto foldedDims :
1524 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1527 rewriter.
create<linalg::IndexOp>(loc, foldedDims.index());
1528 for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1529 indexReplacementVals[dim] =
1530 rewriter.
create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1532 rewriter.
create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1534 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1537 for (
auto indexOp : indexOps) {
1538 auto dim = indexOp.getDim();
1539 rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
1544 const CollapsingInfo &collapsingInfo,
1551 llvm::map_to_vector(op.getDpsInputOperands(), [&](
OpOperand *opOperand) {
1552 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1557 resultTypes.reserve(op.getNumDpsInits());
1558 outputOperands.reserve(op.getNumDpsInits());
1559 for (
OpOperand &output : op.getDpsInitsMutable()) {
1562 outputOperands.push_back(newOutput);
1565 if (!op.hasPureBufferSemantics())
1566 resultTypes.push_back(newOutput.
getType());
1571 template <
typename OpTy>
1573 const CollapsingInfo &collapsingInfo) {
1581 const CollapsingInfo &collapsingInfo) {
1585 outputOperands, resultTypes);
1588 rewriter, origOp, resultTypes,
1589 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1596 const CollapsingInfo &collapsingInfo) {
1600 outputOperands, resultTypes);
1602 llvm::map_range(origOp.getIndexingMapsArray(), [&](
AffineMap map) {
1603 return getCollapsedOpIndexingMap(map, collapsingInfo);
1607 origOp.getIteratorTypesArray(), collapsingInfo));
1609 GenericOp collapsedOp = rewriter.
create<linalg::GenericOp>(
1610 origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1612 Block *origOpBlock = &origOp->getRegion(0).
front();
1613 Block *collapsedOpBlock = &collapsedOp->getRegion(0).
front();
1614 rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
1621 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1633 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1635 return foldedDims.size() <= 1;
1639 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1640 if (hasPureBufferSemantics &&
1642 MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1643 if (!memRefToCollapse)
1646 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1647 memRefToCollapse, foldedIterationDims);
1650 "memref is not guaranteed collapsible");
1652 CollapsingInfo collapsingInfo;
1654 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1656 op,
"illegal to collapse specified dimensions");
1661 auto opFoldIsConstantValue = [](
OpFoldResult ofr, int64_t value) {
1662 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1663 return cast<IntegerAttr>(attr).getInt() == value;
1666 actual.getSExtValue() == value;
1668 if (!llvm::all_of(loopRanges, [&](
Range range) {
1669 return opFoldIsConstantValue(range.
offset, 0) &&
1670 opFoldIsConstantValue(range.
stride, 1);
1673 op,
"expected all loop ranges to have zero start and unit stride");
1679 if (collapsedOp.hasIndexSemantics()) {
1684 llvm::map_to_vector(loopRanges, [&](
Range range) {
1688 collapsingInfo, loopBound, rewriter);
1695 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1696 auto originalResultType =
1697 cast<ShapedType>(originalResult.value().getType());
1698 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.
getType());
1699 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1701 op.getIndexingMapMatchingResult(originalResult.value());
1705 if (isa<MemRefType>(collapsedOpResult.
getType())) {
1707 originalResultType.getShape(), originalResultType.getElementType());
1708 result = rewriter.
create<memref::ExpandShapeOp>(
1709 loc, expandShapeResultType, collapsedOpResult, reassociation);
1711 result = rewriter.
create<tensor::ExpandShapeOp>(
1712 loc, originalResultType, collapsedOpResult, reassociation);
1714 results.push_back(result);
1716 results.push_back(collapsedOpResult);
1726 class FoldWithProducerReshapeOpByCollapsing
1729 FoldWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1733 controlFoldingReshapes(std::move(foldReshapes)) {}
1735 LogicalResult matchAndRewrite(GenericOp genericOp,
1737 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
1738 tensor::ExpandShapeOp reshapeOp =
1739 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1745 reshapeOp.getReassociationIndices());
1746 if (collapsableIterationDims.empty() ||
1747 !controlFoldingReshapes(&opOperand)) {
1752 genericOp, collapsableIterationDims, rewriter);
1753 if (!collapseResult) {
1755 genericOp,
"failed to do the fusion by collapsing transformation");
1758 rewriter.
replaceOp(genericOp, collapseResult->results);
1768 class FoldPadWithProducerReshapeOpByCollapsing
1771 FoldPadWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1775 controlFoldingReshapes(std::move(foldReshapes)) {}
1777 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1779 tensor::ExpandShapeOp reshapeOp =
1780 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1783 if (!reshapeOp->hasOneUse())
1786 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1788 "fusion blocked by control function");
1794 reshapeOp.getReassociationIndices();
1796 for (
auto reInd : reassociations) {
1797 if (reInd.size() == 1)
1799 if (llvm::any_of(reInd, [&](int64_t ind) {
1800 return low[ind] != 0 || high[ind] != 0;
1807 RankedTensorType collapsedType = reshapeOp.getSrcType();
1808 RankedTensorType paddedType = padOp.getResultType();
1812 reshapeOp.getOutputShape(), rewriter));
1816 Location loc = reshapeOp->getLoc();
1820 if (reInd.size() == 1) {
1821 collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1823 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1824 expandedPaddedSizes[reInd[0]] = paddedSize;
1826 newLow.push_back(l);
1827 newHigh.push_back(h);
1830 RankedTensorType collapsedPaddedType =
1831 paddedType.clone(collapsedPaddedShape);
1832 auto newPadOp = rewriter.
create<tensor::PadOp>(
1833 loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1834 padOp.getConstantPaddingValue(), padOp.getNofold());
1837 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1838 expandedPaddedSizes);
1848 template <
typename LinalgType>
1855 controlCollapseDimension(std::move(collapseDimensions)) {}
1857 LogicalResult matchAndRewrite(LinalgType op,
1860 controlCollapseDimension(op);
1861 if (collapsableIterationDims.empty())
1866 collapsableIterationDims)) {
1868 op,
"specified dimensions cannot be collapsed");
1871 std::optional<CollapseResult> collapseResult =
1873 if (!collapseResult) {
1876 rewriter.
replaceOp(op, collapseResult->results);
1898 LogicalResult matchAndRewrite(GenericOp genericOp,
1900 if (!genericOp.hasPureTensorSemantics())
1902 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1903 Operation *def = opOperand->get().getDefiningOp();
1904 TypedAttr constantAttr;
1905 auto isScalarOrSplatConstantOp = [&constantAttr](
Operation *def) ->
bool {
1908 if (
matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1910 splatAttr.
getType().getElementType().isIntOrFloat()) {
1916 IntegerAttr intAttr;
1917 if (
matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1918 constantAttr = intAttr;
1923 FloatAttr floatAttr;
1924 if (
matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1925 constantAttr = floatAttr;
1932 auto resultValue = dyn_cast<OpResult>(opOperand->get());
1933 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1942 fusedIndexMaps.reserve(genericOp->getNumOperands());
1943 fusedOperands.reserve(genericOp.getNumDpsInputs());
1944 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
1945 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
1946 if (inputOperand == opOperand)
1948 Value inputValue = inputOperand->get();
1949 fusedIndexMaps.push_back(
1950 genericOp.getMatchingIndexingMap(inputOperand));
1951 fusedOperands.push_back(inputValue);
1952 fusedLocs.push_back(inputValue.
getLoc());
1954 for (
OpOperand &outputOperand : genericOp.getDpsInitsMutable())
1955 fusedIndexMaps.push_back(
1956 genericOp.getMatchingIndexingMap(&outputOperand));
1961 genericOp,
"fused op loop bound computation failed");
1965 Value scalarConstant =
1966 rewriter.
create<arith::ConstantOp>(def->
getLoc(), constantAttr);
1969 auto fusedOp = rewriter.
create<GenericOp>(
1970 rewriter.
getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1974 genericOp.getIteratorTypes(),
1980 Region ®ion = genericOp->getRegion(0);
1983 mapping.
map(entryBlock.
getArgument(opOperand->getOperandNumber()),
1985 Region &fusedRegion = fusedOp->getRegion(0);
1988 rewriter.
replaceOp(genericOp, fusedOp->getResults());
2009 LogicalResult matchAndRewrite(GenericOp op,
2012 bool modifiedOutput =
false;
2014 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
2015 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2016 Value operandVal = opOperand.get();
2017 auto operandType = dyn_cast<RankedTensorType>(operandVal.
getType());
2026 auto definingOp = operandVal.
getDefiningOp<tensor::EmptyOp>();
2029 modifiedOutput =
true;
2032 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
2033 loc, mixedSizes, operandType.getElementType());
2034 op->
setOperand(opOperand.getOperandNumber(), emptyTensor);
2037 if (!modifiedOutput) {
2050 LogicalResult matchAndRewrite(GenericOp genericOp,
2052 if (!genericOp.hasPureTensorSemantics())
2054 bool fillFound =
false;
2055 Block &payload = genericOp.getRegion().
front();
2056 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2057 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2059 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2063 Value fillVal = fillOp.value();
2065 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2066 Value convertedVal =
2070 payload.
getArgument(opOperand->getOperandNumber()), convertedVal);
2072 return success(fillFound);
2080 patterns.
add<FoldReshapeWithGenericOpByExpansion>(patterns.
getContext(),
2081 controlFoldingReshapes);
2082 patterns.
add<FoldPadWithProducerReshapeOpByExpansion>(patterns.
getContext(),
2083 controlFoldingReshapes);
2084 patterns.
add<FoldWithProducerReshapeOpByExpansion>(patterns.
getContext(),
2085 controlFoldingReshapes);
2091 patterns.
add<FoldWithProducerReshapeOpByCollapsing>(patterns.
getContext(),
2092 controlFoldingReshapes);
2093 patterns.
add<FoldPadWithProducerReshapeOpByCollapsing>(
2094 patterns.
getContext(), controlFoldingReshapes);
2101 patterns.
add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2102 patterns.
add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2103 RemoveOutsDependency>(context);
2111 patterns.
add<CollapseLinalgDimensions<linalg::GenericOp>,
2112 CollapseLinalgDimensions<linalg::CopyOp>>(
2113 patterns.
getContext(), controlCollapseDimensions);
2128 struct LinalgElementwiseOpFusionPass
2129 :
public impl::LinalgElementwiseOpFusionPassBase<
2130 LinalgElementwiseOpFusionPass> {
2131 using impl::LinalgElementwiseOpFusionPassBase<
2132 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2133 void runOnOperation()
override {
2140 Operation *producer = fusedOperand->get().getDefiningOp();
2141 return producer && producer->
hasOneUse();
2149 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2150 GenericOp::getCanonicalizationPatterns(patterns, context);
2151 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2152 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, const CollapsingInfo &collapsingInfo)
Clone a LinalgOp to a collapsed version of same name.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Expanding the body of a linalg operation requires adaptations of the accessed loop indices.
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
GenericOp cloneToCollapsedOp< GenericOp >(RewriterBase &rewriter, GenericOp origOp, const CollapsingInfo &collapsingInfo)
Collapse a GenericOp
void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ValueRange loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, OpOperand *fusableOpOperand)
Conditions for folding a structured linalg operation with a reshape op by expanding the iteration spa...
static LogicalResult validateDynamicDimExpansion(LinalgOp linalgOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Checks if a single dynamic dimension expanded into multiple dynamic dimensions.
void collapseOperandsAndResults(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter, SmallVectorImpl< Value > &inputOperands, SmallVectorImpl< Value > &outputOperands, SmallVectorImpl< Type > &resultTypes)
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter)
static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the type of the operand/result to use in the expanded op given the type in the original op.
LinalgOp cloneToCollapsedOp< LinalgOp >(RewriterBase &rewriter, LinalgOp origOp, const CollapsingInfo &collapsingInfo)
Collapse any LinalgOp that does not require any specialization such as indexing_maps,...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value getCollapsedOpOperand(Location loc, LinalgOp op, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
bool hasOneUse()
Returns true if this operation has exactly one use.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
std::function< SmallVector< ReassociationIndices >(linalg::LinalgOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
ArrayRef< int64_t > ReassociationIndicesRef
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
llvm::DenseMap< Value, Value > replacements
static llvm::SmallDenseSet< int > getPreservedProducerResults(GenericOp producer, GenericOp consumer)
Returns a set of indices of the producer's results which would be preserved after the fusion.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.