51 assert(invProducerResultIndexMap &&
52 "expected producer result indexing map to be invertible");
54 LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
56 AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
64 return t1.
compose(fusedConsumerArgIndexMap);
71 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
76 if (producer.getNumParallelLoops() != producer.getNumLoops())
81 if (!consumer.isInputTensor(consumerOpOperand))
86 AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
87 if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
91 if (producer.getNumOutputs() != 1)
97 producer.getTiedIndexingMap(producer.getOutputOperand(0));
105 if ((consumer.getNumReductionLoops())) {
106 BitVector coveredDims(consumer.getNumLoops(),
false);
108 auto addToCoveredDims = [&](
AffineMap map) {
109 for (
auto result : map.getResults())
111 coveredDims[dimExpr.getPosition()] =
true;
115 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
116 Value operand = std::get<0>(pair);
117 if (operand == consumerOpOperand->
get())
119 AffineMap operandMap = std::get<1>(pair);
120 addToCoveredDims(operandMap);
123 for (
OpOperand *operand : producer.getInputOperands()) {
126 operand, producerResultIndexMap, consumerIndexMap);
127 addToCoveredDims(newIndexingMap);
129 if (!coveredDims.all())
144 auto consumer = cast<GenericOp>(consumerOpOperand->
getOwner());
146 Block &producerBlock = producer->getRegion(0).
front();
147 Block &consumerBlock = consumer->getRegion(0).
front();
149 fusedOp.getRegion().push_back(fusedBlock);
156 if (producer.hasIndexSemantics()) {
158 unsigned numFusedOpLoops =
159 std::max(producer.getNumLoops(), consumer.getNumLoops());
161 fusedIndices.reserve(numFusedOpLoops);
162 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
163 std::back_inserter(fusedIndices), [&](uint64_t dim) {
164 return rewriter.
create<IndexOp>(producer.getLoc(), dim);
166 for (IndexOp indexOp :
167 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
168 Value newIndex = rewriter.
create<mlir::AffineApplyOp>(
170 consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
171 mapper.
map(indexOp.getResult(), newIndex);
175 assert(consumer.isInputTensor(consumerOpOperand) &&
176 "expected producer of input operand");
180 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
187 producerBlock.getArguments().take_front(producer.getNumInputs()))
188 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
192 assert(producer->getNumResults() == 1 &&
"expected single result producer");
193 if (producer.isInitTensor(producer.getOutputOperand(0))) {
195 .drop_front(producer.getNumInputs())
203 .take_front(consumer.getNumInputs())
205 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
208 consumerBlock.
getArguments().take_back(consumer.getNumOutputs()))
209 mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
212 assert(producer->getNumResults() == 1 &&
"expected single result producer");
216 for (
auto &op : producerBlock.without_terminator()) {
217 if (!isa<IndexOp>(op))
218 rewriter.
clone(op, mapper);
222 auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
224 assert(producer->getNumResults() == 1 &&
"expected single result producer");
225 unsigned producerResultNumber = 0;
230 if (replacement == yieldOp.getOperand(producerResultNumber)) {
232 assert(bb.getOwner() != &producerBlock &&
233 "yielded block argument must have been mapped");
236 "yielded value must have been mapped");
242 rewriter.
clone(op, mapper);
246 "Ill-formed GenericOp region");
253 auto consumer = cast<GenericOp>(consumerOpOperand->
getOwner());
255 !controlFn(producer->getResult(0), *consumerOpOperand))
259 assert(consumer.isInputTensor(consumerOpOperand) &&
260 "expected producer of input operand");
265 fusedOperands.reserve(producer->getNumOperands() +
266 consumer->getNumOperands());
267 fusedIndexMaps.reserve(producer->getNumOperands() +
268 consumer->getNumOperands());
273 llvm::find(consumerInputs, consumerOpOperand);
274 assert(it != consumerInputs.end() &&
"expected to find the consumer operand");
275 for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
276 fusedOperands.push_back(opOperand->get());
277 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
280 assert(producer->getNumResults() == 1 &&
"expected single result producer");
282 producer.getTiedIndexingMap(producer.getOutputOperand(0));
283 for (
OpOperand *opOperand : producer.getInputOperands()) {
284 fusedOperands.push_back(opOperand->get());
287 opOperand, producerResultIndexMap,
288 consumer.getTiedIndexingMap(consumerOpOperand));
289 fusedIndexMaps.push_back(map);
293 assert(producer->getNumResults() == 1 &&
"expected single result producer");
294 if (producer.isInitTensor(producer.getOutputOperand(0))) {
295 fusedOperands.push_back(producer.getOutputOperand(0)->get());
298 producer.getOutputOperand(0), producerResultIndexMap,
299 consumer.getTiedIndexingMap(consumerOpOperand));
300 fusedIndexMaps.push_back(map);
305 llvm::make_range(std::next(it), consumerInputs.end())) {
306 fusedOperands.push_back(opOperand->get());
307 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
310 for (
OpOperand *opOperand : consumer.getOutputOperands())
311 fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
314 assert(producer->getNumResults() == 1 &&
"expected single result producer");
318 auto fusedOp = rewriter.
create<GenericOp>(
319 consumer.getLoc(), consumer->getResultTypes(),
323 consumer.getIteratorTypes(),
326 if (!fusedOp.getShapesToLoopsMap()) {
337 consumer.getTiedIndexingMap(consumerOpOperand);
341 assert(invProducerResultIndexMap &&
342 "expected producer result indexig map to be invertible");
345 invProducerResultIndexMap.
compose(consumerResultIndexMap);
348 consumerToProducerLoopsMap,
349 consumerOpOperand, consumer.getNumLoops());
356 if (producer->getNumResults() != 1)
370 controlFn(std::move(fun)) {}
375 for (
OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
377 dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
378 if (!producer || !producer.hasTensorSemantics())
382 if (fusedOpResults) {
383 rewriter.
replaceOp(genericOp, *fusedOpResults);
460 return genericOp.hasTensorSemantics() &&
461 llvm::all_of(genericOp.getIndexingMaps().getValue(),
463 return attr.cast<AffineMapAttr>()
465 .isProjectedPermutation();
467 genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
468 llvm::all_of(genericOp.getIteratorTypes(), [](
Attribute attr) {
469 return attr.cast<StringAttr>().getValue() ==
477 class ExpansionInfo {
488 unsigned getOrigOpNumDims()
const {
return reassociation.size(); }
489 unsigned getExpandedOpNumDims()
const {
return expandedOpNumDims; }
491 return reassociation[i];
494 return expandedShapeMap[i];
507 unsigned expandedOpNumDims;
517 if (reassociationMaps.empty())
519 AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
522 originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
524 reassociation.clear();
525 expandedShapeMap.clear();
529 expandedShapeMap.resize(fusedIndexMap.
getNumDims());
531 unsigned pos = resultExpr.value().cast<
AffineDimExpr>().getPosition();
532 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
535 expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
536 expandedShapeMap[pos].assign(shape.begin(), shape.end());
539 for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
540 if (expandedShapeMap[i].empty())
541 expandedShapeMap[i] = {originalLoopExtent[i]};
545 reassociation.reserve(fusedIndexMap.
getNumDims());
547 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
548 reassociation.emplace_back(seq.begin(), seq.end());
549 sum += numFoldedDim.value();
551 expandedOpNumDims = sum;
564 const ExpansionInfo &expansionInfo,
566 if (!genericOp.hasIndexSemantics())
568 for (
unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
570 if (expandedShape.size() == 1)
572 for (int64_t shape : expandedShape.drop_front()) {
573 if (ShapedType::isDynamic(shape)) {
575 genericOp,
"cannot expand due to index semantics and dynamic dims");
586 const ExpansionInfo &expansionInfo) {
591 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
594 newExprs.append(expandedExprs.begin(), expandedExprs.end());
605 const ExpansionInfo &expansionInfo) {
609 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
610 expandedShape.append(dimExpansion.begin(), dimExpansion.end());
612 return RankedTensorType::get(expandedShape, originalType.getElementType());
623 const ExpansionInfo &expansionInfo) {
625 unsigned numReshapeDims = 0;
628 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
630 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
631 reassociation.emplace_back(std::move(indices));
632 numReshapeDims += numExpandedDims;
634 return reassociation;
644 const ExpansionInfo &expansionInfo) {
646 for (IndexOp indexOp :
647 llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
649 expansionInfo.getExpandedDims(indexOp.getDim());
650 assert(!expandedDims.empty() &&
"expected valid expansion info");
653 if (expandedDims.size() == 1 &&
654 expandedDims.front() == (int64_t)indexOp.getDim())
661 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
663 expandedIndices.reserve(expandedDims.size() - 1);
665 expandedDims.drop_front(), std::back_inserter(expandedIndices),
666 [&](int64_t dim) {
return rewriter.
create<IndexOp>(loc, dim); });
667 Value newIndex = rewriter.
create<IndexOp>(loc, expandedDims.front());
668 for (
auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
669 assert(!ShapedType::isDynamic(std::get<0>(it)));
672 newIndex = rewriter.
create<AffineApplyOp>(
673 indexOp.getLoc(), idx + acc * std::get<0>(it),
688 "preconditions for fuse operation failed");
690 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
691 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
692 bool isExpanding = (expandingReshapeOp !=
nullptr);
693 RankedTensorType expandedType = isExpanding
694 ? expandingReshapeOp.getResultType()
695 : collapsingReshapeOp.getSrcType();
696 RankedTensorType collapsedType = isExpanding
697 ? expandingReshapeOp.getSrcType()
698 : collapsingReshapeOp.getResultType();
700 ExpansionInfo expansionInfo;
701 if (
failed(expansionInfo.compute(
702 genericOp, fusableOpOperand,
703 isExpanding ? expandingReshapeOp.getReassociationMaps()
704 : collapsingReshapeOp.getReassociationMaps(),
705 expandedType.getShape(), collapsedType.getShape(), rewriter)))
712 llvm::map_range(genericOp.getIndexingMapsArray(), [&](
AffineMap m) {
717 expandedOpOperands.reserve(genericOp.getNumInputs());
718 for (
OpOperand *opOperand : genericOp.getInputOperands()) {
719 if (opOperand == fusableOpOperand) {
720 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
721 : collapsingReshapeOp.getSrc());
724 if (genericOp.isInputTensor(opOperand)) {
725 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
726 auto opOperandType = opOperand->
get().getType().cast<RankedTensorType>();
727 RankedTensorType expandedOperandType =
729 if (expandedOperandType != opOperand->get().getType()) {
734 [&](
const Twine &msg) {
737 opOperandType.getShape(), expandedOperandType.getShape(),
741 expandedOpOperands.push_back(rewriter.
create<tensor::ExpandShapeOp>(
742 genericOp.getLoc(), expandedOperandType, opOperand->get(),
747 expandedOpOperands.push_back(opOperand->get());
752 for (
OpOperand *opOperand : genericOp.getOutputOperands()) {
753 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
754 auto opOperandType = opOperand->
get().getType().cast<RankedTensorType>();
755 RankedTensorType expandedOutputType =
757 if (expandedOutputType != opOperand->get().getType()) {
761 [&](
const Twine &msg) {
764 opOperandType.getShape(), expandedOutputType.getShape(),
768 outputs.push_back(rewriter.
create<tensor::ExpandShapeOp>(
769 genericOp.getLoc(), expandedOutputType, opOperand->get(),
780 rewriter.
create<GenericOp>(genericOp.getLoc(), resultTypes,
781 expandedOpOperands, outputs,
782 expandedOpIndexingMaps, iteratorTypes);
783 Region &fusedRegion = fusedOp->getRegion(0);
784 Region &originalRegion = genericOp->getRegion(0);
793 for (
OpResult opResult : genericOp->getOpResults()) {
794 int64_t resultNumber = opResult.getResultNumber();
795 if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
798 genericOp.getTiedIndexingMap(
799 genericOp.getOutputOperand(resultNumber)),
801 resultVals.push_back(rewriter.
create<tensor::CollapseShapeOp>(
802 genericOp.getLoc(), opResult.getType(),
803 fusedOp->getResult(resultNumber), reassociation));
805 resultVals.push_back(fusedOp->getResult(resultNumber));
817 class FoldWithProducerReshapeOpByExpansion
820 FoldWithProducerReshapeOpByExpansion(
MLIRContext *context,
824 controlFoldingReshapes(std::move(foldReshapes)) {}
828 for (
OpOperand *opOperand : genericOp.getInputTensorOperands()) {
829 tensor::CollapseShapeOp reshapeOp =
830 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
837 (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
842 if (!replacementValues)
844 rewriter.
replaceOp(genericOp, *replacementValues);
856 struct FoldReshapeWithGenericOpByExpansion
859 FoldReshapeWithGenericOpByExpansion(
MLIRContext *context,
863 controlFoldingReshapes(std::move(foldReshapes)) {}
865 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
868 GenericOp producer = reshapeOp.getSrc().getDefiningOp<GenericOp>();
869 if (!producer || producer.getNumOutputs() != 1 ||
871 producer.getOutputOperand(0)) ||
872 !controlFoldingReshapes(producer->getResult(0),
873 reshapeOp->getOpOperand(0)))
876 producer, reshapeOp, producer.getOutputOperand(0), rewriter);
877 if (!replacementValues)
879 rewriter.
replaceOp(reshapeOp, *replacementValues);
901 "expected projected permutation");
904 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
911 return domainReassociation;
919 assert(!dimSequence.empty() &&
920 "expected non-empty list for dimension sequence");
922 "expected indexing map to be projected permutation");
924 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
925 sequenceElements.insert(dimSequence.begin(), dimSequence.end());
927 unsigned dimSequenceStart = dimSequence[0];
929 unsigned dimInMapStart = expr.value().cast<
AffineDimExpr>().getPosition();
931 if (dimInMapStart == dimSequenceStart) {
932 if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
935 for (
const auto &dimInSequence :
enumerate(dimSequence)) {
937 indexingMap.
getResult(expr.index() + dimInSequence.index())
938 .cast<AffineDimExpr>()
940 if (dimInMap != dimInSequence.value())
951 if (sequenceElements.count(dimInMapStart))
1008 if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1)
1011 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
1012 return map.isProjectedPermutation();
1019 for (
const auto &iteratorType :
1022 reductionDims.push_back(iteratorType.index());
1026 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1027 AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand);
1028 auto iteratorTypes = genericOp.getIteratorTypes().getValue();
1031 assert(!foldedRangeDims.empty() &&
"unexpected empty reassociation");
1034 if (foldedRangeDims.size() == 1)
1042 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1043 return processedIterationDims.count(dim);
1048 Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
1052 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1053 return iteratorTypes[dim] != startIteratorType;
1062 bool isContiguous =
false;
1065 if (startDim.value() != foldedIterationSpaceDims[0])
1069 if (startDim.index() + foldedIterationSpaceDims.size() >
1070 reductionDims.size())
1073 isContiguous =
true;
1074 for (
const auto &foldedDim :
1076 if (reductionDims[foldedDim.index() + startDim.index()] !=
1077 foldedDim.value()) {
1078 isContiguous =
false;
1089 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1092 foldedIterationSpaceDims);
1096 processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1097 foldedIterationSpaceDims.end());
1098 iterationSpaceReassociation.emplace_back(
1099 std::move(foldedIterationSpaceDims));
1102 return iterationSpaceReassociation;
1107 class CollapsingInfo {
1111 llvm::SmallDenseSet<int64_t, 4> processedDims;
1114 if (foldedIterationDim.empty())
1118 for (
auto dim : foldedIterationDim) {
1119 if (dim >= origNumLoops)
1121 if (processedDims.count(dim))
1123 processedDims.insert(dim);
1125 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1126 foldedIterationDim.end());
1128 if (processedDims.size() > origNumLoops)
1133 for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1134 if (processedDims.count(dim))
1139 llvm::sort(collapsedOpToOrigOpIterationDim,
1141 return lhs[0] < rhs[0];
1143 origOpToCollapsedOpIterationDim.resize(origNumLoops);
1144 for (
const auto &foldedDims :
1146 for (
const auto &dim :
enumerate(foldedDims.value()))
1147 origOpToCollapsedOpIterationDim[dim.value()] =
1148 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1155 return collapsedOpToOrigOpIterationDim;
1179 return origOpToCollapsedOpIterationDim;
1183 unsigned getCollapsedOpIterationRank()
const {
1184 return collapsedOpToOrigOpIterationDim.size();
1202 const CollapsingInfo &collapsingInfo) {
1205 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1206 assert(!foldedIterDims.empty() &&
1207 "reassociation indices expected to have non-empty sets");
1211 collapsedIteratorTypes.push_back(
1212 iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue());
1214 return collapsedIteratorTypes;
1221 const CollapsingInfo &collapsingInfo) {
1224 "expected indexing map to be projected permutation");
1226 auto origOpToCollapsedOpMapping =
1227 collapsingInfo.getOrigOpToCollapsedOpMapping();
1231 if (origOpToCollapsedOpMapping[dim].second != 0)
1235 resultExprs.push_back(
1238 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1239 resultExprs, context);
1246 const CollapsingInfo &collapsingInfo) {
1247 unsigned counter = 0;
1249 auto origOpToCollapsedOpMapping =
1250 collapsingInfo.getOrigOpToCollapsedOpMapping();
1251 auto collapsedOpToOrigOpMapping =
1252 collapsingInfo.getCollapsedOpToOrigOpMapping();
1256 if (origOpToCollapsedOpMapping[dim].second == 0) {
1260 unsigned numFoldedDims =
1261 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1263 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1264 operandReassociation.emplace_back(range.begin(), range.end());
1265 counter += numFoldedDims;
1268 return operandReassociation;
1274 const CollapsingInfo &collapsingInfo,
1276 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
1283 if (operandReassociation.size() == indexingMap.
getNumResults())
1287 auto reshapeOp = builder.
create<tensor::CollapseShapeOp>(
1288 loc, operand, operandReassociation);
1295 const CollapsingInfo &collapsingInfo,
1302 auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
1312 for (
auto &foldedDims :
1313 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1316 rewriter.
create<linalg::IndexOp>(loc, foldedDims.index());
1317 for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1318 indexReplacementVals[dim] =
1319 rewriter.
create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1321 rewriter.
create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1323 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1326 for (
auto indexOp : indexOps) {
1327 auto dim = indexOp.getDim();
1328 rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
1337 if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1339 return foldedDims.size() <= 1;
1343 CollapsingInfo collapsingInfo;
1344 if (
failed(collapsingInfo.initialize(genericOp.getNumLoops(),
1345 foldedIterationDims))) {
1347 genericOp,
"illegal to collapse specified dimensions");
1352 cast<LinalgOp>(genericOp.getOperation())
1353 .createLoopRanges(rewriter, genericOp.getLoc());
1355 if (
auto attr = ofr.dyn_cast<
Attribute>())
1356 return attr.cast<IntegerAttr>().getInt() ==
value;
1359 actual.getSExtValue() ==
value;
1361 if (!llvm::all_of(loopRanges, [&](
Range range) {
1362 return opFoldIsConstantValue(range.
offset, 0) &&
1363 opFoldIsConstantValue(range.
stride, 1);
1367 "expected all loop ranges to have zero start and unit stride");
1372 genericOp.getIteratorTypes().getValue(), collapsingInfo);
1375 auto indexingMaps = llvm::to_vector(
1376 llvm::map_range(genericOp.getIndexingMapsArray(), [&](
AffineMap map) {
1380 Location loc = genericOp->getLoc();
1383 auto inputOperands = llvm::to_vector(
1384 llvm::map_range(genericOp.getInputOperands(), [&](
OpOperand *opOperand) {
1392 resultTypes.reserve(genericOp.getNumOutputs());
1393 outputOperands.reserve(genericOp.getNumOutputs());
1394 for (
OpOperand *output : genericOp.getOutputOperands()) {
1397 outputOperands.push_back(newOutput);
1398 resultTypes.push_back(newOutput.
getType());
1402 auto collapsedGenericOp = rewriter.
create<linalg::GenericOp>(
1403 loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1405 Block *origOpBlock = &genericOp->getRegion(0).
front();
1406 Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).
front();
1407 rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
1410 if (collapsedGenericOp.hasIndexSemantics()) {
1415 llvm::to_vector(llvm::map_range(loopRanges, [&](
Range range) {
1419 &collapsedGenericOp->getRegion(0).front(),
1420 collapsingInfo, loopBound, rewriter);
1426 for (
const auto &originalResult :
llvm::enumerate(genericOp->getResults())) {
1427 Value collapsedOpResult =
1428 collapsedGenericOp->getResult(originalResult.index());
1429 auto originalResultType =
1430 originalResult.value().
getType().
cast<ShapedType>();
1431 auto collapsedOpResultType = collapsedOpResult.
getType().
cast<ShapedType>();
1432 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1434 genericOp.getTiedIndexingMapForResult(originalResult.value());
1437 Value result = rewriter.
create<tensor::ExpandShapeOp>(
1438 loc, originalResultType, collapsedOpResult, reassociation);
1439 results.push_back(result);
1441 results.push_back(collapsedOpResult);
1451 class FoldWithProducerReshapeOpByCollapsing
1454 FoldWithProducerReshapeOpByCollapsing(
MLIRContext *context,
1458 controlFoldingReshapes(std::move(foldReshapes)) {}
1462 for (
OpOperand *opOperand : genericOp.getInputTensorOperands()) {
1463 tensor::ExpandShapeOp reshapeOp =
1464 opOperand->get().getDefiningOp<tensor::ExpandShapeOp>();
1470 reshapeOp.getReassociationIndices());
1471 if (collapsableIterationDims.empty() ||
1472 !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
1478 opOperand, rewriter);
1479 if (!replacements) {
1481 genericOp,
"failed to do the fusion by collapsing transformation");
1484 rewriter.
replaceOp(genericOp, *replacements);
1509 if (!genericOp.hasTensorSemantics())
1511 for (
OpOperand *opOperand : genericOp.getInputOperands()) {
1512 Operation *def = opOperand->get().getDefiningOp();
1513 TypedAttr constantAttr;
1514 auto isScalarOrSplatConstantOp = [&constantAttr](
Operation *def) ->
bool {
1517 if (
matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1519 splatAttr.
getType().getElementType().isIntOrFloat()) {
1525 IntegerAttr intAttr;
1526 if (
matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1527 constantAttr = intAttr;
1532 FloatAttr floatAttr;
1533 if (
matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1534 constantAttr = floatAttr;
1541 auto resultValue = opOperand->get().dyn_cast<
OpResult>();
1542 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1551 fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
1552 fusedOperands.reserve(genericOp.getNumInputs());
1553 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
1554 for (
OpOperand *inputOperand : genericOp.getInputOperands()) {
1555 if (inputOperand == opOperand)
1557 Value inputValue = inputOperand->get();
1558 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
1559 fusedOperands.push_back(inputValue);
1560 fusedLocs.push_back(inputValue.
getLoc());
1562 for (
OpOperand *outputOperand : genericOp.getOutputOperands())
1563 fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
1568 genericOp,
"fused op loop bound computation failed");
1572 Value scalarConstant = rewriter.
create<arith::ConstantOp>(
1573 def->
getLoc(), constantAttr, constantAttr.getType());
1576 auto fusedOp = rewriter.
create<GenericOp>(
1577 rewriter.
getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1581 genericOp.getIteratorTypes(),
1587 Region ®ion = genericOp->getRegion(0);
1590 mapping.
map(entryBlock.
getArgument(opOperand->getOperandNumber()),
1592 Region &fusedRegion = fusedOp->getRegion(0);
1595 rewriter.
replaceOp(genericOp, fusedOp->getResults());
1619 bool modifiedOutput =
false;
1621 for (
OpOperand *opOperand : op.getOutputOperands()) {
1622 if (!op.payloadUsesValueFromOperand(opOperand)) {
1623 Value operandVal = opOperand->get();
1636 modifiedOutput =
true;
1639 if (dim.value() != ShapedType::kDynamicSize)
1641 dynamicDims.push_back(rewriter.
createOrFold<tensor::DimOp>(
1642 loc, operandVal, dim.index()));
1645 loc, dynamicDims, operandType.getShape(),
1646 operandType.getElementType());
1647 op->
setOperand(opOperand->getOperandNumber(), initTensor);
1650 if (!modifiedOutput) {
1665 if (!genericOp.hasTensorSemantics())
1667 bool fillFound =
false;
1668 Block &payload = genericOp.getRegion().
front();
1669 for (
OpOperand *opOperand : genericOp.getInputOperands()) {
1670 if (!genericOp.payloadUsesValueFromOperand(opOperand))
1672 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
1676 payload.
getArgument(opOperand->getOperandNumber())
1677 .replaceAllUsesWith(fillOp.value());
1687 patterns.
add<FoldReshapeWithGenericOpByExpansion>(patterns.
getContext(),
1688 controlFoldingReshapes);
1689 patterns.
add<FoldWithProducerReshapeOpByExpansion>(patterns.
getContext(),
1690 controlFoldingReshapes);
1696 patterns.
add<FoldWithProducerReshapeOpByCollapsing>(patterns.
getContext(),
1697 controlFoldingReshapes);
1704 patterns.
add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
1705 patterns.
add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
1706 RemoveOutsDependency>(context);
1721 struct LinalgElementwiseOpFusionPass
1722 :
public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
1723 void runOnOperation()
override {
1739 AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1740 GenericOp::getCanonicalizationPatterns(patterns, context);
1741 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1742 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1760 return std::make_unique<LinalgElementwiseOpFusionPass>();
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
MLIRContext * getContext() const
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static Optional< SmallVector< Value > > fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, const ControlFusionFn &controlFn, PatternRewriter &rewriter)
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Operation is a basic unit of execution within MLIR.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
bool isParallelIterator(Attribute attr)
unsigned getNumSymbols() const
unsigned getNumDims() const
This is a value defined by a result of an operation.
Block represents an ordered list of Operations.
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
static SmallVector< StringRef > getCollapsedOpIteratorTypes(ArrayRef< Attribute > iteratorTypes, const CollapsingInfo &collapsingInfo)
Get the iterator types for the collapsed operation given the original iterator types and collapsed di...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::unique_ptr< Pass > createLinalgElementwiseOpFusionPass()
This class represents a single result from folding an operation.
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
OpListType & getOperations()
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rules 1) if a dimension in the collapsed typ...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
static FailureOr< SmallVector< Value > > collapseGenericOpIterationDims(GenericOp genericOp, ArrayRef< ReassociationIndices > foldedIterationDims, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implementation of fusion with reshape operation by collapsing dimensions.
This class allows control over how the GreedyPatternRewriteDriver works.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape...
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
static ReassociationIndices getDomainReassociation(AffineMap indexingMap, ReassociationIndicesRef rangeReassociation)
For a given list of indices in the range of the indexingMap that are folded, return the indices of th...
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
BlockArgument getArgument(unsigned i)
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Auxiliary range data structure to unpack the offset, size and stride operands into a list of triples...
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo)
Update the body of an expanded linalg operation having index semantics.
AffineExpr getResult(unsigned idx) const
void map(Block *from, Block *to)
Inserts a new mapping for 'from' to 'to'.
MLIRContext * getContext()
Return the context this operation is associated with.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the type of the operand/result to use in the expanded op given the type in the original op...
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
static SmallVector< ReassociationIndices > getOperandReassociation(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Return the reassociation indices to use to collapse the operand when the iteration space of a generic...
static SmallVector< ReassociationIndices > getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Returns the reassociation maps to use in the tensor.expand_shape operation to convert the operands of...
This class provides support for representing a failure result, or a valid value of type T...
An attribute that represents a reference to a dense vector or tensor object.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static bool isDimSequencePreserved(AffineMap indexingMap, ReassociationIndicesRef dimSequence)
For a given dimSequence, check if the sequence is conserved in the indexingMap.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumArguments()
bool hasOneUse() const
Returns true if this value has exactly one use.
Attributes are known-constant values of operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo)
Return the indexing map to use in the expanded op for a given the indexingMap of the original operati...
type_range getTypes() const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps...
Base type for affine expression.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, OpOperand *opOperand, const CollapsingInfo &collapsingInfo, OpBuilder &builder)
Get the new value to use for a given OpOperand in the collapsed operation.
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
This class provides an abstraction over the various different ranges of value types.
unsigned getNumResults() const
Location getLoc()
The source location the operation was defined or derived from.
IRValueT get() const
Return the current value being used by this operand.
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap)
Append to fusedOpIndexingMapAttrs the indexing maps for the operands of the producer to use in the fu...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued...
BlockArgListType getArguments()
static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, OpOperand *consumerOpOperand)
Conditions for elementwise fusion of generic operations.
This class represents an argument of a Block.
Value materializeOpFoldResult(ImplicitLocOpBuilder &builder, OpFoldResult opFoldResult)
Turns an OpFoldResult into a value, creating an index-typed constant if necessary.
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
ArrayRef< AffineExpr > getResults() const
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void setOperand(unsigned idx, Value value)
Location getLoc() const
Return the location of this value.
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool isReductionIterator(Attribute attr)
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
RAII guard to reset the insertion point of the builder when destroyed.
Type getType() const
Return the type of this value.
static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, OpOperand *fusableOpOperand)
Conditions for folding a generic operation with a reshape op by expanding the iteration space dimensi...
void generateCollapsedIndexingRegion(Location loc, Block *block, const CollapsingInfo &collapsingInfo, ValueRange loopRange, PatternRewriter &rewriter)
Modify the linalg.index operations in the original generic op, to its value in the collapsed operatio...
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
A dimensional identifier appearing in an affine expression.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'. ...
virtual void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent"...
Operation * getOwner() const
Return the owner of this operand.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MLIRContext is the top-level object for a collection of MLIR operations.
Block * lookupOrDefault(Block *from) const
Lookup a mapped value within the map.
This class represents an operand of an operation.
AffineExpr getAffineDimExpr(unsigned position)
static Optional< SmallVector< Value > > fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, GenericOp producer, const ControlFusionFn &controlFn)
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, const CollapsingInfo &collapsingInfo)
Compute the indexing map in the collapsed op that corresponds to the given indexingMap of the origina...
std::function< bool(const OpResult &producer, OpOperand &consumer)> ControlFusionFn
Function type which is used to control when to stop fusion.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
static LogicalResult isGenericOpExpandable(GenericOp genericOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter)
Epanding the body of a linalg operation requires adaptations of the accessed loop indices...
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
static SmallVector< ReassociationIndices > getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef< ReassociationIndices > reassociation)
static void generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *consumerOpOperand, unsigned nloops)
Generate the region of the fused tensor operation.
MLIRContext * getContext() const
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block 'source' into the end of block 'dest'.
This class helps build Operations.
std::enable_if<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T >::type getSplatValue() const
Return the splat value for this attribute.
This class provides an abstraction over the different types of ranges over Values.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
virtual void cancelRootUpdate(Operation *op)
This method cancels a pending root update.
Location getLoc() const
Return the location for this argument.
static Optional< SmallVector< Value > > fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter)
Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op and a generic op as expl...
MLIRContext * getContext() const
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)