17 #include "llvm/ADT/ScopeExit.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "llvm/Support/SaveAndRestore.h"
23 #include "llvm/Support/ScopedPrinter.h"
29 #define DEBUG_TYPE "dialect-conversion"
32 template <
typename... Args>
33 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
36 os.startLine() <<
"} -> SUCCESS";
38 os.getOStream() <<
" : "
39 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
40 os.getOStream() <<
"\n";
45 template <
typename... Args>
46 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
49 os.startLine() <<
"} -> FAILURE : "
50 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
62 struct ConversionValueMapping {
67 Value lookupOrDefault(
Value from,
Type desiredType =
nullptr)
const;
77 for (
Value it = newVal; it; it = mapping.lookupOrNull(it))
78 assert(it != oldVal &&
"inserting cyclic mapping");
80 mapping.map(oldVal, newVal);
89 void erase(
Value value) { mapping.erase(value); }
94 for (
auto &it : mapping.getValueMap())
95 inverse[it.second].push_back(it.first);
105 Value ConversionValueMapping::lookupOrDefault(
Value from,
106 Type desiredType)
const {
111 while (
auto mappedValue = mapping.lookupOrNull(from))
119 if (from.
getType() == desiredType)
122 Value mappedValue = mapping.lookupOrNull(from);
129 return desiredValue ? desiredValue : from;
132 Value ConversionValueMapping::lookupOrNull(
Value from,
Type desiredType)
const {
133 Value result = lookupOrDefault(from, desiredType);
134 if (result == from || (desiredType && result.
getType() != desiredType))
139 bool ConversionValueMapping::tryMap(
Value oldVal,
Value newVal) {
140 for (
Value it = newVal; it; it = mapping.lookupOrNull(it))
153 struct RewriterState {
154 RewriterState(
unsigned numCreatedOps,
unsigned numUnresolvedMaterializations,
155 unsigned numReplacements,
unsigned numArgReplacements,
156 unsigned numBlockActions,
unsigned numIgnoredOperations,
157 unsigned numRootUpdates)
158 : numCreatedOps(numCreatedOps),
159 numUnresolvedMaterializations(numUnresolvedMaterializations),
160 numReplacements(numReplacements),
161 numArgReplacements(numArgReplacements),
162 numBlockActions(numBlockActions),
163 numIgnoredOperations(numIgnoredOperations),
164 numRootUpdates(numRootUpdates) {}
167 unsigned numCreatedOps;
170 unsigned numUnresolvedMaterializations;
173 unsigned numReplacements;
176 unsigned numArgReplacements;
179 unsigned numBlockActions;
182 unsigned numIgnoredOperations;
185 unsigned numRootUpdates;
194 class OperationTransactionState {
196 OperationTransactionState() =
default;
198 : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
199 operands(op->operand_begin(), op->operand_end()),
200 successors(op->successor_begin(), op->successor_end()) {}
204 void resetOperation()
const {
207 op->setOperands(operands);
209 op->setSuccessor(it.value(), it.index());
213 Operation *getOperation()
const {
return op; }
218 DictionaryAttr attrs;
228 struct OpReplacement {
229 OpReplacement(
TypeConverter *converter =
nullptr) : converter(converter) {}
241 enum class BlockActionKind {
252 struct BlockPosition {
254 Block *insertAfterBlock;
270 static BlockAction getCreate(
Block *block) {
271 return {BlockActionKind::Create, block, {}};
273 static BlockAction getErase(
Block *block, BlockPosition originalPosition) {
274 return {BlockActionKind::Erase, block, {originalPosition}};
276 static BlockAction getInline(
Block *block,
Block *srcBlock,
278 BlockAction action{BlockActionKind::Inline, block, {}};
279 action.inlineInfo = {srcBlock,
280 srcBlock->
empty() ? nullptr : &srcBlock->
front(),
281 srcBlock->
empty() ? nullptr : &srcBlock->
back()};
284 static BlockAction getMove(
Block *block, BlockPosition originalPosition) {
285 return {BlockActionKind::Move, block, {originalPosition}};
287 static BlockAction getSplit(
Block *block,
Block *originalBlock) {
288 BlockAction action{BlockActionKind::Split, block, {}};
289 action.originalBlock = originalBlock;
292 static BlockAction getTypeConversion(
Block *block) {
293 return BlockAction{BlockActionKind::TypeConversion, block, {}};
297 BlockActionKind kind;
306 BlockPosition originalPosition;
309 Block *originalBlock;
312 InlineInfo inlineInfo;
322 class UnresolvedMaterialization {
335 UnresolvedMaterialization(UnrealizedConversionCastOp op =
nullptr,
337 Kind kind = Target,
Type origOutputType =
nullptr)
338 : op(op), converterAndKind(converter, kind),
339 origOutputType(origOutputType) {}
343 UnrealizedConversionCastOp getOp()
const {
return op; }
346 TypeConverter *getConverter()
const {
return converterAndKind.getPointer(); }
349 Kind getKind()
const {
return converterAndKind.getInt(); }
352 void setKind(
Kind kind) { converterAndKind.setInt(kind); }
355 Type getOrigOutputType()
const {
return origOutputType; }
359 UnrealizedConversionCastOp op;
363 llvm::PointerIntPair<TypeConverter *, 1, Kind> converterAndKind;
373 UnresolvedMaterialization::Kind kind,
Block *insertBlock,
378 if (inputs.size() == 1 && inputs.front().
getType() == outputType)
379 return inputs.front();
383 OpBuilder builder(insertBlock, insertPt);
385 builder.
create<UnrealizedConversionCastOp>(loc, outputType, inputs);
386 unresolvedMaterializations.emplace_back(convertOp, converter, kind,
388 return convertOp.getResult(0);
397 converter, unresolvedMaterializations);
405 insertPt = ++inputRes.getOwner()->getIterator();
408 UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
409 outputType, outputType, converter, unresolvedMaterializations);
420 struct ArgConverter {
424 : rewriter(rewriter),
425 unresolvedMaterializations(unresolvedMaterializations) {}
429 struct ConvertedArgInfo {
430 ConvertedArgInfo(
unsigned newArgIdx,
unsigned newArgSize,
431 Value castValue =
nullptr)
432 : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
448 struct ConvertedBlockInfo {
450 : origBlock(origBlock), converter(converter) {}
464 bool hasBeenConverted(
Block *block)
const {
465 return conversionInfo.count(block) || convertedBlocks.count(block);
470 assert(typeConverter &&
"expected valid type converter");
471 regionToConverter[region] = typeConverter;
477 return regionToConverter.lookup(region);
492 void discardRewrites(
Block *block);
495 void applyRewrites(ConversionValueMapping &mapping);
501 materializeLiveConversions(ConversionValueMapping &mapping,
514 ConversionValueMapping &mapping,
523 Block *applySignatureConversion(
526 ConversionValueMapping &mapping,
530 void insertConversion(
Block *newBlock, ConvertedBlockInfo &&info);
534 llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
558 void ArgConverter::notifyOpRemoved(
Operation *op) {
559 if (conversionInfo.empty())
563 for (
Block &block : region) {
566 if (nestedOp.getNumRegions())
567 notifyOpRemoved(&nestedOp);
570 auto it = conversionInfo.find(&block);
571 if (it == conversionInfo.end())
575 Block *origBlock = it->second.origBlock;
578 conversionInfo.erase(it);
583 void ArgConverter::discardRewrites(
Block *block) {
584 auto it = conversionInfo.find(block);
585 if (it == conversionInfo.end())
587 Block *origBlock = it->second.origBlock;
599 convertedBlocks.erase(origBlock);
600 conversionInfo.erase(it);
603 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
604 for (
auto &info : conversionInfo) {
605 ConvertedBlockInfo &blockInfo = info.second;
606 Block *origBlock = blockInfo.origBlock;
610 std::optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
615 if (
Value newArg = mapping.lookupOrNull(origArg, origArg.
getType()))
621 Value castValue = argInfo->castValue;
622 assert(argInfo->newArgSize >= 1 && castValue &&
"expected 1->1+ mapping");
627 mapping.lookupOrDefault(castValue, origArg.
getType()));
634 ConversionValueMapping &mapping,
OpBuilder &builder,
636 for (
auto &info : conversionInfo) {
637 Block *newBlock = info.first;
638 ConvertedBlockInfo &blockInfo = info.second;
639 Block *origBlock = blockInfo.origBlock;
646 if (mapping.lookupOrNull(origArg, origArg.
getType()))
648 Operation *liveUser = findLiveUser(origArg);
652 Value replacementValue = mapping.lookupOrDefault(origArg);
653 bool isDroppedArg = replacementValue == origArg;
655 rewriter.setInsertionPointToStart(newBlock);
657 rewriter.setInsertionPointAfterValue(replacementValue);
659 if (blockInfo.converter) {
660 newArg = blockInfo.converter->materializeSourceConversion(
664 "materialization hook did not provide a value of the expected "
670 <<
"failed to materialize conversion for block argument #" << i
671 <<
" that remained live after conversion, type was "
674 diag <<
", with target type " << replacementValue.
getType();
676 <<
"see existing live user here: " << *liveUser;
679 mapping.map(origArg, newArg);
693 if (hasBeenConverted(block) || !block->
getParent())
702 return applySignatureConversion(block, converter, *conversion, mapping,
707 Block *ArgConverter::applySignatureConversion(
710 ConversionValueMapping &mapping,
715 if (origArgCount == 0 && convertedTypes.empty())
725 rewriter.getUnknownLoc());
732 ConvertedBlockInfo info(block, converter);
733 info.argInfo.resize(origArgCount);
736 rewriter.setInsertionPointToStart(newBlock);
737 for (
unsigned i = 0; i != origArgCount; ++i) {
745 if (inputMap->replacementValue) {
746 assert(inputMap->size == 0 &&
747 "invalid to provide a replacement value when the argument isn't "
749 mapping.map(origArg, inputMap->replacementValue);
750 argReplacements.push_back(origArg);
755 auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
767 if (replArgs.size() == 1 &&
768 (!converter || replArgs[0].getType() == origArg.
getType())) {
769 newArg = replArgs.front();
774 Type outputType = origOutputType;
776 outputType = legalOutputType;
779 rewriter, origArg.
getLoc(), replArgs, origOutputType, outputType,
780 converter, unresolvedMaterializations);
783 mapping.map(origArg, newArg);
784 argReplacements.push_back(origArg);
786 ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
790 insertConversion(newBlock, std::move(info));
794 void ArgConverter::insertConversion(
Block *newBlock,
795 ConvertedBlockInfo &&info) {
798 std::unique_ptr<Region> &mappedRegion = regionMapping[region];
800 mappedRegion = std::make_unique<Region>(region->
getParentOp());
803 mappedRegion->getBlocks().splice(mappedRegion->end(), region->
getBlocks(),
804 info.origBlock->getIterator());
805 convertedBlocks.insert(info.origBlock);
806 conversionInfo.insert({newBlock, std::move(info)});
816 : argConverter(rewriter, unresolvedMaterializations),
817 notifyCallback(nullptr) {}
821 void discardRewrites();
825 void applyRewrites();
832 RewriterState getCurrentState();
835 void resetState(RewriterState state);
839 void eraseDanglingBlocks();
843 void undoBlockActions(
unsigned numActionsToKeep = 0);
850 std::optional<Location> inputLoc,
860 void markNestedOpsIgnored(
Operation *op);
874 applySignatureConversion(
Region *region,
896 void notifyBlockIsBeingErased(
Block *block);
899 void notifyCreatedBlock(
Block *block);
902 void notifySplitBlock(
Block *block,
Block *continuation);
905 void notifyBlockBeingInlined(
Block *block,
Block *srcBlock,
909 void notifyRegionIsBeingInlinedBefore(
Region ®ion,
Region &parent,
977 llvm::ScopedPrinter logger{llvm::dbgs()};
1006 state.resetOperation();
1020 for (
OpResult result : repl.first->getResults())
1026 if (repl.first->getNumRegions())
1037 arg.replaceAllUsesWith(repl);
1046 arg.replaceUsesWithIf(repl, [&](
OpOperand &operand) {
1056 mat.getOp()->erase();
1065 repl.first->dropAllUses();
1066 repl.first->erase();
1087 for (
unsigned i = state.numRootUpdates, e =
rootUpdates.size(); i != e; ++i)
1101 for (
auto &repl : llvm::drop_begin(
replacements, state.numReplacements))
1102 for (
auto result : repl.first->getResults())
1109 state.numUnresolvedMaterializations) {
1111 UnrealizedConversionCastOp op = mat.getOp();
1114 if (mat.getKind() == UnresolvedMaterialization::Target) {
1115 for (
Value input : op->getOperands())
1122 while (
createdOps.size() != state.numCreatedOps) {
1128 while (
ignoredOps.size() != state.numIgnoredOperations)
1139 if (action.kind == BlockActionKind::Erase)
1140 delete action.block;
1144 unsigned numActionsToKeep) {
1146 llvm::reverse(llvm::drop_begin(
blockActions, numActionsToKeep))) {
1147 switch (action.kind) {
1149 case BlockActionKind::Create: {
1152 auto &blockOps = action.block->getOperations();
1153 while (!blockOps.empty())
1154 blockOps.remove(blockOps.begin());
1155 action.block->dropAllDefinedValueUses();
1156 action.block->erase();
1160 case BlockActionKind::Erase: {
1161 auto &blockList = action.originalPosition.region->getBlocks();
1162 Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1163 blockList.insert((insertAfterBlock
1165 : blockList.
begin()),
1171 case BlockActionKind::Inline: {
1172 Block *sourceBlock = action.inlineInfo.sourceBlock;
1173 if (action.inlineInfo.firstInlinedInst) {
1174 assert(action.inlineInfo.lastInlinedInst &&
"expected operation");
1176 sourceBlock->
begin(), action.block->getOperations(),
1183 case BlockActionKind::Move: {
1184 Region *originalRegion = action.originalPosition.region;
1185 Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1188 : originalRegion->
end()),
1189 action.block->getParent()->getBlocks(), action.block);
1193 case BlockActionKind::Split: {
1194 action.originalBlock->getOperations().splice(
1195 action.originalBlock->end(), action.block->getOperations());
1196 action.block->dropAllDefinedValueUses();
1197 action.block->erase();
1201 case BlockActionKind::TypeConversion: {
1211 StringRef valueDiagTag, std::optional<Location> inputLoc,
1214 remapped.reserve(llvm::size(values));
1218 Value operand = it.value();
1230 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1231 << it.index() <<
", type was " << origType;
1236 if (legalTypes.size() == 1)
1237 desiredType = legalTypes.front();
1246 Value newOperand =
mapping.lookupOrDefault(operand, desiredType);
1257 newOperand = castValue;
1259 remapped.push_back(newOperand);
1277 [](
Region ®ion) { return !region.empty(); }))
1295 if (
Block *newBlock = *result) {
1296 if (newBlock != block)
1297 blockActions.push_back(BlockAction::getTypeConversion(newBlock));
1305 if (!region->
empty())
1314 if (region->
empty())
1329 if (region->
empty())
1334 assert((blockConversions.empty() ||
1335 blockConversions.size() == region->
getBlocks().size() - 1) &&
1336 "expected either to provide no SignatureConversions at all or to "
1337 "provide a SignatureConversion for each non-entry block");
1340 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1342 blockConversions.empty()
1345 &blockConversions[blockIdx++]);
1359 assert(!
replacements.count(op) &&
"operation was already replaced");
1362 bool resultChanged =
false;
1365 for (
auto [newValue, result] : llvm::zip(newValues, op->
getResults())) {
1367 resultChanged =
true;
1371 mapping.map(result, newValue);
1372 resultChanged |= (newValue.getType() != result.
getType());
1387 Block *origPrevBlock = block->getPrevNode();
1388 blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
1396 Block *continuation) {
1397 blockActions.push_back(BlockAction::getSplit(continuation, block));
1402 blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
1410 for (
auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
1412 BlockAction::getMove(laterBlock, {®ion, &earlierBlock}));
1413 laterBlock = &earlierBlock;
1415 blockActions.push_back(BlockAction::getMove(laterBlock, {®ion,
nullptr}));
1422 reasonCallback(
diag);
1423 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
1444 llvm::unique_function<
bool(
OpOperand &)
const> functor) {
1452 "replaceOpWithIf is currently not supported by DialectConversion");
1457 impl->logger.startLine()
1458 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
1460 impl->notifyOpReplaced(op, newValues);
1465 impl->logger.startLine()
1466 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
1469 impl->notifyOpReplaced(op, nullRepls);
1473 impl->notifyBlockIsBeingErased(block);
1489 return impl->applySignatureConversion(region, conversion, converter);
1495 return impl->convertRegionTypes(region, converter, entryConversion);
1501 return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
1508 impl->logger.startLine() <<
"** Replace Argument : '" << from
1509 <<
"'(in region of '" << parentOp->
getName()
1512 impl->argReplacements.push_back(from);
1513 impl->mapping.map(
impl->mapping.lookupOrDefault(from), to);
1518 if (
failed(
impl->remapValues(
"value", std::nullopt, *
this, key,
1521 return remappedValues.front();
1529 return impl->remapValues(
"value", std::nullopt, *
this, keys,
1534 impl->notifyCreatedBlock(block);
1540 impl->notifySplitBlock(block, continuation);
1541 return continuation;
1548 "incorrect # of argument replacement values");
1550 auto opIgnored = [&](
Operation *op) {
return impl->isOpIgnored(op); };
1554 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
1555 "expected 'source' to have no predecessors");
1557 impl->notifyBlockBeingInlined(dest, source, before);
1558 for (
auto it : llvm::zip(source->
getArguments(), argValues))
1567 impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
1582 impl->notifyCreatedBlock(cloned);
1590 impl->logger.startLine()
1591 <<
"** Insert : '" << op->
getName() <<
"'(" << op <<
")\n";
1593 impl->createdOps.push_back(op);
1598 impl->pendingRootUpdates.insert(op);
1600 impl->rootUpdates.emplace_back(op);
1608 assert(
impl->pendingRootUpdates.erase(op) &&
1609 "operation did not have a pending in-place update");
1615 assert(
impl->pendingRootUpdates.erase(op) &&
1616 "operation did not have a pending in-place update");
1619 auto stateHasOp = [op](
const auto &it) {
return it.getOperation() == op; };
1620 auto &rootUpdates =
impl->rootUpdates;
1621 auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
1622 assert(it != rootUpdates.rend() &&
"no root update started on op");
1623 (*it).resetOperation();
1624 int updateIdx = std::prev(rootUpdates.rend()) - it;
1625 rootUpdates.erase(rootUpdates.begin() + updateIdx);
1630 return impl->notifyMatchFailure(loc, reasonCallback);
1645 auto &rewriterImpl = dialectRewriter.
getImpl();
1648 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1653 if (
failed(rewriterImpl.remapValues(
"operand", op->
getLoc(), rewriter,
1669 class OperationLegalizer {
1704 RewriterState &curState);
1710 RewriterState &state,
1711 RewriterState &newState);
1714 RewriterState &state, RewriterState &newState);
1717 RewriterState &state,
1718 RewriterState &newState);
1728 void buildLegalizationGraph(
1729 LegalizationPatterns &anyOpLegalizerPatterns,
1740 void computeLegalizationGraphBenefit(
1741 LegalizationPatterns &anyOpLegalizerPatterns,
1746 unsigned computeOpLegalizationDepth(
1753 unsigned applyCostModelToPatterns(
1754 LegalizationPatterns &patterns,
1771 : target(targetInfo), applicator(patterns) {
1775 LegalizationPatterns anyOpLegalizerPatterns;
1777 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1778 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1781 bool OperationLegalizer::isIllegal(
Operation *op)
const {
1782 return target.isIllegal(op);
1786 OperationLegalizer::legalize(
Operation *op,
1789 const char *logLineComment =
1790 "//===-------------------------------------------===//\n";
1795 logger.getOStream() <<
"\n";
1796 logger.startLine() << logLineComment;
1797 logger.startLine() <<
"Legalizing operation : '" << op->
getName() <<
"'("
1803 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1804 logger.getOStream() <<
"\n\n";
1809 if (
auto legalityInfo = target.isLegal(op)) {
1812 logger,
"operation marked legal by the target{0}",
1813 legalityInfo->isRecursivelyLegal
1814 ?
"; NOTE: operation is recursively legal; skipping internals"
1816 logger.startLine() << logLineComment;
1821 if (legalityInfo->isRecursivelyLegal)
1829 logSuccess(logger,
"operation marked 'ignored' during conversion");
1830 logger.startLine() << logLineComment;
1838 if (
succeeded(legalizeWithFold(op, rewriter))) {
1841 logger.startLine() << logLineComment;
1847 if (
succeeded(legalizeWithPattern(op, rewriter))) {
1850 logger.startLine() << logLineComment;
1856 logFailure(logger,
"no matched legalization pattern");
1857 logger.startLine() << logLineComment;
1863 OperationLegalizer::legalizeWithFold(
Operation *op,
1865 auto &rewriterImpl = rewriter.
getImpl();
1869 rewriterImpl.logger.startLine() <<
"* Fold {\n";
1870 rewriterImpl.logger.indent();
1877 LLVM_DEBUG(
logFailure(rewriterImpl.logger,
"unable to fold"));
1882 rewriter.
replaceOp(op, replacementValues);
1885 for (
unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1887 Operation *cstOp = rewriterImpl.createdOps[i];
1888 if (
failed(legalize(cstOp, rewriter))) {
1890 "failed to legalize generated constant '{0}'",
1892 rewriterImpl.resetState(curState);
1897 LLVM_DEBUG(
logSuccess(rewriterImpl.logger,
""));
1902 OperationLegalizer::legalizeWithPattern(
Operation *op,
1904 auto &rewriterImpl = rewriter.
getImpl();
1907 auto canApply = [&](
const Pattern &pattern) {
1908 return canApplyPattern(op, pattern, rewriter);
1912 RewriterState curState = rewriterImpl.getCurrentState();
1913 auto onFailure = [&](
const Pattern &pattern) {
1915 logFailure(rewriterImpl.logger,
"pattern failed to match");
1916 if (rewriterImpl.notifyCallback) {
1918 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
1921 rewriterImpl.notifyCallback(
diag);
1924 rewriterImpl.resetState(curState);
1925 appliedPatterns.erase(&pattern);
1930 auto onSuccess = [&](
const Pattern &pattern) {
1931 auto result = legalizePatternResult(op, pattern, rewriter, curState);
1932 appliedPatterns.erase(&pattern);
1934 rewriterImpl.resetState(curState);
1939 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1943 bool OperationLegalizer::canApplyPattern(
Operation *op,
const Pattern &pattern,
1946 auto &os = rewriter.
getImpl().logger;
1947 os.getOStream() <<
"\n";
1948 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
1950 os.getOStream() <<
")' {\n";
1957 !appliedPatterns.insert(&pattern).second) {
1966 OperationLegalizer::legalizePatternResult(
Operation *op,
const Pattern &pattern,
1968 RewriterState &curState) {
1972 assert(
impl.pendingRootUpdates.empty() &&
"dangling root updates");
1976 auto replacedRoot = [&] {
1977 return llvm::any_of(
1978 llvm::drop_begin(
impl.replacements, curState.numReplacements),
1979 [op](
auto &it) { return it.first == op; });
1981 auto updatedRootInPlace = [&] {
1982 return llvm::any_of(
1983 llvm::drop_begin(
impl.rootUpdates, curState.numRootUpdates),
1984 [op](
auto &state) { return state.getOperation() == op; });
1987 (void)updatedRootInPlace;
1988 assert((replacedRoot() || updatedRootInPlace()) &&
1989 "expected pattern to replace the root operation");
1992 RewriterState newState =
impl.getCurrentState();
1993 if (
failed(legalizePatternBlockActions(op, rewriter,
impl, curState,
1995 failed(legalizePatternRootUpdates(rewriter,
impl, curState, newState)) ||
1996 failed(legalizePatternCreatedOperations(rewriter,
impl, curState,
2001 LLVM_DEBUG(
logSuccess(
impl.logger,
"pattern applied successfully"));
2005 LogicalResult OperationLegalizer::legalizePatternBlockActions(
2008 RewriterState &newState) {
2013 for (
int i = state.numBlockActions, e = newState.numBlockActions; i != e;
2015 auto &action =
impl.blockActions[i];
2016 if (action.kind == BlockActionKind::TypeConversion ||
2017 action.kind == BlockActionKind::Erase)
2021 if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
2026 if (
auto *converter =
2027 impl.argConverter.getConverter(action.block->getParent())) {
2028 if (
failed(
impl.convertBlockSignature(action.block, converter))) {
2029 LLVM_DEBUG(
logFailure(
impl.logger,
"failed to convert types of moved "
2040 if (operationsToIgnore.empty()) {
2042 .drop_front(state.numCreatedOps);
2043 operationsToIgnore.insert(createdOps.begin(), createdOps.end());
2047 if (operationsToIgnore.insert(parentOp).second &&
2048 failed(legalize(parentOp, rewriter))) {
2050 impl.logger,
"operation '{0}'({1}) became illegal after block action",
2051 parentOp->
getName(), parentOp));
2058 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2060 RewriterState &state, RewriterState &newState) {
2061 for (
int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
2063 if (
failed(legalize(op, rewriter))) {
2065 "failed to legalize generated operation '{0}'({1})",
2073 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2075 RewriterState &state, RewriterState &newState) {
2076 for (
int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
2078 if (
failed(legalize(op, rewriter))) {
2080 impl.logger,
"failed to legalize operation updated in-place '{0}'",
2091 void OperationLegalizer::buildLegalizationGraph(
2092 LegalizationPatterns &anyOpLegalizerPatterns,
2103 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2104 std::optional<OperationName> root = pattern.
getRootKind();
2110 anyOpLegalizerPatterns.push_back(&pattern);
2115 if (target.getOpAction(*root) == LegalizationAction::Legal)
2120 invalidPatterns[*root].insert(&pattern);
2122 parentOps[op].insert(*root);
2125 patternWorklist.insert(&pattern);
2133 if (!anyOpLegalizerPatterns.empty()) {
2134 for (
const Pattern *pattern : patternWorklist)
2135 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2139 while (!patternWorklist.empty()) {
2140 auto *pattern = patternWorklist.pop_back_val();
2144 std::optional<LegalizationAction> action = target.getOpAction(op);
2145 return !legalizerPatterns.count(op) &&
2146 (!action || action == LegalizationAction::Illegal);
2152 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2153 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2157 for (
auto op : parentOps[*pattern->
getRootKind()])
2158 patternWorklist.set_union(invalidPatterns[op]);
2162 void OperationLegalizer::computeLegalizationGraphBenefit(
2163 LegalizationPatterns &anyOpLegalizerPatterns,
2169 for (
auto &opIt : legalizerPatterns)
2170 if (!minOpPatternDepth.count(opIt.first))
2171 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2177 if (!anyOpLegalizerPatterns.empty())
2178 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2184 applicator.applyCostModel([&](
const Pattern &pattern) {
2186 if (std::optional<OperationName> rootName = pattern.
getRootKind())
2187 orderedPatternList = legalizerPatterns[*rootName];
2189 orderedPatternList = anyOpLegalizerPatterns;
2192 auto *it = llvm::find(orderedPatternList, &pattern);
2193 if (it == orderedPatternList.end())
2197 return PatternBenefit(std::distance(it, orderedPatternList.end()));
2201 unsigned OperationLegalizer::computeOpLegalizationDepth(
2205 auto depthIt = minOpPatternDepth.find(op);
2206 if (depthIt != minOpPatternDepth.end())
2207 return depthIt->second;
2211 auto opPatternsIt = legalizerPatterns.find(op);
2212 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2221 unsigned minDepth = applyCostModelToPatterns(
2222 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2223 minOpPatternDepth[op] = minDepth;
2227 unsigned OperationLegalizer::applyCostModelToPatterns(
2228 LegalizationPatterns &patterns,
2235 patternsByDepth.reserve(patterns.size());
2236 for (
const Pattern *pattern : patterns) {
2239 unsigned generatedOpDepth = computeOpLegalizationDepth(
2240 generatedOp, minOpPatternDepth, legalizerPatterns);
2241 depth =
std::max(depth, generatedOpDepth + 1);
2243 patternsByDepth.emplace_back(pattern, depth);
2246 minDepth =
std::min(minDepth, depth);
2251 if (patternsByDepth.size() == 1)
2255 std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2256 [](
const std::pair<const Pattern *, unsigned> &lhs,
2257 const std::pair<const Pattern *, unsigned> &rhs) {
2260 if (lhs.second != rhs.second)
2261 return lhs.second < rhs.second;
2264 auto lhsBenefit = lhs.first->getBenefit();
2265 auto rhsBenefit = rhs.first->getBenefit();
2266 return lhsBenefit > rhsBenefit;
2271 for (
auto &patternIt : patternsByDepth)
2272 patterns.push_back(patternIt.first);
2280 enum OpConversionMode {
2297 struct OperationConverter {
2300 OpConversionMode mode,
2302 : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2342 OperationLegalizer opLegalizer;
2345 OpConversionMode mode;
2358 if (
failed(opLegalizer.legalize(op, rewriter))) {
2361 if (mode == OpConversionMode::Full)
2363 <<
"failed to legalize operation '" << op->
getName() <<
"'";
2367 if (mode == OpConversionMode::Partial) {
2368 if (opLegalizer.isIllegal(op))
2370 <<
"failed to legalize operation '" << op->
getName()
2371 <<
"' that was explicitly marked illegal";
2373 trackedOps->insert(op);
2375 }
else if (mode == OpConversionMode::Analysis) {
2379 trackedOps->insert(op);
2393 for (
auto *op : ops) {
2396 toConvert.push_back(op);
2399 auto legalityInfo = target.
isLegal(op);
2400 if (legalityInfo && legalityInfo->isRecursivelyLegal)
2411 for (
auto *op : toConvert)
2412 if (
failed(convert(rewriter, op)))
2418 if (
failed(finalize(rewriter)))
2423 if (mode == OpConversionMode::Analysis) {
2433 trackedOps->erase(repl.first);
2440 std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
2442 if (
failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2444 failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2454 auto &repl = *(rewriterImpl.
replacements.begin() + replIdx);
2455 for (
OpResult result : repl.first->getResults()) {
2456 Value newValue = rewriterImpl.
mapping.lookupOrNull(result);
2461 if (
failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
2471 if (!inverseMapping)
2472 inverseMapping = rewriterImpl.
mapping.getInverse();
2476 if (
failed(legalizeChangedResultType(repl.first, result, newValue,
2477 repl.second.converter, rewriter,
2478 rewriterImpl, *inverseMapping)))
2489 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2494 auto findLiveUser = [&](
Value val) {
2495 auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](
Operation *user) {
2496 return rewriterImpl.isOpIgnored(user);
2498 return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2500 return rewriterImpl.
argConverter.materializeLiveConversions(
2501 rewriterImpl.
mapping, rewriter, findLiveUser);
2513 for (
auto [matResult, newValue] : llvm::zip(matResults, values)) {
2514 auto inverseMapIt = inverseMapping.find(matResult);
2515 if (inverseMapIt == inverseMapping.end())
2524 for (
Value inverseMapVal : inverseMapIt->second)
2525 if (!rewriterImpl.
mapping.tryMap(inverseMapVal, newValue))
2526 rewriterImpl.
mapping.erase(inverseMapVal);
2538 auto isLive = [&](
Value value) {
2540 auto matIt = materializationOps.find(user);
2541 if (matIt != materializationOps.end())
2542 return !necessaryMaterializations.count(matIt->second);
2546 for (
Value inv : inverseMapping.lookup(value))
2547 if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
2557 Value remappedValue = rewriterImpl.
mapping.lookupOrDefault(value, type);
2558 if (remappedValue.
getType() == type && remappedValue != invalidRoot)
2559 return remappedValue;
2564 auto inputCastOp = value.
getDefiningOp<UnrealizedConversionCastOp>();
2565 if (inputCastOp && inputCastOp->getNumOperands() == 1)
2566 return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
2574 materializationOps.try_emplace(mat.getOp(), &mat);
2575 worklist.insert(&mat);
2577 while (!worklist.empty()) {
2578 UnresolvedMaterialization *mat = worklist.pop_back_val();
2579 UnrealizedConversionCastOp op = mat->getOp();
2582 assert(op->getNumResults() == 1 &&
"unexpected materialization type");
2583 OpResult opResult = op->getOpResult(0);
2590 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
2593 if (castOp->getResultTypes() == inputOperands.
getTypes()) {
2596 necessaryMaterializations.remove(materializationOps.lookup(user));
2602 if (inputOperands.size() == 1) {
2605 Value remappedValue =
2606 lookupRemappedValue(opResult, inputOperands[0], outputType);
2607 if (remappedValue && remappedValue != opResult) {
2610 necessaryMaterializations.remove(mat);
2619 if (llvm::any_of(op->getOperands(), isBlockArg) ||
2620 llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) {
2630 bool isMaterializationLive = isLive(opResult);
2632 isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive);
2633 if (!isMaterializationLive)
2635 if (!necessaryMaterializations.insert(mat))
2639 for (
Value input : inputOperands) {
2640 if (
auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
2641 if (
auto *mat = materializationOps.lookup(parentOp))
2642 worklist.insert(mat);
2651 UnresolvedMaterialization &mat,
2656 auto findLiveUser = [&](
auto &&users) {
2657 auto liveUserIt = llvm::find_if_not(
2659 return liveUserIt == users.end() ? nullptr : *liveUserIt;
2666 Value remappedValue = rewriterImpl.
mapping.lookupOrDefault(value, type);
2667 if (remappedValue.
getType() == type)
2668 return remappedValue;
2672 UnrealizedConversionCastOp op = mat.getOp();
2677 OpResult opResult = op->getOpResult(0);
2683 for (
Value value : op->getOperands()) {
2684 auto valueCast = value.
getDefiningOp<UnrealizedConversionCastOp>();
2688 auto matIt = materializationOps.find(valueCast);
2689 if (matIt != materializationOps.end())
2691 *matIt->second, materializationOps, rewriter, rewriterImpl,
2699 if (inputOperands.size() == 1) {
2702 Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
2703 if (remappedValue && remappedValue != opResult) {
2716 if (inputOperands.size() == 1)
2721 Value newMaterialization;
2722 switch (mat.getKind()) {
2733 rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
2734 if (newMaterialization)
2740 case UnresolvedMaterialization::Target:
2742 rewriter, op->getLoc(), outputType, inputOperands);
2745 if (newMaterialization) {
2753 <<
"failed to legalize unresolved materialization "
2755 << inputOperands.
getTypes() <<
" to " << outputType
2756 <<
" that remained live after conversion";
2757 if (
Operation *liveUser = findLiveUser(op->getUsers())) {
2759 <<
"see existing live user here: " << *liveUser;
2764 LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
2770 inverseMapping = rewriterImpl.
mapping.getInverse();
2777 *inverseMapping, necessaryMaterializations);
2780 for (
auto *mat : necessaryMaterializations) {
2782 *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
2794 return rewriterImpl.isOpIgnored(user);
2796 if (liveUserIt != result.
user_end()) {
2798 << op->
getName() <<
"' marked as erased";
2799 diag.attachNote(liveUserIt->getLoc())
2813 while (!worklist.empty()) {
2814 Value value = worklist.pop_back_val();
2819 return rewriterImpl.isOpIgnored(user);
2821 if (liveUserIt != value.
user_end())
2823 auto mapIt = inverseMapping.find(value);
2824 if (mapIt != inverseMapping.end())
2825 worklist.append(mapIt->second);
2830 LogicalResult OperationConverter::legalizeChangedResultType(
2841 auto emitConversionError = [&] {
2843 <<
"failed to materialize conversion for result #"
2846 <<
"' that remained live after conversion";
2848 <<
"see existing live user here: " << *liveUser;
2855 return emitConversionError();
2860 rewriter, op->
getLoc(), resultType, newValue);
2861 if (!convertedValue)
2862 return emitConversionError();
2864 rewriterImpl.
mapping.map(result, convertedValue);
2874 assert(!types.empty() &&
"expected valid types");
2875 remapInput(origInputNo, argTypes.size(), types.size());
2880 assert(!types.empty() &&
2881 "1->0 type remappings don't need to be added explicitly");
2882 argTypes.append(types.begin(), types.end());
2886 unsigned newInputNo,
2887 unsigned newInputCount) {
2888 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
2889 assert(newInputCount != 0 &&
"expected valid input count");
2890 remappedInputs[origInputNo] =
2891 InputMapping{newInputNo, newInputCount,
nullptr};
2895 Value replacementValue) {
2896 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
2897 remappedInputs[origInputNo] =
2903 auto existingIt = cachedDirectConversions.find(t);
2904 if (existingIt != cachedDirectConversions.end()) {
2905 if (existingIt->second)
2906 results.push_back(existingIt->second);
2907 return success(existingIt->second !=
nullptr);
2909 auto multiIt = cachedMultiConversions.find(t);
2910 if (multiIt != cachedMultiConversions.end()) {
2911 results.append(multiIt->second.begin(), multiIt->second.end());
2917 size_t currentCount = results.size();
2918 conversionCallStack.push_back(t);
2919 auto popConversionCallStack =
2920 llvm::make_scope_exit([
this]() { conversionCallStack.pop_back(); });
2921 for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2922 if (std::optional<LogicalResult> result =
2923 converter(t, results, conversionCallStack)) {
2925 cachedDirectConversions.try_emplace(t,
nullptr);
2928 auto newTypes =
ArrayRef<Type>(results).drop_front(currentCount);
2929 if (newTypes.size() == 1)
2930 cachedDirectConversions.try_emplace(t, newTypes.front());
2932 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2946 return results.size() == 1 ? results.front() :
nullptr;
2951 for (
Type type : types)
2963 return llvm::all_of(*region, [
this](
Block &block) {
2969 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2980 if (convertedTypes.empty())
2984 result.
addInputs(inputNo, convertedTypes);
2989 unsigned origInputOffset) {
2990 for (
unsigned i = 0, e = types.size(); i != e; ++i)
2996 Value TypeConverter::materializeConversion(
2999 for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
3000 if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
3006 -> std::optional<SignatureConversion> {
3009 return std::nullopt;
3032 return impl.getInt() == resultTag;
3036 return impl.getInt() == naTag;
3040 return impl.getInt() == abortTag;
3044 assert(hasResult() &&
"Cannot get result from N/A or abort");
3045 return impl.getPointer();
3050 for (TypeAttributeConversionCallbackFn &fn :
3051 llvm::reverse(typeAttributeConversions)) {
3056 return std::nullopt;
3058 return std::nullopt;
3068 FunctionType type = funcOp.getFunctionType().cast<FunctionType>();
3076 typeConverter, &result)))
3080 auto newType = FunctionType::get(rewriter.
getContext(),
3081 result.getConvertedTypes(), newResults);
3093 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3101 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3106 struct AnyFunctionOpInterfaceSignatureConversion
3121 patterns.
add<FunctionOpInterfaceSignatureConversion>(
3122 functionLikeOpName, patterns.
getContext(), converter);
3127 patterns.
add<AnyFunctionOpInterfaceSignatureConversion>(
3137 legalOperations[op].action = action;
3142 for (StringRef dialect : dialectNames)
3143 legalDialects[dialect] = action;
3147 -> std::optional<LegalizationAction> {
3148 std::optional<LegalizationInfo> info = getOpInfo(op);
3149 return info ? info->action : std::optional<LegalizationAction>();
3153 -> std::optional<LegalOpDetails> {
3154 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3156 return std::nullopt;
3159 auto isOpLegal = [&] {
3161 if (info->action == LegalizationAction::Dynamic) {
3162 std::optional<bool> result = info->legalityFn(op);
3168 return info->action == LegalizationAction::Legal;
3171 return std::nullopt;
3175 if (info->isRecursivelyLegal) {
3176 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3177 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3179 legalityFnIt->second(op).value_or(
true);
3184 return legalityDetails;
3188 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3192 if (info->action == LegalizationAction::Dynamic) {
3193 std::optional<bool> result = info->legalityFn(op);
3200 return info->action == LegalizationAction::Illegal;
3209 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3211 if (std::optional<bool> result = newCl(op))
3219 void ConversionTarget::setLegalityCallback(
3220 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3221 assert(callback &&
"expected valid legality callback");
3222 auto infoIt = legalOperations.find(name);
3223 assert(infoIt != legalOperations.end() &&
3224 infoIt->second.action == LegalizationAction::Dynamic &&
3225 "expected operation to already be marked as dynamically legal");
3226 infoIt->second.legalityFn =
3232 auto infoIt = legalOperations.find(name);
3233 assert(infoIt != legalOperations.end() &&
3234 infoIt->second.action != LegalizationAction::Illegal &&
3235 "expected operation to already be marked as legal");
3236 infoIt->second.isRecursivelyLegal =
true;
3239 std::move(opRecursiveLegalityFns[name]), callback);
3241 opRecursiveLegalityFns.erase(name);
3244 void ConversionTarget::setLegalityCallback(
3246 assert(callback &&
"expected valid legality callback");
3247 for (StringRef dialect : dialects)
3249 std::move(dialectLegalityFns[dialect]), callback);
3252 void ConversionTarget::setLegalityCallback(
3253 const DynamicLegalityCallbackFn &callback) {
3254 assert(callback &&
"expected valid legality callback");
3259 -> std::optional<LegalizationInfo> {
3261 auto it = legalOperations.find(op);
3262 if (it != legalOperations.end())
3265 auto dialectIt = legalDialects.find(op.getDialectNamespace());
3266 if (dialectIt != legalDialects.end()) {
3267 DynamicLegalityCallbackFn callback;
3268 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3269 if (dialectFn != dialectLegalityFns.end())
3270 callback = dialectFn->second;
3271 return LegalizationInfo{dialectIt->second,
false,
3275 if (unknownLegalityFn)
3276 return LegalizationInfo{LegalizationAction::Dynamic,
3277 false, unknownLegalityFn};
3278 return std::nullopt;
3286 auto &rewriterImpl =
3292 auto &rewriterImpl =
3304 return std::move(mappedValues);
3315 return results->front();
3325 auto &rewriterImpl =
3338 auto &rewriterImpl =
3347 return std::move(remappedTypes);
3363 OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
3365 return opConverter.convertOperations(ops);
3381 OperationConverter opConverter(target, patterns, OpConversionMode::Full);
3382 return opConverter.convertOperations(ops);
3399 OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
3401 return opConverter.convertOperations(ops, notifyCallback);
3409 convertedOps, notifyCallback);
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void detachNestedAndErase(Operation *op)
Detach any operations nested in the given operation from their parent blocks, and erase the given ope...
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static Value buildUnresolvedTargetMaterialization(Location loc, Value input, Type outputType, TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
static void computeNecessaryMaterializations(DenseMap< Operation *, UnresolvedMaterialization * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping, SetVector< UnresolvedMaterialization * > &necessaryMaterializations)
Compute all of the unresolved materializations that will persist beyond the conversion process,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, ResultRange matResults, ValueRange values, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Replace the results of a materialization operation with the given values.
static Value buildUnresolvedMaterialization(UnresolvedMaterialization::Kind kind, Block *insertBlock, Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, Type origOutputType, TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
Build an unresolved materialization operation given an output type and set of input operands.
static LogicalResult legalizeUnresolvedMaterialization(UnresolvedMaterialization &mat, DenseMap< Operation *, UnresolvedMaterialization * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Legalize the given unresolved materialization.
static Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter, Location loc, ValueRange inputs, Type origOutputType, Type outputType, TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
void erase()
Unlink this Block from its parent region and delete it.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
RetT walk(FnT &&callback)
Walk the operations in this block.
OpListType & getOperations()
BlockArgListType getArguments()
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult convertNonEntryRegionTypes(Region *region, TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void finalizeRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor) override
PatternRewriter hook for replacing the results of an operation when the given functor returns true.
void notifyBlockCreated(Block *block) override
PatternRewriter hook creating a new block.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
ConversionPatternRewriter(MLIRContext *ctx)
void notifyOperationInserted(Operation *op) override
PatternRewriter hook for inserting a new operation.
void cancelRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
void startRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
~ConversionPatternRewriter() override
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping) override
PatternRewriter hook for cloning blocks of one region into another.
Base class for the conversion patterns.
TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class provides support for representing a failure result, or a valid value of type T.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void dropAllUses()
Drop all uses of this object from their respective owners.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class represents a diagnostic that is inflight and set to be reported.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within 'results'.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
void dropAllUses()
Drop all uses of results of this operation.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function with PDL.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
This class implements the result iterators for the Operation class.
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
MLIRContext * getContext() const
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
The general result of a type attribute conversion callback, allowing for early termination.
Attribute getResult() const
static AttributeConversionResult abort()
static AttributeConversionResult na()
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
bool isSignatureLegal(FunctionType ty)
Return true if the inputs and outputs of the given function type are legal.
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result)
This method allows for converting a specific argument of a signature.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results)
Convert the given set of types, filling 'results' as necessary.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0)
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr)
Convert an attribute present attr from within the type type using the registered conversion functions...
bool isLegal(Type type)
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block)
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs)
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
void dropAllUses() const
Drop all uses of this object from their respective owners.
Type getType() const
Return the type of this value.
user_iterator user_end() const
Block * getParentBlock()
Return the Block in which this Value is defined.
void replaceAllUsesWith(Value newValue) const
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
Detect if any of the given parameter types has a sub-element handler.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
llvm::PointerUnion< NamedAttribute *, NamedTypeConstraint * > Argument
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > &convertedOps, function_ref< void(Diagnostic &)> notifyCallback=nullptr)
Apply an analysis conversion on the given operations, and all nested operations.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
This class represents an efficient way to signal success or failure.
void eraseDanglingBlocks()
Erase any blocks that were unlinked from their regions and stored in block actions.
TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
void undoBlockActions(unsigned numActionsToKeep=0)
Undo the block actions (motions, splits) one by one in reverse order until "numActionsToKeep" actions...
void discardRewrites()
Cleanup and destroy any generated rewrite operations.
ConversionPatternRewriterImpl(PatternRewriter &rewriter)
function_ref< void(Diagnostic &)> notifyCallback
This allows the user to collect the match failure message.
SmallVector< BlockAction, 4 > blockActions
Ordered list of block operations (creations, splits, motions).
llvm::MapVector< Operation *, OpReplacement > replacements
Ordered map of requested operation replacements.
LogicalResult convertNonEntryRegionTypes(Region *region, TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions={})
Convert the types of non-entry block arguments within the given region.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent, Region::iterator before)
Notifies that the blocks of a region are about to be moved.
void notifySplitBlock(Block *block, Block *continuation)
Notifies that a block was split.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
SmallVector< OperationTransactionState, 4 > rootUpdates
A transaction state for each of operations that were updated in-place.
RewriterState getCurrentState()
Return the current state of the rewriter.
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, TypeConverter *converter)
Apply a signature conversion on the given region, using converter for materializations if not null.
void notifyOpReplaced(Operation *op, ValueRange newValues)
PatternRewriter hook for replacing the results of an operation.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
SmallVector< UnresolvedMaterialization > unresolvedMaterializations
Ordered vector of all unresolved type conversion materializations during conversion.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
ArgConverter argConverter
Utility used to convert block arguments.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
bool isOpIgnored(Operation *op) const
Returns true if the given operation is ignored, and does not need to be converted.
FailureOr< Block * > convertBlockSignature(Block *block, TypeConverter *converter, TypeConverter::SignatureConversion *conversion=nullptr)
Convert the signature of the given block.
void markNestedOpsIgnored(Operation *op)
Recursively marks the nested operations under 'op' as ignored.
SmallVector< Operation * > createdOps
Ordered vector of all of the newly created operations during conversion.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notifies that a pattern match failed for the given reason.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization, but were not directly repla...
SmallVector< BlockArgument, 4 > argReplacements
Ordered vector of any requested block argument replacements.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
void notifyCreatedBlock(Block *block)
Notifies that a block was created.
SmallVector< unsigned, 4 > operationsWithChangedResults
A vector of indices into replacements of operations that were replaced with values with different res...