17 #include "llvm/ADT/ScopeExit.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "llvm/Support/SaveAndRestore.h"
23 #include "llvm/Support/ScopedPrinter.h"
29 #define DEBUG_TYPE "dialect-conversion"
32 template <
typename... Args>
33 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
36 os.startLine() <<
"} -> SUCCESS";
38 os.getOStream() <<
" : "
39 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
40 os.getOStream() <<
"\n";
45 template <
typename... Args>
46 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
49 os.startLine() <<
"} -> FAILURE : "
50 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
62 struct ConversionValueMapping {
67 Value lookupOrDefault(
Value from,
Type desiredType =
nullptr)
const;
77 for (
Value it = newVal; it; it = mapping.lookupOrNull(it))
78 assert(it != oldVal &&
"inserting cyclic mapping");
80 mapping.map(oldVal, newVal);
89 void erase(
Value value) { mapping.erase(value); }
94 for (
auto &it : mapping.getValueMap())
95 inverse[it.second].push_back(it.first);
105 Value ConversionValueMapping::lookupOrDefault(
Value from,
106 Type desiredType)
const {
111 while (
auto mappedValue = mapping.lookupOrNull(from))
119 if (from.
getType() == desiredType)
122 Value mappedValue = mapping.lookupOrNull(from);
129 return desiredValue ? desiredValue : from;
132 Value ConversionValueMapping::lookupOrNull(
Value from,
Type desiredType)
const {
133 Value result = lookupOrDefault(from, desiredType);
134 if (result == from || (desiredType && result.
getType() != desiredType))
139 bool ConversionValueMapping::tryMap(
Value oldVal,
Value newVal) {
140 for (
Value it = newVal; it; it = mapping.lookupOrNull(it))
153 struct RewriterState {
154 RewriterState(
unsigned numCreatedOps,
unsigned numUnresolvedMaterializations,
155 unsigned numReplacements,
unsigned numArgReplacements,
156 unsigned numBlockActions,
unsigned numIgnoredOperations,
157 unsigned numRootUpdates)
158 : numCreatedOps(numCreatedOps),
159 numUnresolvedMaterializations(numUnresolvedMaterializations),
160 numReplacements(numReplacements),
161 numArgReplacements(numArgReplacements),
162 numBlockActions(numBlockActions),
163 numIgnoredOperations(numIgnoredOperations),
164 numRootUpdates(numRootUpdates) {}
167 unsigned numCreatedOps;
170 unsigned numUnresolvedMaterializations;
173 unsigned numReplacements;
176 unsigned numArgReplacements;
179 unsigned numBlockActions;
182 unsigned numIgnoredOperations;
185 unsigned numRootUpdates;
194 class OperationTransactionState {
196 OperationTransactionState() =
default;
198 : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
199 operands(op->operand_begin(), op->operand_end()),
200 successors(op->successor_begin(), op->successor_end()) {}
204 void resetOperation()
const {
213 Operation *getOperation()
const {
return op; }
218 DictionaryAttr attrs;
228 struct OpReplacement {
230 : converter(converter) {}
242 enum class BlockActionKind {
253 struct BlockPosition {
255 Block *insertAfterBlock;
271 static BlockAction getCreate(
Block *block) {
272 return {BlockActionKind::Create, block, {}};
274 static BlockAction getErase(
Block *block, BlockPosition originalPosition) {
275 return {BlockActionKind::Erase, block, {originalPosition}};
277 static BlockAction getInline(
Block *block,
Block *srcBlock,
279 BlockAction action{BlockActionKind::Inline, block, {}};
280 action.inlineInfo = {srcBlock,
281 srcBlock->
empty() ? nullptr : &srcBlock->
front(),
282 srcBlock->
empty() ? nullptr : &srcBlock->
back()};
285 static BlockAction getMove(
Block *block, BlockPosition originalPosition) {
286 return {BlockActionKind::Move, block, {originalPosition}};
288 static BlockAction getSplit(
Block *block,
Block *originalBlock) {
289 BlockAction action{BlockActionKind::Split, block, {}};
290 action.originalBlock = originalBlock;
293 static BlockAction getTypeConversion(
Block *block) {
294 return BlockAction{BlockActionKind::TypeConversion, block, {}};
298 BlockActionKind kind;
307 BlockPosition originalPosition;
310 Block *originalBlock;
313 InlineInfo inlineInfo;
323 class UnresolvedMaterialization {
336 UnresolvedMaterialization(UnrealizedConversionCastOp op =
nullptr,
338 Kind kind = Target,
Type origOutputType =
nullptr)
339 : op(op), converterAndKind(converter, kind),
340 origOutputType(origOutputType) {}
344 UnrealizedConversionCastOp getOp()
const {
return op; }
348 return converterAndKind.getPointer();
352 Kind getKind()
const {
return converterAndKind.getInt(); }
355 void setKind(
Kind kind) { converterAndKind.setInt(kind); }
358 Type getOrigOutputType()
const {
return origOutputType; }
362 UnrealizedConversionCastOp op;
366 llvm::PointerIntPair<const TypeConverter *, 1, Kind> converterAndKind;
376 UnresolvedMaterialization::Kind kind,
Block *insertBlock,
381 if (inputs.size() == 1 && inputs.front().
getType() == outputType)
382 return inputs.front();
386 OpBuilder builder(insertBlock, insertPt);
388 builder.
create<UnrealizedConversionCastOp>(loc, outputType, inputs);
389 unresolvedMaterializations.emplace_back(convertOp, converter, kind,
391 return convertOp.getResult(0);
400 converter, unresolvedMaterializations);
407 if (
OpResult inputRes = dyn_cast<OpResult>(input))
408 insertPt = ++inputRes.getOwner()->getIterator();
411 UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
412 outputType, outputType, converter, unresolvedMaterializations);
423 struct ArgConverter {
427 : rewriter(rewriter),
428 unresolvedMaterializations(unresolvedMaterializations) {}
432 struct ConvertedArgInfo {
433 ConvertedArgInfo(
unsigned newArgIdx,
unsigned newArgSize,
434 Value castValue =
nullptr)
435 : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
451 struct ConvertedBlockInfo {
453 : origBlock(origBlock), converter(converter) {}
467 bool hasBeenConverted(
Block *block)
const {
468 return conversionInfo.count(block) || convertedBlocks.count(block);
473 assert(typeConverter &&
"expected valid type converter");
474 regionToConverter[region] = typeConverter;
480 return regionToConverter.lookup(region);
495 void discardRewrites(
Block *block);
498 void applyRewrites(ConversionValueMapping &mapping);
504 materializeLiveConversions(ConversionValueMapping &mapping,
517 ConversionValueMapping &mapping,
526 Block *applySignatureConversion(
529 ConversionValueMapping &mapping,
533 void insertConversion(
Block *newBlock, ConvertedBlockInfo &&info);
537 llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
561 void ArgConverter::notifyOpRemoved(
Operation *op) {
562 if (conversionInfo.empty())
566 for (
Block &block : region) {
569 if (nestedOp.getNumRegions())
570 notifyOpRemoved(&nestedOp);
573 auto it = conversionInfo.find(&block);
574 if (it == conversionInfo.end())
578 Block *origBlock = it->second.origBlock;
581 conversionInfo.erase(it);
586 void ArgConverter::discardRewrites(
Block *block) {
587 auto it = conversionInfo.find(block);
588 if (it == conversionInfo.end())
590 Block *origBlock = it->second.origBlock;
602 convertedBlocks.erase(origBlock);
603 conversionInfo.erase(it);
606 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
607 for (
auto &info : conversionInfo) {
608 ConvertedBlockInfo &blockInfo = info.second;
609 Block *origBlock = blockInfo.origBlock;
613 std::optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
618 if (
Value newArg = mapping.lookupOrNull(origArg, origArg.
getType()))
624 Value castValue = argInfo->castValue;
625 assert(argInfo->newArgSize >= 1 && castValue &&
"expected 1->1+ mapping");
630 mapping.lookupOrDefault(castValue, origArg.
getType()));
637 ConversionValueMapping &mapping,
OpBuilder &builder,
639 for (
auto &info : conversionInfo) {
640 Block *newBlock = info.first;
641 ConvertedBlockInfo &blockInfo = info.second;
642 Block *origBlock = blockInfo.origBlock;
649 if (mapping.lookupOrNull(origArg, origArg.
getType()))
651 Operation *liveUser = findLiveUser(origArg);
655 Value replacementValue = mapping.lookupOrDefault(origArg);
656 bool isDroppedArg = replacementValue == origArg;
658 rewriter.setInsertionPointToStart(newBlock);
660 rewriter.setInsertionPointAfterValue(replacementValue);
662 if (blockInfo.converter) {
663 newArg = blockInfo.converter->materializeSourceConversion(
667 "materialization hook did not provide a value of the expected "
673 <<
"failed to materialize conversion for block argument #" << i
674 <<
" that remained live after conversion, type was "
677 diag <<
", with target type " << replacementValue.
getType();
679 <<
"see existing live user here: " << *liveUser;
682 mapping.map(origArg, newArg);
693 ConversionValueMapping &mapping,
697 if (hasBeenConverted(block) || !block->
getParent())
706 return applySignatureConversion(block, converter, *conversion, mapping,
711 Block *ArgConverter::applySignatureConversion(
714 ConversionValueMapping &mapping,
719 if (origArgCount == 0 && convertedTypes.empty())
729 rewriter.getUnknownLoc());
730 for (
unsigned i = 0; i < origArgCount; ++i) {
732 if (!inputMap || inputMap->replacementValue)
735 for (
unsigned j = 0;
j < inputMap->size; ++
j)
736 newLocs[inputMap->inputNo +
j] = origLoc;
745 ConvertedBlockInfo info(block, converter);
746 info.argInfo.resize(origArgCount);
749 rewriter.setInsertionPointToStart(newBlock);
750 for (
unsigned i = 0; i != origArgCount; ++i) {
758 if (inputMap->replacementValue) {
759 assert(inputMap->size == 0 &&
760 "invalid to provide a replacement value when the argument isn't "
762 mapping.map(origArg, inputMap->replacementValue);
763 argReplacements.push_back(origArg);
768 auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
780 if (replArgs.size() == 1 &&
781 (!converter || replArgs[0].getType() == origArg.
getType())) {
782 newArg = replArgs.front();
787 Type outputType = origOutputType;
789 outputType = legalOutputType;
792 rewriter, origArg.
getLoc(), replArgs, origOutputType, outputType,
793 converter, unresolvedMaterializations);
796 mapping.map(origArg, newArg);
797 argReplacements.push_back(origArg);
799 ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
803 insertConversion(newBlock, std::move(info));
807 void ArgConverter::insertConversion(
Block *newBlock,
808 ConvertedBlockInfo &&info) {
811 std::unique_ptr<Region> &mappedRegion = regionMapping[region];
813 mappedRegion = std::make_unique<Region>(region->
getParentOp());
816 mappedRegion->getBlocks().splice(mappedRegion->end(), region->
getBlocks(),
817 info.origBlock->getIterator());
818 convertedBlocks.insert(info.origBlock);
819 conversionInfo.insert({newBlock, std::move(info)});
829 : argConverter(rewriter, unresolvedMaterializations),
830 notifyCallback(nullptr) {}
834 void discardRewrites();
838 void applyRewrites();
845 RewriterState getCurrentState();
848 void resetState(RewriterState state);
852 void eraseDanglingBlocks();
856 void undoBlockActions(
unsigned numActionsToKeep = 0);
863 std::optional<Location> inputLoc,
873 void markNestedOpsIgnored(
Operation *op);
887 applySignatureConversion(
Region *region,
909 void notifyBlockIsBeingErased(
Block *block);
912 void notifyCreatedBlock(
Block *block);
915 void notifySplitBlock(
Block *block,
Block *continuation);
918 void notifyBlockBeingInlined(
Block *block,
Block *srcBlock,
922 void notifyRegionIsBeingInlinedBefore(
Region ®ion,
Region &parent,
990 llvm::ScopedPrinter logger{llvm::dbgs()};
1019 state.resetOperation();
1033 for (
OpResult result : repl.first->getResults())
1039 if (repl.first->getNumRegions())
1049 if (isa<BlockArgument>(repl)) {
1050 arg.replaceAllUsesWith(repl);
1057 Operation *replOp = cast<OpResult>(repl).getOwner();
1059 arg.replaceUsesWithIf(repl, [&](
OpOperand &operand) {
1069 mat.getOp()->erase();
1078 repl.first->dropAllUses();
1079 repl.first->erase();
1100 for (
unsigned i = state.numRootUpdates, e =
rootUpdates.size(); i != e; ++i)
1114 for (
auto &repl : llvm::drop_begin(
replacements, state.numReplacements))
1115 for (
auto result : repl.first->getResults())
1122 state.numUnresolvedMaterializations) {
1124 UnrealizedConversionCastOp op = mat.getOp();
1127 if (mat.getKind() == UnresolvedMaterialization::Target) {
1135 while (
createdOps.size() != state.numCreatedOps) {
1141 while (
ignoredOps.size() != state.numIgnoredOperations)
1152 if (action.kind == BlockActionKind::Erase)
1153 delete action.block;
1157 unsigned numActionsToKeep) {
1159 llvm::reverse(llvm::drop_begin(
blockActions, numActionsToKeep))) {
1160 switch (action.kind) {
1162 case BlockActionKind::Create: {
1165 auto &blockOps = action.block->getOperations();
1166 while (!blockOps.empty())
1167 blockOps.remove(blockOps.begin());
1168 action.block->dropAllDefinedValueUses();
1169 action.block->erase();
1173 case BlockActionKind::Erase: {
1174 auto &blockList = action.originalPosition.region->getBlocks();
1175 Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1176 blockList.insert((insertAfterBlock
1178 : blockList.
begin()),
1184 case BlockActionKind::Inline: {
1185 Block *sourceBlock = action.inlineInfo.sourceBlock;
1186 if (action.inlineInfo.firstInlinedInst) {
1187 assert(action.inlineInfo.lastInlinedInst &&
"expected operation");
1189 sourceBlock->
begin(), action.block->getOperations(),
1196 case BlockActionKind::Move: {
1197 Region *originalRegion = action.originalPosition.region;
1198 Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
1201 : originalRegion->
end()),
1202 action.block->getParent()->getBlocks(), action.block);
1206 case BlockActionKind::Split: {
1207 action.originalBlock->getOperations().splice(
1208 action.originalBlock->end(), action.block->getOperations());
1209 action.block->dropAllDefinedValueUses();
1210 action.block->erase();
1214 case BlockActionKind::TypeConversion: {
1224 StringRef valueDiagTag, std::optional<Location> inputLoc,
1227 remapped.reserve(llvm::size(values));
1231 Value operand = it.value();
1243 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1244 << it.index() <<
", type was " << origType;
1249 if (legalTypes.size() == 1)
1250 desiredType = legalTypes.front();
1259 Value newOperand =
mapping.lookupOrDefault(operand, desiredType);
1270 newOperand = castValue;
1272 remapped.push_back(newOperand);
1290 [](
Region ®ion) { return !region.empty(); }))
1308 if (
Block *newBlock = *result) {
1309 if (newBlock != block)
1310 blockActions.push_back(BlockAction::getTypeConversion(newBlock));
1318 if (!region->
empty())
1327 if (region->
empty())
1342 if (region->
empty())
1347 assert((blockConversions.empty() ||
1348 blockConversions.size() == region->
getBlocks().size() - 1) &&
1349 "expected either to provide no SignatureConversions at all or to "
1350 "provide a SignatureConversion for each non-entry block");
1353 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1355 blockConversions.empty()
1358 &blockConversions[blockIdx++]);
1372 assert(!
replacements.count(op) &&
"operation was already replaced");
1375 bool resultChanged =
false;
1378 for (
auto [newValue, result] : llvm::zip(newValues, op->
getResults())) {
1380 resultChanged =
true;
1384 mapping.map(result, newValue);
1385 resultChanged |= (newValue.getType() != result.
getType());
1400 Block *origPrevBlock = block->getPrevNode();
1401 blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
1409 Block *continuation) {
1410 blockActions.push_back(BlockAction::getSplit(continuation, block));
1415 blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
1423 for (
auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
1425 BlockAction::getMove(laterBlock, {®ion, &earlierBlock}));
1426 laterBlock = &earlierBlock;
1428 blockActions.push_back(BlockAction::getMove(laterBlock, {®ion,
nullptr}));
1435 reasonCallback(
diag);
1436 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
1457 llvm::unique_function<
bool(
OpOperand &)
const> functor) {
1465 "replaceOpWithIf is currently not supported by DialectConversion");
1469 assert(op && newOp &&
"expected non-null op");
1475 "incorrect # of replacement values");
1477 impl->logger.startLine()
1478 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
1480 impl->notifyOpReplaced(op, newValues);
1485 impl->logger.startLine()
1486 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
1489 impl->notifyOpReplaced(op, nullRepls);
1493 impl->notifyBlockIsBeingErased(block);
1509 return impl->applySignatureConversion(region, conversion, converter);
1515 return impl->convertRegionTypes(region, converter, entryConversion);
1521 return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
1528 impl->logger.startLine() <<
"** Replace Argument : '" << from
1529 <<
"'(in region of '" << parentOp->
getName()
1532 impl->argReplacements.push_back(from);
1533 impl->mapping.map(
impl->mapping.lookupOrDefault(from), to);
1538 if (
failed(
impl->remapValues(
"value", std::nullopt, *
this, key,
1541 return remappedValues.front();
1549 return impl->remapValues(
"value", std::nullopt, *
this, keys,
1554 impl->notifyCreatedBlock(block);
1560 impl->notifySplitBlock(block, continuation);
1561 return continuation;
1568 "incorrect # of argument replacement values");
1570 auto opIgnored = [&](
Operation *op) {
return impl->isOpIgnored(op); };
1574 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
1575 "expected 'source' to have no predecessors");
1577 impl->notifyBlockBeingInlined(dest, source, before);
1578 for (
auto it : llvm::zip(source->
getArguments(), argValues))
1587 impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
1602 impl->notifyCreatedBlock(cloned);
1610 impl->logger.startLine()
1611 <<
"** Insert : '" << op->
getName() <<
"'(" << op <<
")\n";
1613 impl->createdOps.push_back(op);
1618 impl->pendingRootUpdates.insert(op);
1620 impl->rootUpdates.emplace_back(op);
1628 assert(
impl->pendingRootUpdates.erase(op) &&
1629 "operation did not have a pending in-place update");
1635 assert(
impl->pendingRootUpdates.erase(op) &&
1636 "operation did not have a pending in-place update");
1639 auto stateHasOp = [op](
const auto &it) {
return it.getOperation() == op; };
1640 auto &rootUpdates =
impl->rootUpdates;
1641 auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
1642 assert(it != rootUpdates.rend() &&
"no root update started on op");
1643 (*it).resetOperation();
1644 int updateIdx = std::prev(rootUpdates.rend()) - it;
1645 rootUpdates.erase(rootUpdates.begin() + updateIdx);
1650 return impl->notifyMatchFailure(loc, reasonCallback);
1665 auto &rewriterImpl = dialectRewriter.
getImpl();
1668 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1673 if (
failed(rewriterImpl.remapValues(
"operand", op->
getLoc(), rewriter,
1689 class OperationLegalizer {
1724 RewriterState &curState);
1730 RewriterState &state,
1731 RewriterState &newState);
1734 RewriterState &state, RewriterState &newState);
1737 RewriterState &state,
1738 RewriterState &newState);
1748 void buildLegalizationGraph(
1749 LegalizationPatterns &anyOpLegalizerPatterns,
1760 void computeLegalizationGraphBenefit(
1761 LegalizationPatterns &anyOpLegalizerPatterns,
1766 unsigned computeOpLegalizationDepth(
1773 unsigned applyCostModelToPatterns(
1774 LegalizationPatterns &patterns,
1791 : target(targetInfo), applicator(patterns) {
1795 LegalizationPatterns anyOpLegalizerPatterns;
1797 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1798 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1801 bool OperationLegalizer::isIllegal(
Operation *op)
const {
1802 return target.isIllegal(op);
1806 OperationLegalizer::legalize(
Operation *op,
1809 const char *logLineComment =
1810 "//===-------------------------------------------===//\n";
1815 logger.getOStream() <<
"\n";
1816 logger.startLine() << logLineComment;
1817 logger.startLine() <<
"Legalizing operation : '" << op->
getName() <<
"'("
1823 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1824 logger.getOStream() <<
"\n\n";
1829 if (
auto legalityInfo = target.isLegal(op)) {
1832 logger,
"operation marked legal by the target{0}",
1833 legalityInfo->isRecursivelyLegal
1834 ?
"; NOTE: operation is recursively legal; skipping internals"
1836 logger.startLine() << logLineComment;
1841 if (legalityInfo->isRecursivelyLegal)
1849 logSuccess(logger,
"operation marked 'ignored' during conversion");
1850 logger.startLine() << logLineComment;
1858 if (
succeeded(legalizeWithFold(op, rewriter))) {
1861 logger.startLine() << logLineComment;
1867 if (
succeeded(legalizeWithPattern(op, rewriter))) {
1870 logger.startLine() << logLineComment;
1876 logFailure(logger,
"no matched legalization pattern");
1877 logger.startLine() << logLineComment;
1883 OperationLegalizer::legalizeWithFold(
Operation *op,
1885 auto &rewriterImpl = rewriter.
getImpl();
1889 rewriterImpl.logger.startLine() <<
"* Fold {\n";
1890 rewriterImpl.logger.indent();
1897 LLVM_DEBUG(
logFailure(rewriterImpl.logger,
"unable to fold"));
1902 rewriter.
replaceOp(op, replacementValues);
1905 for (
unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
1907 Operation *cstOp = rewriterImpl.createdOps[i];
1908 if (
failed(legalize(cstOp, rewriter))) {
1910 "failed to legalize generated constant '{0}'",
1912 rewriterImpl.resetState(curState);
1917 LLVM_DEBUG(
logSuccess(rewriterImpl.logger,
""));
1922 OperationLegalizer::legalizeWithPattern(
Operation *op,
1924 auto &rewriterImpl = rewriter.
getImpl();
1927 auto canApply = [&](
const Pattern &pattern) {
1928 return canApplyPattern(op, pattern, rewriter);
1932 RewriterState curState = rewriterImpl.getCurrentState();
1933 auto onFailure = [&](
const Pattern &pattern) {
1935 logFailure(rewriterImpl.logger,
"pattern failed to match");
1936 if (rewriterImpl.notifyCallback) {
1938 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
1941 rewriterImpl.notifyCallback(
diag);
1944 rewriterImpl.resetState(curState);
1945 appliedPatterns.erase(&pattern);
1950 auto onSuccess = [&](
const Pattern &pattern) {
1951 auto result = legalizePatternResult(op, pattern, rewriter, curState);
1952 appliedPatterns.erase(&pattern);
1954 rewriterImpl.resetState(curState);
1959 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1963 bool OperationLegalizer::canApplyPattern(
Operation *op,
const Pattern &pattern,
1966 auto &os = rewriter.
getImpl().logger;
1967 os.getOStream() <<
"\n";
1968 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
1970 os.getOStream() <<
")' {\n";
1977 !appliedPatterns.insert(&pattern).second) {
1986 OperationLegalizer::legalizePatternResult(
Operation *op,
const Pattern &pattern,
1988 RewriterState &curState) {
1992 assert(
impl.pendingRootUpdates.empty() &&
"dangling root updates");
1996 auto replacedRoot = [&] {
1997 return llvm::any_of(
1998 llvm::drop_begin(
impl.replacements, curState.numReplacements),
1999 [op](
auto &it) { return it.first == op; });
2001 auto updatedRootInPlace = [&] {
2002 return llvm::any_of(
2003 llvm::drop_begin(
impl.rootUpdates, curState.numRootUpdates),
2004 [op](
auto &state) { return state.getOperation() == op; });
2007 (void)updatedRootInPlace;
2008 assert((replacedRoot() || updatedRootInPlace()) &&
2009 "expected pattern to replace the root operation");
2012 RewriterState newState =
impl.getCurrentState();
2013 if (
failed(legalizePatternBlockActions(op, rewriter,
impl, curState,
2015 failed(legalizePatternRootUpdates(rewriter,
impl, curState, newState)) ||
2016 failed(legalizePatternCreatedOperations(rewriter,
impl, curState,
2021 LLVM_DEBUG(
logSuccess(
impl.logger,
"pattern applied successfully"));
2025 LogicalResult OperationLegalizer::legalizePatternBlockActions(
2028 RewriterState &newState) {
2033 for (
int i = state.numBlockActions, e = newState.numBlockActions; i != e;
2035 auto &action =
impl.blockActions[i];
2036 if (action.kind == BlockActionKind::TypeConversion ||
2037 action.kind == BlockActionKind::Erase)
2041 if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
2046 if (
auto *converter =
2047 impl.argConverter.getConverter(action.block->getParent())) {
2048 if (
failed(
impl.convertBlockSignature(action.block, converter))) {
2049 LLVM_DEBUG(
logFailure(
impl.logger,
"failed to convert types of moved "
2060 if (operationsToIgnore.empty()) {
2062 .drop_front(state.numCreatedOps);
2063 operationsToIgnore.insert(createdOps.begin(), createdOps.end());
2067 if (operationsToIgnore.insert(parentOp).second &&
2068 failed(legalize(parentOp, rewriter))) {
2070 impl.logger,
"operation '{0}'({1}) became illegal after block action",
2071 parentOp->
getName(), parentOp));
2078 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2080 RewriterState &state, RewriterState &newState) {
2081 for (
int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
2083 if (
failed(legalize(op, rewriter))) {
2085 "failed to legalize generated operation '{0}'({1})",
2093 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2095 RewriterState &state, RewriterState &newState) {
2096 for (
int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
2098 if (
failed(legalize(op, rewriter))) {
2100 impl.logger,
"failed to legalize operation updated in-place '{0}'",
2111 void OperationLegalizer::buildLegalizationGraph(
2112 LegalizationPatterns &anyOpLegalizerPatterns,
2123 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2124 std::optional<OperationName> root = pattern.
getRootKind();
2130 anyOpLegalizerPatterns.push_back(&pattern);
2135 if (target.getOpAction(*root) == LegalizationAction::Legal)
2140 invalidPatterns[*root].insert(&pattern);
2142 parentOps[op].insert(*root);
2145 patternWorklist.insert(&pattern);
2153 if (!anyOpLegalizerPatterns.empty()) {
2154 for (
const Pattern *pattern : patternWorklist)
2155 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2159 while (!patternWorklist.empty()) {
2160 auto *pattern = patternWorklist.pop_back_val();
2164 std::optional<LegalizationAction> action = target.getOpAction(op);
2165 return !legalizerPatterns.count(op) &&
2166 (!action || action == LegalizationAction::Illegal);
2172 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2173 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2177 for (
auto op : parentOps[*pattern->
getRootKind()])
2178 patternWorklist.set_union(invalidPatterns[op]);
2182 void OperationLegalizer::computeLegalizationGraphBenefit(
2183 LegalizationPatterns &anyOpLegalizerPatterns,
2189 for (
auto &opIt : legalizerPatterns)
2190 if (!minOpPatternDepth.count(opIt.first))
2191 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2197 if (!anyOpLegalizerPatterns.empty())
2198 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2204 applicator.applyCostModel([&](
const Pattern &pattern) {
2206 if (std::optional<OperationName> rootName = pattern.
getRootKind())
2207 orderedPatternList = legalizerPatterns[*rootName];
2209 orderedPatternList = anyOpLegalizerPatterns;
2212 auto *it = llvm::find(orderedPatternList, &pattern);
2213 if (it == orderedPatternList.end())
2217 return PatternBenefit(std::distance(it, orderedPatternList.end()));
2221 unsigned OperationLegalizer::computeOpLegalizationDepth(
2225 auto depthIt = minOpPatternDepth.find(op);
2226 if (depthIt != minOpPatternDepth.end())
2227 return depthIt->second;
2231 auto opPatternsIt = legalizerPatterns.find(op);
2232 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2241 unsigned minDepth = applyCostModelToPatterns(
2242 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2243 minOpPatternDepth[op] = minDepth;
2247 unsigned OperationLegalizer::applyCostModelToPatterns(
2248 LegalizationPatterns &patterns,
2255 patternsByDepth.reserve(patterns.size());
2256 for (
const Pattern *pattern : patterns) {
2259 unsigned generatedOpDepth = computeOpLegalizationDepth(
2260 generatedOp, minOpPatternDepth, legalizerPatterns);
2261 depth =
std::max(depth, generatedOpDepth + 1);
2263 patternsByDepth.emplace_back(pattern, depth);
2266 minDepth =
std::min(minDepth, depth);
2271 if (patternsByDepth.size() == 1)
2275 std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2276 [](
const std::pair<const Pattern *, unsigned> &lhs,
2277 const std::pair<const Pattern *, unsigned> &rhs) {
2280 if (lhs.second != rhs.second)
2281 return lhs.second < rhs.second;
2284 auto lhsBenefit = lhs.first->getBenefit();
2285 auto rhsBenefit = rhs.first->getBenefit();
2286 return lhsBenefit > rhsBenefit;
2291 for (
auto &patternIt : patternsByDepth)
2292 patterns.push_back(patternIt.first);
2300 enum OpConversionMode {
2317 struct OperationConverter {
2320 OpConversionMode mode,
2322 : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2362 OperationLegalizer opLegalizer;
2365 OpConversionMode mode;
2378 if (
failed(opLegalizer.legalize(op, rewriter))) {
2381 if (mode == OpConversionMode::Full)
2383 <<
"failed to legalize operation '" << op->
getName() <<
"'";
2387 if (mode == OpConversionMode::Partial) {
2388 if (opLegalizer.isIllegal(op))
2390 <<
"failed to legalize operation '" << op->
getName()
2391 <<
"' that was explicitly marked illegal";
2393 trackedOps->insert(op);
2395 }
else if (mode == OpConversionMode::Analysis) {
2399 trackedOps->insert(op);
2413 for (
auto *op : ops) {
2416 toConvert.push_back(op);
2419 auto legalityInfo = target.
isLegal(op);
2420 if (legalityInfo && legalityInfo->isRecursivelyLegal)
2431 for (
auto *op : toConvert)
2432 if (
failed(convert(rewriter, op)))
2438 if (
failed(finalize(rewriter)))
2443 if (mode == OpConversionMode::Analysis) {
2453 trackedOps->erase(repl.first);
2460 std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
2462 if (
failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2464 failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2474 auto &repl = *(rewriterImpl.
replacements.begin() + replIdx);
2475 for (
OpResult result : repl.first->getResults()) {
2476 Value newValue = rewriterImpl.
mapping.lookupOrNull(result);
2481 if (
failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
2491 if (!inverseMapping)
2492 inverseMapping = rewriterImpl.
mapping.getInverse();
2496 if (
failed(legalizeChangedResultType(repl.first, result, newValue,
2497 repl.second.converter, rewriter,
2498 rewriterImpl, *inverseMapping)))
2509 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2514 auto findLiveUser = [&](
Value val) {
2515 auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](
Operation *user) {
2516 return rewriterImpl.isOpIgnored(user);
2518 return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2520 return rewriterImpl.
argConverter.materializeLiveConversions(
2521 rewriterImpl.
mapping, rewriter, findLiveUser);
2533 for (
auto [matResult, newValue] : llvm::zip(matResults, values)) {
2534 auto inverseMapIt = inverseMapping.find(matResult);
2535 if (inverseMapIt == inverseMapping.end())
2544 for (
Value inverseMapVal : inverseMapIt->second)
2545 if (!rewriterImpl.
mapping.tryMap(inverseMapVal, newValue))
2546 rewriterImpl.
mapping.erase(inverseMapVal);
2558 auto isLive = [&](
Value value) {
2560 auto matIt = materializationOps.find(user);
2561 if (matIt != materializationOps.end())
2562 return !necessaryMaterializations.count(matIt->second);
2566 for (
Value inv : inverseMapping.lookup(value))
2567 if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
2577 Value remappedValue = rewriterImpl.
mapping.lookupOrDefault(value, type);
2578 if (remappedValue.
getType() == type && remappedValue != invalidRoot)
2579 return remappedValue;
2584 auto inputCastOp = value.
getDefiningOp<UnrealizedConversionCastOp>();
2585 if (inputCastOp && inputCastOp->getNumOperands() == 1)
2586 return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
2594 materializationOps.try_emplace(mat.getOp(), &mat);
2595 worklist.insert(&mat);
2597 while (!worklist.empty()) {
2598 UnresolvedMaterialization *mat = worklist.pop_back_val();
2599 UnrealizedConversionCastOp op = mat->getOp();
2602 assert(op->
getNumResults() == 1 &&
"unexpected materialization type");
2610 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
2613 if (castOp->getResultTypes() == inputOperands.
getTypes()) {
2616 necessaryMaterializations.remove(materializationOps.lookup(user));
2622 if (inputOperands.size() == 1) {
2625 Value remappedValue =
2626 lookupRemappedValue(opResult, inputOperands[0], outputType);
2627 if (remappedValue && remappedValue != opResult) {
2630 necessaryMaterializations.remove(mat);
2638 auto isBlockArg = [](
Value v) {
return isa<BlockArgument>(v); };
2639 if (llvm::any_of(op->
getOperands(), isBlockArg) ||
2640 llvm::any_of(inverseMapping[op->
getResult(0)], isBlockArg)) {
2650 bool isMaterializationLive = isLive(opResult);
2652 isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive);
2653 if (!isMaterializationLive)
2655 if (!necessaryMaterializations.insert(mat))
2659 for (
Value input : inputOperands) {
2660 if (
auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
2661 if (
auto *mat = materializationOps.lookup(parentOp))
2662 worklist.insert(mat);
2671 UnresolvedMaterialization &mat,
2676 auto findLiveUser = [&](
auto &&users) {
2677 auto liveUserIt = llvm::find_if_not(
2679 return liveUserIt == users.end() ? nullptr : *liveUserIt;
2686 Value remappedValue = rewriterImpl.
mapping.lookupOrDefault(value, type);
2687 if (remappedValue.
getType() == type)
2688 return remappedValue;
2692 UnrealizedConversionCastOp op = mat.getOp();
2704 auto valueCast = value.
getDefiningOp<UnrealizedConversionCastOp>();
2708 auto matIt = materializationOps.find(valueCast);
2709 if (matIt != materializationOps.end())
2711 *matIt->second, materializationOps, rewriter, rewriterImpl,
2719 if (inputOperands.size() == 1) {
2722 Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
2723 if (remappedValue && remappedValue != opResult) {
2736 if (inputOperands.size() == 1)
2741 Value newMaterialization;
2742 switch (mat.getKind()) {
2753 rewriter, op->
getLoc(), mat.getOrigOutputType(), inputOperands);
2754 if (newMaterialization)
2760 case UnresolvedMaterialization::Target:
2762 rewriter, op->
getLoc(), outputType, inputOperands);
2765 if (newMaterialization) {
2773 <<
"failed to legalize unresolved materialization "
2775 << inputOperands.
getTypes() <<
" to " << outputType
2776 <<
" that remained live after conversion";
2779 <<
"see existing live user here: " << *liveUser;
2784 LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
2790 inverseMapping = rewriterImpl.
mapping.getInverse();
2797 *inverseMapping, necessaryMaterializations);
2800 for (
auto *mat : necessaryMaterializations) {
2802 *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
2814 return rewriterImpl.isOpIgnored(user);
2816 if (liveUserIt != result.
user_end()) {
2818 << op->
getName() <<
"' marked as erased";
2819 diag.attachNote(liveUserIt->getLoc())
2833 while (!worklist.empty()) {
2834 Value value = worklist.pop_back_val();
2839 return rewriterImpl.isOpIgnored(user);
2841 if (liveUserIt != value.
user_end())
2843 auto mapIt = inverseMapping.find(value);
2844 if (mapIt != inverseMapping.end())
2845 worklist.append(mapIt->second);
2850 LogicalResult OperationConverter::legalizeChangedResultType(
2861 auto emitConversionError = [&] {
2863 <<
"failed to materialize conversion for result #"
2866 <<
"' that remained live after conversion";
2868 <<
"see existing live user here: " << *liveUser;
2875 return emitConversionError();
2880 rewriter, op->
getLoc(), resultType, newValue);
2881 if (!convertedValue)
2882 return emitConversionError();
2884 rewriterImpl.
mapping.map(result, convertedValue);
2894 assert(!types.empty() &&
"expected valid types");
2895 remapInput(origInputNo, argTypes.size(), types.size());
2900 assert(!types.empty() &&
2901 "1->0 type remappings don't need to be added explicitly");
2902 argTypes.append(types.begin(), types.end());
2906 unsigned newInputNo,
2907 unsigned newInputCount) {
2908 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
2909 assert(newInputCount != 0 &&
"expected valid input count");
2910 remappedInputs[origInputNo] =
2911 InputMapping{newInputNo, newInputCount,
nullptr};
2915 Value replacementValue) {
2916 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
2917 remappedInputs[origInputNo] =
2924 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2927 cacheReadLock.lock();
2928 auto existingIt = cachedDirectConversions.find(t);
2929 if (existingIt != cachedDirectConversions.end()) {
2930 if (existingIt->second)
2931 results.push_back(existingIt->second);
2932 return success(existingIt->second !=
nullptr);
2934 auto multiIt = cachedMultiConversions.find(t);
2935 if (multiIt != cachedMultiConversions.end()) {
2936 results.append(multiIt->second.begin(), multiIt->second.end());
2942 size_t currentCount = results.size();
2944 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2947 for (
const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2948 if (std::optional<LogicalResult> result = converter(t, results)) {
2950 cacheWriteLock.lock();
2952 cachedDirectConversions.try_emplace(t,
nullptr);
2955 auto newTypes =
ArrayRef<Type>(results).drop_front(currentCount);
2956 if (newTypes.size() == 1)
2957 cachedDirectConversions.try_emplace(t, newTypes.front());
2959 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2973 return results.size() == 1 ? results.front() :
nullptr;
2979 for (
Type type : types)
2993 return llvm::all_of(*region, [
this](
Block &block) {
2999 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
3011 if (convertedTypes.empty())
3015 result.
addInputs(inputNo, convertedTypes);
3021 unsigned origInputOffset)
const {
3022 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3028 Value TypeConverter::materializeConversion(
3031 for (
const MaterializationCallbackFn &fn : llvm::reverse(materializations))
3032 if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
3037 std::optional<TypeConverter::SignatureConversion>
3041 return std::nullopt;
3064 return impl.getInt() == resultTag;
3068 return impl.getInt() == naTag;
3072 return impl.getInt() == abortTag;
3076 assert(hasResult() &&
"Cannot get result from N/A or abort");
3077 return impl.getPointer();
3080 std::optional<Attribute>
3082 for (
const TypeAttributeConversionCallbackFn &fn :
3083 llvm::reverse(typeAttributeConversions)) {
3088 return std::nullopt;
3090 return std::nullopt;
3100 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3110 typeConverter, &result)))
3127 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3135 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3140 struct AnyFunctionOpInterfaceSignatureConversion
3155 patterns.
add<FunctionOpInterfaceSignatureConversion>(
3156 functionLikeOpName, patterns.
getContext(), converter);
3161 patterns.
add<AnyFunctionOpInterfaceSignatureConversion>(
3171 legalOperations[op].action = action;
3176 for (StringRef dialect : dialectNames)
3177 legalDialects[dialect] = action;
3181 -> std::optional<LegalizationAction> {
3182 std::optional<LegalizationInfo> info = getOpInfo(op);
3183 return info ? info->action : std::optional<LegalizationAction>();
3187 -> std::optional<LegalOpDetails> {
3188 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3190 return std::nullopt;
3193 auto isOpLegal = [&] {
3195 if (info->action == LegalizationAction::Dynamic) {
3196 std::optional<bool> result = info->legalityFn(op);
3202 return info->action == LegalizationAction::Legal;
3205 return std::nullopt;
3209 if (info->isRecursivelyLegal) {
3210 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3211 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3213 legalityFnIt->second(op).value_or(
true);
3218 return legalityDetails;
3222 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3226 if (info->action == LegalizationAction::Dynamic) {
3227 std::optional<bool> result = info->legalityFn(op);
3234 return info->action == LegalizationAction::Illegal;
3243 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3245 if (std::optional<bool> result = newCl(op))
3253 void ConversionTarget::setLegalityCallback(
3254 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3255 assert(callback &&
"expected valid legality callback");
3256 auto infoIt = legalOperations.find(name);
3257 assert(infoIt != legalOperations.end() &&
3258 infoIt->second.action == LegalizationAction::Dynamic &&
3259 "expected operation to already be marked as dynamically legal");
3260 infoIt->second.legalityFn =
3266 auto infoIt = legalOperations.find(name);
3267 assert(infoIt != legalOperations.end() &&
3268 infoIt->second.action != LegalizationAction::Illegal &&
3269 "expected operation to already be marked as legal");
3270 infoIt->second.isRecursivelyLegal =
true;
3273 std::move(opRecursiveLegalityFns[name]), callback);
3275 opRecursiveLegalityFns.erase(name);
3278 void ConversionTarget::setLegalityCallback(
3280 assert(callback &&
"expected valid legality callback");
3281 for (StringRef dialect : dialects)
3283 std::move(dialectLegalityFns[dialect]), callback);
3286 void ConversionTarget::setLegalityCallback(
3287 const DynamicLegalityCallbackFn &callback) {
3288 assert(callback &&
"expected valid legality callback");
3293 -> std::optional<LegalizationInfo> {
3295 auto it = legalOperations.find(op);
3296 if (it != legalOperations.end())
3299 auto dialectIt = legalDialects.find(op.getDialectNamespace());
3300 if (dialectIt != legalDialects.end()) {
3301 DynamicLegalityCallbackFn callback;
3302 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3303 if (dialectFn != dialectLegalityFns.end())
3304 callback = dialectFn->second;
3305 return LegalizationInfo{dialectIt->second,
false,
3309 if (unknownLegalityFn)
3310 return LegalizationInfo{LegalizationAction::Dynamic,
3311 false, unknownLegalityFn};
3312 return std::nullopt;
3320 auto &rewriterImpl =
3326 auto &rewriterImpl =
3338 return std::move(mappedValues);
3349 return results->front();
3359 auto &rewriterImpl =
3373 auto &rewriterImpl =
3382 return std::move(remappedTypes);
3398 OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
3400 return opConverter.convertOperations(ops);
3416 OperationConverter opConverter(target, patterns, OpConversionMode::Full);
3417 return opConverter.convertOperations(ops);
3434 OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
3436 return opConverter.convertOperations(ops, notifyCallback);
3444 convertedOps, notifyCallback);
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static Value buildUnresolvedTargetMaterialization(Location loc, Value input, Type outputType, const TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void detachNestedAndErase(Operation *op)
Detach any operations nested in the given operation from their parent blocks, and erase the given ope...
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter, Location loc, ValueRange inputs, Type origOutputType, Type outputType, const TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
static void computeNecessaryMaterializations(DenseMap< Operation *, UnresolvedMaterialization * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping, SetVector< UnresolvedMaterialization * > &necessaryMaterializations)
Compute all of the unresolved materializations that will persist beyond the conversion process,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, ResultRange matResults, ValueRange values, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Replace the results of a materialization operation with the given values.
static LogicalResult legalizeUnresolvedMaterialization(UnresolvedMaterialization &mat, DenseMap< Operation *, UnresolvedMaterialization * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Legalize the given unresolved materialization.
static Value buildUnresolvedMaterialization(UnresolvedMaterialization::Kind kind, Block *insertBlock, Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, Type origOutputType, const TypeConverter *converter, SmallVectorImpl< UnresolvedMaterialization > &unresolvedMaterializations)
Build an unresolved materialization operation given an output type and set of input operands.
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
void erase()
Unlink this Block from its parent region and delete it.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
RetT walk(FnT &&callback)
Walk the operations in this block.
OpListType & getOperations()
BlockArgListType getArguments()
void moveBefore(Block *block)
Unlink this block from its current region and insert it right before the specific block.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void finalizeRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor) override
PatternRewriter hook for replacing an operation when the given functor returns "true".
void notifyBlockCreated(Block *block) override
PatternRewriter hook creating a new block.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
ConversionPatternRewriter(MLIRContext *ctx)
void notifyOperationInserted(Operation *op) override
PatternRewriter hook for inserting a new operation.
void cancelRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
void startRootUpdate(Operation *op) override
PatternRewriter hook for updating the root operation in-place.
LogicalResult convertNonEntryRegionTypes(Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region.
~ConversionPatternRewriter() override
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping) override
PatternRewriter hook for cloning blocks of one region into another.
Base class for the conversion patterns.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class provides support for representing a failure result, or a valid value of type T.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void dropAllUses()
Drop all uses of this object from their respective owners.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class represents a diagnostic that is inflight and set to be reported.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within 'results'.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
This class implements the operand iterators for the Operation class.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
OpResult getOpResult(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
user_range getUsers()
Returns a range of all users.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function with PDL.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
This class implements the result iterators for the Operation class.
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
MLIRContext * getContext() const
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
The general result of a type attribute conversion callback, allowing for early termination.
Attribute getResult() const
static AttributeConversionResult abort()
static AttributeConversionResult na()
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
void dropAllUses() const
Drop all uses of this object from their respective owners.
Type getType() const
Return the type of this value.
user_iterator user_end() const
Block * getParentBlock()
Return the Block in which this Value is defined.
void replaceAllUsesWith(Value newValue) const
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
Detect if any of the given parameter types has a sub-element handler.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns)
Apply a complete conversion on the given operations, and all nested operations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > &convertedOps, function_ref< void(Diagnostic &)> notifyCallback=nullptr)
Apply an analysis conversion on the given operations, and all nested operations.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
This class represents an efficient way to signal success or failure.
void eraseDanglingBlocks()
Erase any blocks that were unlinked from their regions and stored in block actions.
void undoBlockActions(unsigned numActionsToKeep=0)
Undo the block actions (motions, splits) one by one in reverse order until "numActionsToKeep" actions...
void discardRewrites()
Cleanup and destroy any generated rewrite operations.
ConversionPatternRewriterImpl(PatternRewriter &rewriter)
function_ref< void(Diagnostic &)> notifyCallback
This allows the user to collect the match failure message.
SmallVector< BlockAction, 4 > blockActions
Ordered list of block operations (creations, splits, motions).
llvm::MapVector< Operation *, OpReplacement > replacements
Ordered map of requested operation replacements.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent, Region::iterator before)
Notifies that the blocks of a region are about to be moved.
void notifySplitBlock(Block *block, Block *continuation)
Notifies that a block was split.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
SmallVector< OperationTransactionState, 4 > rootUpdates
A transaction state for each of operations that were updated in-place.
RewriterState getCurrentState()
Return the current state of the rewriter.
void notifyOpReplaced(Operation *op, ValueRange newValues)
PatternRewriter hook for replacing the results of an operation.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
SmallVector< UnresolvedMaterialization > unresolvedMaterializations
Ordered vector of all unresolved type conversion materializations during conversion.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
LogicalResult convertNonEntryRegionTypes(Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions={})
Convert the types of non-entry block arguments within the given region.
ArgConverter argConverter
Utility used to convert block arguments.
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter)
Apply a signature conversion on the given region, using converter for materializations if not null.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
bool isOpIgnored(Operation *op) const
Returns true if the given operation is ignored, and does not need to be converted.
void markNestedOpsIgnored(Operation *op)
Recursively marks the nested operations under 'op' as ignored.
SmallVector< Operation * > createdOps
Ordered vector of all of the newly created operations during conversion.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notifies that a pattern match failed for the given reason.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization, but were not directly repla...
SmallVector< BlockArgument, 4 > argReplacements
Ordered vector of any requested block argument replacements.
FailureOr< Block * > convertBlockSignature(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion *conversion=nullptr)
Convert the signature of the given block.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
void notifyCreatedBlock(Block *block)
Notifies that a block was created.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
SmallVector< unsigned, 4 > operationsWithChangedResults
A vector of indices into replacements of operations that were replaced with values with different res...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.