30 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS
31 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSION
32 #include "mlir/Dialect/Linalg/Passes.h.inc"
58 assert(invProducerResultIndexMap &&
59 "expected producer result indexing map to be invertible");
61 LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
63 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
71 return t1.
compose(fusedConsumerArgIndexMap);
80 auto consumer = dyn_cast<GenericOp>(fusedOperand->
getOwner());
83 if (!producer || !consumer)
89 if (!producer.hasTensorSemantics() ||
90 !isa<RankedTensorType>(fusedOperand->
get().
getType()))
95 if (producer.getNumParallelLoops() != producer.getNumLoops())
100 if (!consumer.isDpsInput(fusedOperand))
105 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
106 if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
112 producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
120 if ((consumer.getNumReductionLoops())) {
121 BitVector coveredDims(consumer.getNumLoops(),
false);
123 auto addToCoveredDims = [&](
AffineMap map) {
124 for (
auto result : map.getResults())
126 coveredDims[dimExpr.getPosition()] =
true;
130 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
131 Value operand = std::get<0>(pair);
132 if (operand == fusedOperand->
get())
134 AffineMap operandMap = std::get<1>(pair);
135 addToCoveredDims(operandMap);
138 for (
OpOperand *operand : producer.getDpsInputOperands()) {
141 operand, producerResultIndexMap, consumerIndexMap);
142 addToCoveredDims(newIndexingMap);
144 if (!coveredDims.all())
156 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
158 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
160 Block &producerBlock = producer->getRegion(0).
front();
161 Block &consumerBlock = consumer->getRegion(0).
front();
163 fusedOp.getRegion().push_back(fusedBlock);
170 if (producer.hasIndexSemantics()) {
172 unsigned numFusedOpLoops =
173 std::max(producer.getNumLoops(), consumer.getNumLoops());
175 fusedIndices.reserve(numFusedOpLoops);
176 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
177 std::back_inserter(fusedIndices), [&](uint64_t dim) {
178 return rewriter.
create<IndexOp>(producer.getLoc(), dim);
180 for (IndexOp indexOp :
181 llvm::make_early_inc_range(producerBlock.
getOps<IndexOp>())) {
182 Value newIndex = rewriter.
create<affine::AffineApplyOp>(
184 consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
185 mapper.
map(indexOp.getResult(), newIndex);
189 assert(consumer.isDpsInput(fusedOperand) &&
190 "expected producer of input operand");
194 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
201 producerBlock.
getArguments().take_front(producer.getNumDpsInputs()))
202 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
207 .take_front(consumer.getNumDpsInputs())
209 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
213 producerBlock.
getArguments().take_back(producer.getNumDpsInits()))) {
214 if (!preservedProducerResults.count(bbArg.index()))
216 mapper.
map(bbArg.value(), fusedBlock->
addArgument(bbArg.value().getType(),
217 bbArg.value().getLoc()));
222 consumerBlock.
getArguments().take_back(consumer.getNumDpsInits()))
223 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
228 if (!isa<IndexOp>(op))
229 rewriter.
clone(op, mapper);
233 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.
getTerminator());
234 unsigned producerResultNumber =
235 cast<OpResult>(fusedOperand->
get()).getResultNumber();
237 mapper.
lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
241 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
242 if (
auto bb = dyn_cast<BlockArgument>(replacement))
243 assert(bb.getOwner() != &producerBlock &&
244 "yielded block argument must have been mapped");
247 "yielded value must have been mapped");
253 rewriter.
clone(op, mapper);
257 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.
getTerminator());
259 fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
260 consumerYieldOp.getNumOperands());
261 for (
const auto &producerYieldVal :
263 if (preservedProducerResults.count(producerYieldVal.index()))
264 fusedYieldValues.push_back(
267 for (
auto consumerYieldVal : consumerYieldOp.getOperands())
269 rewriter.
create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
273 "Ill-formed GenericOp region");
280 "expected elementwise operation pre-conditions to pass");
281 auto producerResult = cast<OpResult>(fusedOperand->
get());
282 auto producer = cast<GenericOp>(producerResult.getOwner());
283 auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
285 assert(consumer.isDpsInput(fusedOperand) &&
286 "expected producer of input operand");
288 llvm::SmallDenseSet<int> preservedProducerResults;
289 for (
const auto &producerResult :
llvm::enumerate(producer->getResults())) {
290 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
291 if (producer.payloadUsesValueFromOperand(outputOperand) ||
292 !producer.canOpOperandsBeDropped(outputOperand) ||
293 llvm::any_of(producerResult.value().getUsers(), [&](
Operation *user) {
294 return user != consumer.getOperation();
296 preservedProducerResults.insert(producerResult.index());
304 fusedInputOperands.reserve(producer.getNumDpsInputs() +
305 consumer.getNumDpsInputs());
306 fusedOutputOperands.reserve(preservedProducerResults.size() +
307 consumer.getNumDpsInits());
308 fusedResultTypes.reserve(preservedProducerResults.size() +
309 consumer.getNumDpsInits());
310 fusedIndexMaps.reserve(producer->getNumOperands() +
311 consumer->getNumOperands());
314 auto consumerInputs = consumer.getDpsInputOperands();
315 auto *it = llvm::find_if(consumerInputs, [&](
OpOperand *operand) {
316 return operand == fusedOperand;
318 assert(it != consumerInputs.end() &&
"expected to find the consumer operand");
319 for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
320 fusedInputOperands.push_back(opOperand->get());
321 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
325 producer.getIndexingMapMatchingResult(producerResult);
326 for (
OpOperand *opOperand : producer.getDpsInputOperands()) {
327 fusedInputOperands.push_back(opOperand->get());
330 opOperand, producerResultIndexMap,
331 consumer.getMatchingIndexingMap(fusedOperand));
332 fusedIndexMaps.push_back(map);
337 llvm::make_range(std::next(it), consumerInputs.end())) {
338 fusedInputOperands.push_back(opOperand->get());
339 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
343 for (
const auto &opOperand :
llvm::enumerate(producer.getDpsInitsMutable())) {
344 if (!preservedProducerResults.count(opOperand.index()))
347 fusedOutputOperands.push_back(opOperand.value().get());
349 &opOperand.value(), producerResultIndexMap,
350 consumer.getMatchingIndexingMap(fusedOperand));
351 fusedIndexMaps.push_back(map);
352 fusedResultTypes.push_back(opOperand.value().get().getType());
356 for (
OpOperand &opOperand : consumer.getDpsInitsMutable()) {
357 fusedOutputOperands.push_back(opOperand.get());
358 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
359 Type resultType = opOperand.get().getType();
360 if (!isa<MemRefType>(resultType))
361 fusedResultTypes.push_back(resultType);
365 auto fusedOp = rewriter.
create<GenericOp>(
366 consumer.getLoc(), fusedResultTypes, fusedInputOperands,
368 consumer.getIteratorTypes(),
371 if (!fusedOp.getShapesToLoopsMap()) {
377 fusedOp,
"fused op failed loop bound computation check");
383 consumer.getMatchingIndexingMap(fusedOperand);
387 assert(invProducerResultIndexMap &&
388 "expected producer result indexig map to be invertible");
391 invProducerResultIndexMap.
compose(consumerResultIndexMap);
394 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
395 consumer.getNumLoops(), preservedProducerResults);
399 for (
auto [index, producerResult] :
llvm::enumerate(producer->getResults()))
400 if (preservedProducerResults.count(index))
401 result.
replacements[producerResult] = fusedOp->getResult(resultNum++);
402 for (
auto consumerResult : consumer->getResults())
403 result.
replacements[consumerResult] = fusedOp->getResult(resultNum++);
414 controlFn(std::move(fun)) {}
419 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
422 if (!controlFn(&opOperand))
430 Operation *producer = opOperand.get().getDefiningOp();
439 for (
auto [origVal, replacement] : fusionResult->replacements) {
442 return use.
get().getDefiningOp() != producer;
521 return genericOp.hasTensorSemantics() &&
522 llvm::all_of(genericOp.getIndexingMaps().getValue(),
524 return cast<AffineMapAttr>(attr)
526 .isProjectedPermutation();
528 genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
536 class ExpansionInfo {
547 unsigned getOrigOpNumDims()
const {
return reassociation.size(); }
548 unsigned getExpandedOpNumDims()
const {
return expandedOpNumDims; }
550 return reassociation[i];
553 return expandedShapeMap[i];
566 unsigned expandedOpNumDims;
576 if (reassociationMaps.empty())
578 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
581 originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
583 reassociation.clear();
584 expandedShapeMap.clear();
588 expandedShapeMap.resize(fusedIndexMap.
getNumDims());
590 unsigned pos = resultExpr.value().cast<
AffineDimExpr>().getPosition();
591 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
594 expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
595 expandedShapeMap[pos].assign(shape.begin(), shape.end());
598 for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
599 if (expandedShapeMap[i].empty())
600 expandedShapeMap[i] = {originalLoopExtent[i]};
604 reassociation.reserve(fusedIndexMap.
getNumDims());
606 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
607 reassociation.emplace_back(seq.begin(), seq.end());
608 sum += numFoldedDim.value();
610 expandedOpNumDims = sum;
623 const ExpansionInfo &expansionInfo,
625 if (!genericOp.hasIndexSemantics())
627 for (
unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
629 if (expandedShape.size() == 1)
631 for (int64_t shape : expandedShape.drop_front()) {
632 if (ShapedType::isDynamic(shape)) {
634 genericOp,
"cannot expand due to index semantics and dynamic dims");
645 const ExpansionInfo &expansionInfo) {
650 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
651 return builder.getAffineDimExpr(static_cast<unsigned>(v));
653 newExprs.append(expandedExprs.begin(), expandedExprs.end());
664 const ExpansionInfo &expansionInfo) {
668 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
669 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
682 const ExpansionInfo &expansionInfo) {
684 unsigned numReshapeDims = 0;
687 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
689 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
690 reassociation.emplace_back(std::move(indices));
691 numReshapeDims += numExpandedDims;
693 return reassociation;
703 const ExpansionInfo &expansionInfo) {
705 for (IndexOp indexOp :
706 llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
708 expansionInfo.getExpandedDims(indexOp.getDim());
709 assert(!expandedDims.empty() &&
"expected valid expansion info");
712 if (expandedDims.size() == 1 &&
713 expandedDims.front() == (int64_t)indexOp.getDim())
720 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
722 expandedIndices.reserve(expandedDims.size() - 1);
724 expandedDims.drop_front(), std::back_inserter(expandedIndices),
725 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
726 Value newIndex = rewriter.
create<IndexOp>(loc, expandedDims.front());
727 for (
auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
728 assert(!ShapedType::isDynamic(std::get<0>(it)));
731 newIndex = rewriter.
create<affine::AffineApplyOp>(
732 indexOp.getLoc(), idx + acc * std::get<0>(it),
742 static std::optional<SmallVector<Value>>
747 "preconditions for fuse operation failed");
749 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
750 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
751 bool isExpanding = (expandingReshapeOp !=
nullptr);
752 RankedTensorType expandedType = isExpanding
753 ? expandingReshapeOp.getResultType()
754 : collapsingReshapeOp.getSrcType();
755 RankedTensorType collapsedType = isExpanding
756 ? expandingReshapeOp.getSrcType()
757 : collapsingReshapeOp.getResultType();
759 ExpansionInfo expansionInfo;
760 if (
failed(expansionInfo.compute(
761 genericOp, fusableOpOperand,
762 isExpanding ? expandingReshapeOp.getReassociationMaps()
763 : collapsingReshapeOp.getReassociationMaps(),
764 expandedType.getShape(), collapsedType.getShape(), rewriter)))
771 llvm::map_range(genericOp.getIndexingMapsArray(), [&](
AffineMap m) {
772 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
780 expandedOpOperands.reserve(genericOp.getNumDpsInputs());
781 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
782 if (opOperand == fusableOpOperand) {
783 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
784 : collapsingReshapeOp.getSrc());
787 if (
auto opOperandType =
788 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
789 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
790 RankedTensorType expandedOperandType =
792 if (expandedOperandType != opOperand->get().getType()) {
797 [&](
const Twine &msg) {
800 opOperandType.getShape(), expandedOperandType.getShape(),
804 expandedOpOperands.push_back(rewriter.
create<tensor::ExpandShapeOp>(
805 genericOp.getLoc(), expandedOperandType, opOperand->get(),
810 expandedOpOperands.push_back(opOperand->get());
815 for (
OpOperand &opOperand : genericOp.getDpsInitsMutable()) {
816 AffineMap indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
817 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
818 RankedTensorType expandedOutputType =
820 if (expandedOutputType != opOperand.get().getType()) {
824 [&](
const Twine &msg) {
827 opOperandType.getShape(), expandedOutputType.getShape(),
831 outputs.push_back(rewriter.
create<tensor::ExpandShapeOp>(
832 genericOp.getLoc(), expandedOutputType, opOperand.get(),
835 outputs.push_back(opOperand.get());
841 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
845 rewriter.
create<GenericOp>(genericOp.getLoc(), resultTypes,
846 expandedOpOperands, outputs,
847 expandedOpIndexingMaps, iteratorTypes);
848 Region &fusedRegion = fusedOp->getRegion(0);
849 Region &originalRegion = genericOp->getRegion(0);
858 for (
OpResult opResult : genericOp->getOpResults()) {
859 int64_t resultNumber = opResult.getResultNumber();
860 if (resultTypes[resultNumber] != opResult.getType()) {
863 genericOp.getMatchingIndexingMap(
864 genericOp.getDpsInitOperand(resultNumber)),
866 resultVals.push_back(rewriter.
create<tensor::CollapseShapeOp>(
867 genericOp.getLoc(), opResult.getType(),
868 fusedOp->getResult(resultNumber), reassociation));
870 resultVals.push_back(fusedOp->getResult(resultNumber));
882 class FoldWithProducerReshapeOpByExpansion
885 FoldWithProducerReshapeOpByExpansion(
MLIRContext *context,
889 controlFoldingReshapes(std::move(foldReshapes)) {}
893 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
894 tensor::CollapseShapeOp reshapeOp =
895 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
902 (!controlFoldingReshapes(opOperand)))
905 std::optional<SmallVector<Value>> replacementValues =
907 if (!replacementValues)
909 rewriter.
replaceOp(genericOp, *replacementValues);
921 struct FoldReshapeWithGenericOpByExpansion
924 FoldReshapeWithGenericOpByExpansion(
MLIRContext *context,
928 controlFoldingReshapes(std::move(foldReshapes)) {}
930 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
933 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
934 if (!producerResult) {
936 "source not produced by an operation");
939 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
942 "producer not a generic op");
947 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
949 reshapeOp,
"failed preconditions of fusion with producer generic op");
952 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable()[0])) {
954 "fusion blocked by control function");
957 std::optional<SmallVector<Value>> replacementValues =
960 producer.getDpsInitOperand(producerResult.getResultNumber()),
962 if (!replacementValues) {
964 "fusion by expansion failed");
971 Value reshapeReplacement =
972 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
974 if (
auto collapseOp =
975 reshapeReplacement.
getDefiningOp<tensor::CollapseShapeOp>()) {
976 reshapeReplacement = collapseOp.getSrc();
978 rewriter.
replaceOp(reshapeOp, reshapeReplacement);
979 rewriter.
replaceOp(producer, *replacementValues);
1001 "expected projected permutation");
1004 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1006 .cast<AffineDimExpr>()
1011 return domainReassociation;
1019 assert(!dimSequence.empty() &&
1020 "expected non-empty list for dimension sequence");
1022 "expected indexing map to be projected permutation");
1024 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1025 sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1027 unsigned dimSequenceStart = dimSequence[0];
1029 unsigned dimInMapStart = expr.value().cast<
AffineDimExpr>().getPosition();
1031 if (dimInMapStart == dimSequenceStart) {
1032 if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
1035 for (
const auto &dimInSequence :
enumerate(dimSequence)) {
1037 indexingMap.
getResult(expr.index() + dimInSequence.index())
1040 if (dimInMap != dimInSequence.value())
1051 if (sequenceElements.count(dimInMapStart))
1060 return llvm::all_of(maps, [&](
AffineMap map) {
1117 if (!genericOp.hasTensorSemantics() || genericOp.getNumDpsInits() != 1)
1120 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
1121 return map.isProjectedPermutation();
1128 genericOp.getReductionDims(reductionDims);
1130 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1131 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1132 auto iteratorTypes = genericOp.getIteratorTypesArray();
1135 assert(!foldedRangeDims.empty() &&
"unexpected empty reassociation");
1138 if (foldedRangeDims.size() == 1)
1146 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1147 return processedIterationDims.count(dim);
1152 utils::IteratorType startIteratorType =
1153 iteratorTypes[foldedIterationSpaceDims[0]];
1157 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1158 return iteratorTypes[dim] != startIteratorType;
1167 bool isContiguous =
false;
1170 if (startDim.value() != foldedIterationSpaceDims[0])
1174 if (startDim.index() + foldedIterationSpaceDims.size() >
1175 reductionDims.size())
1178 isContiguous =
true;
1179 for (
const auto &foldedDim :
1181 if (reductionDims[foldedDim.index() + startDim.index()] !=
1182 foldedDim.value()) {
1183 isContiguous =
false;
1194 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1196 return !isDimSequencePreserved(indexingMap,
1197 foldedIterationSpaceDims);
1201 processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1202 foldedIterationSpaceDims.end());
1203 iterationSpaceReassociation.emplace_back(
1204 std::move(foldedIterationSpaceDims));
1207 return iterationSpaceReassociation;
1212 class CollapsingInfo {
1216 llvm::SmallDenseSet<int64_t, 4> processedDims;
1219 if (foldedIterationDim.empty())
1223 for (
auto dim : foldedIterationDim) {
1224 if (dim >= origNumLoops)
1226 if (processedDims.count(dim))
1228 processedDims.insert(dim);
1230 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1231 foldedIterationDim.end());
1233 if (processedDims.size() > origNumLoops)
1238 for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1239 if (processedDims.count(dim))
1244 llvm::sort(collapsedOpToOrigOpIterationDim,
1246 return lhs[0] < rhs[0];
1248 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1249 for (
const auto &foldedDims :
1251 for (
const auto &dim :
enumerate(foldedDims.value()))
1252 origOpToCollapsedOpIterationDim[dim.value()] =
1253 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1260 return collapsedOpToOrigOpIterationDim;
1284 return origOpToCollapsedOpIterationDim;
1288 unsigned getCollapsedOpIterationRank()
const {
1289 return collapsedOpToOrigOpIterationDim.size();
1307 const CollapsingInfo &collapsingInfo) {
1310 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1311 assert(!foldedIterDims.empty() &&
1312 "reassociation indices expected to have non-empty sets");
1316 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1318 return collapsedIteratorTypes;
1325 const CollapsingInfo &collapsingInfo) {
1328 "expected indexing map to be projected permutation");
1330 auto origOpToCollapsedOpMapping =
1331 collapsingInfo.getOrigOpToCollapsedOpMapping();
1335 if (origOpToCollapsedOpMapping[dim].second != 0)
1339 resultExprs.push_back(
1342 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1343 resultExprs, context);
1350 const CollapsingInfo &collapsingInfo) {
1351 unsigned counter = 0;
1353 auto origOpToCollapsedOpMapping =
1354 collapsingInfo.getOrigOpToCollapsedOpMapping();
1355 auto collapsedOpToOrigOpMapping =
1356 collapsingInfo.getCollapsedOpToOrigOpMapping();
1363 unsigned numFoldedDims =
1364 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1366 if (origOpToCollapsedOpMapping[dim].second == 0) {
1367 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1368 operandReassociation.emplace_back(range.begin(), range.end());
1370 counter += numFoldedDims;
1372 return operandReassociation;
1378 const CollapsingInfo &collapsingInfo,
1380 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
1387 if (operandReassociation.size() == indexingMap.
getNumResults())
1391 auto reshapeOp = builder.
create<tensor::CollapseShapeOp>(
1392 loc, operand, operandReassociation);
1399 const CollapsingInfo &collapsingInfo,
1406 auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
1416 for (
auto foldedDims :
1417 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1420 rewriter.
create<linalg::IndexOp>(loc, foldedDims.index());
1421 for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1422 indexReplacementVals[dim] =
1423 rewriter.
create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1425 rewriter.
create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1427 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1430 for (
auto indexOp : indexOps) {
1431 auto dim = indexOp.getDim();
1432 rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
1441 if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1443 return foldedDims.size() <= 1;
1447 CollapsingInfo collapsingInfo;
1448 if (
failed(collapsingInfo.initialize(genericOp.getNumLoops(),
1449 foldedIterationDims))) {
1451 genericOp,
"illegal to collapse specified dimensions");
1456 cast<LinalgOp>(genericOp.getOperation())
1457 .createLoopRanges(rewriter, genericOp.getLoc());
1458 auto opFoldIsConstantValue = [](
OpFoldResult ofr, int64_t value) {
1459 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1460 return cast<IntegerAttr>(attr).getInt() == value;
1463 actual.getSExtValue() == value;
1465 if (!llvm::all_of(loopRanges, [&](
Range range) {
1466 return opFoldIsConstantValue(range.
offset, 0) &&
1467 opFoldIsConstantValue(range.
stride, 1);
1471 "expected all loop ranges to have zero start and unit stride");
1476 genericOp.getIteratorTypesArray(), collapsingInfo);
1479 auto indexingMaps = llvm::to_vector(
1480 llvm::map_range(genericOp.getIndexingMapsArray(), [&](
AffineMap map) {
1481 return getCollapsedOpIndexingMap(map, collapsingInfo);
1484 Location loc = genericOp->getLoc();
1487 auto inputOperands = llvm::to_vector(llvm::map_range(
1488 genericOp.getDpsInputOperands(), [&](
OpOperand *opOperand) {
1489 return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
1496 resultTypes.reserve(genericOp.getNumDpsInits());
1497 outputOperands.reserve(genericOp.getNumDpsInits());
1498 for (
OpOperand &output : genericOp.getDpsInitsMutable()) {
1500 collapsingInfo, rewriter);
1501 outputOperands.push_back(newOutput);
1502 resultTypes.push_back(newOutput.
getType());
1506 auto collapsedGenericOp = rewriter.
create<linalg::GenericOp>(
1507 loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1509 Block *origOpBlock = &genericOp->getRegion(0).
front();
1510 Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).
front();
1511 rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
1514 if (collapsedGenericOp.hasIndexSemantics()) {
1519 llvm::to_vector(llvm::map_range(loopRanges, [&](
Range range) {
1523 &collapsedGenericOp->getRegion(0).front(),
1524 collapsingInfo, loopBound, rewriter);
1530 for (
const auto &originalResult :
llvm::enumerate(genericOp->getResults())) {
1531 Value collapsedOpResult =
1532 collapsedGenericOp->getResult(originalResult.index());
1533 auto originalResultType =
1534 cast<ShapedType>(originalResult.value().getType());
1535 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.
getType());
1536 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1538 genericOp.getIndexingMapMatchingResult(originalResult.value());
1541 Value result = rewriter.
create<tensor::ExpandShapeOp>(
1542 loc, originalResultType, collapsedOpResult, reassociation);
1543 results.push_back(result);
1545 results.push_back(collapsedOpResult);
1555 class FoldWithProducerReshapeOpByCollapsing
1558 FoldWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1562 controlFoldingReshapes(std::move(foldReshapes)) {}
1566 for (
OpOperand &opOperand : genericOp->getOpOperands()) {
1567 tensor::ExpandShapeOp reshapeOp =
1568 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1574 reshapeOp.getReassociationIndices());
1575 if (collapsableIterationDims.empty() ||
1576 !controlFoldingReshapes(&opOperand)) {
1580 std::optional<SmallVector<Value>> replacements =
1583 if (!replacements) {
1585 genericOp,
"failed to do the fusion by collapsing transformation");
1588 rewriter.
replaceOp(genericOp, *replacements);
1605 controlCollapseDimension(std::move(collapseDimensions)) {}
1610 controlCollapseDimension(genericOp);
1611 if (collapsableIterationDims.empty())
1616 collapsableIterationDims)) {
1618 genericOp,
"specified dimensions cannot be collapsed");
1621 std::optional<SmallVector<Value>> replacements =
1624 if (!replacements) {
1626 "failed to collapse dimensions");
1628 rewriter.
replaceOp(genericOp, *replacements);
1652 if (!genericOp.hasTensorSemantics())
1654 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1655 Operation *def = opOperand->get().getDefiningOp();
1656 TypedAttr constantAttr;
1657 auto isScalarOrSplatConstantOp = [&constantAttr](
Operation *def) ->
bool {
1660 if (
matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1662 splatAttr.
getType().getElementType().isIntOrFloat()) {
1668 IntegerAttr intAttr;
1669 if (
matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1670 constantAttr = intAttr;
1675 FloatAttr floatAttr;
1676 if (
matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1677 constantAttr = floatAttr;
1684 auto resultValue = dyn_cast<OpResult>(opOperand->get());
1685 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1694 fusedIndexMaps.reserve(genericOp->getNumOperands());
1695 fusedOperands.reserve(genericOp.getNumDpsInputs());
1696 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
1697 for (
OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
1698 if (inputOperand == opOperand)
1700 Value inputValue = inputOperand->get();
1701 fusedIndexMaps.push_back(
1702 genericOp.getMatchingIndexingMap(inputOperand));
1703 fusedOperands.push_back(inputValue);
1704 fusedLocs.push_back(inputValue.
getLoc());
1706 for (
OpOperand &outputOperand : genericOp.getDpsInitsMutable())
1707 fusedIndexMaps.push_back(
1708 genericOp.getMatchingIndexingMap(&outputOperand));
1713 genericOp,
"fused op loop bound computation failed");
1717 Value scalarConstant =
1718 rewriter.
create<arith::ConstantOp>(def->
getLoc(), constantAttr);
1721 auto fusedOp = rewriter.
create<GenericOp>(
1722 rewriter.
getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1726 genericOp.getIteratorTypes(),
1732 Region ®ion = genericOp->getRegion(0);
1735 mapping.
map(entryBlock.
getArgument(opOperand->getOperandNumber()),
1737 Region &fusedRegion = fusedOp->getRegion(0);
1740 rewriter.
replaceOp(genericOp, fusedOp->getResults());
1764 bool modifiedOutput =
false;
1766 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
1767 if (!op.payloadUsesValueFromOperand(&opOperand)) {
1768 Value operandVal = opOperand.get();
1769 auto operandType = dyn_cast<RankedTensorType>(operandVal.
getType());
1778 auto definingOp = operandVal.
getDefiningOp<tensor::EmptyOp>();
1781 modifiedOutput =
true;
1784 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
1785 loc, mixedSizes, operandType.getElementType());
1786 op->
setOperand(opOperand.getOperandNumber(), emptyTensor);
1789 if (!modifiedOutput) {
1804 if (!genericOp.hasTensorSemantics())
1806 bool fillFound =
false;
1807 Block &payload = genericOp.getRegion().
front();
1808 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1809 if (!genericOp.payloadUsesValueFromOperand(opOperand))
1811 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
1815 Value fillVal = fillOp.value();
1817 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
1818 Value convertedVal =
1822 payload.
getArgument(opOperand->getOperandNumber()), convertedVal);
1832 patterns.
add<FoldReshapeWithGenericOpByExpansion>(patterns.
getContext(),
1833 controlFoldingReshapes);
1834 patterns.
add<FoldWithProducerReshapeOpByExpansion>(patterns.
getContext(),
1835 controlFoldingReshapes);
1841 patterns.
add<FoldWithProducerReshapeOpByCollapsing>(patterns.
getContext(),
1842 controlFoldingReshapes);
1849 patterns.
add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
1850 patterns.
add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
1851 RemoveOutsDependency>(context);
1859 patterns.
add<CollapseLinalgDimensions>(patterns.
getContext(),
1860 controlCollapseDimensions);
1875 struct LinalgElementwiseOpFusionPass
1876 :
public impl::LinalgElementwiseOpFusionBase<
1877 LinalgElementwiseOpFusionPass> {
1878 void runOnOperation()
override {
1885 Operation *producer = fusedOperand->get().getDefiningOp();
1886 return producer && producer->
hasOneUse();
1894 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1895 GenericOp::getCanonicalizationPatterns(patterns, context);
1896 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1897 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1914 return std::make_unique<LinalgElementwiseOpFusionPass>();
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static LogicalResult isGenericOpExpandable(GenericOp genericOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Epanding the body of a linalg operation requires adaptations of the accessed loop indices.
static SmallVector< utils::IteratorType > getCollapsedOpIteratorTypes(ArrayRef< utils::IteratorType > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ValueRange loopRange, RewriterBase &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, OpOperand *fusableOpOperand)
Conditions for folding a generic operation with a reshape op by expanding the iteration space dimensi...
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static void generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet< int > &preservedProducerResults)
Generate the region of the fused tensor operation.
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the type of the operand/result to use in the expanded op given the type in the original op.
static std::optional< SmallVector< Value > > fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class provides support for representing a failure result, or a valid value of type T.
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
virtual void cancelRootUpdate(Operation *op)
This method cancels a pending root update.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool areDimSequencesPreserved(ArrayRef< AffineMap > maps, ArrayRef< ReassociationIndices > dimSequences)
Return true if all sequences of dimensions specified in dimSequences are contiguous in all the ranges...
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence)
Return true if a given sequence of dimensions are contiguous in the range of the specified indexing m...
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
FailureOr< ElementwiseOpFusionResult > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
FailureOr< SmallVector< Value > > collapseGenericOpIterationDims(GenericOp genericOp, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic operation.
void populateCollapseDimensions(RewritePatternSet &patterns, const GetCollapsableDimensionsFn &controlCollapseDimensions)
Pattern to collapse dimensions in a linalg.generic op.
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns)
Pattern to remove dead operands and results of linalg.generic operations.
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
std::function< SmallVector< ReassociationIndices >(linalg::GenericOp)> GetCollapsableDimensionsFn
Function type to control generic op dimension collapsing.
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnySparseResult(Operation *op)
Returns true iff MLIR operand has any sparse result.
bool hasAnySparseOperand(Operation *op)
Returns true iff MLIR operand has any sparse operand.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rules 1) if a dimension in the collapsed typ...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::unique_ptr< Pass > createLinalgElementwiseOpFusionPass()
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
llvm::DenseMap< Value, Value > replacements