10 #include "mlir/Config/mlir-config.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
30 #define DEBUG_TYPE "dialect-conversion"
33 template <
typename... Args>
34 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
37 os.startLine() <<
"} -> SUCCESS";
39 os.getOStream() <<
" : "
40 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41 os.getOStream() <<
"\n";
46 template <
typename... Args>
47 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
50 os.startLine() <<
"} -> FAILURE : "
51 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
63 struct ConversionValueMapping {
68 Value lookupOrDefault(
Value from,
Type desiredType =
nullptr)
const;
78 for (
Value it = newVal; it; it = mapping.lookupOrNull(it))
79 assert(it != oldVal &&
"inserting cyclic mapping");
81 mapping.map(oldVal, newVal);
90 void erase(
Value value) { mapping.erase(value); }
95 for (
auto &it : mapping.getValueMap())
96 inverse[it.second].push_back(it.first);
106 Value ConversionValueMapping::lookupOrDefault(
Value from,
107 Type desiredType)
const {
112 while (
auto mappedValue = mapping.lookupOrNull(from))
120 if (from.
getType() == desiredType)
123 Value mappedValue = mapping.lookupOrNull(from);
130 return desiredValue ? desiredValue : from;
133 Value ConversionValueMapping::lookupOrNull(
Value from,
Type desiredType)
const {
134 Value result = lookupOrDefault(from, desiredType);
135 if (result == from || (desiredType && result.
getType() != desiredType))
140 bool ConversionValueMapping::tryMap(
Value oldVal,
Value newVal) {
141 for (
Value it = newVal; it; it = mapping.lookupOrNull(it))
154 struct RewriterState {
155 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
156 unsigned numReplacedOps)
157 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
158 numReplacedOps(numReplacedOps) {}
161 unsigned numRewrites;
164 unsigned numIgnoredOperations;
167 unsigned numReplacedOps;
199 UnresolvedMaterialization
202 virtual ~IRRewrite() =
default;
205 virtual void rollback() = 0;
224 Kind getKind()
const {
return kind; }
226 static bool classof(
const IRRewrite *
rewrite) {
return true; }
230 : kind(kind), rewriterImpl(rewriterImpl) {}
239 class BlockRewrite :
public IRRewrite {
242 Block *getBlock()
const {
return block; }
244 static bool classof(
const IRRewrite *
rewrite) {
245 return rewrite->getKind() >= Kind::CreateBlock &&
246 rewrite->getKind() <= Kind::ReplaceBlockArg;
252 : IRRewrite(kind, rewriterImpl), block(block) {}
261 class CreateBlockRewrite :
public BlockRewrite {
264 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
266 static bool classof(
const IRRewrite *
rewrite) {
267 return rewrite->getKind() == Kind::CreateBlock;
273 listener->notifyBlockInserted(block, {}, {});
276 void rollback()
override {
279 auto &blockOps = block->getOperations();
280 while (!blockOps.empty())
281 blockOps.remove(blockOps.begin());
282 block->dropAllUses();
283 if (block->getParent())
294 class EraseBlockRewrite :
public BlockRewrite {
298 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block), region(region),
299 insertBeforeBlock(insertBeforeBlock) {}
301 static bool classof(
const IRRewrite *
rewrite) {
302 return rewrite->getKind() == Kind::EraseBlock;
305 ~EraseBlockRewrite()
override {
307 "rewrite was neither rolled back nor committed/cleaned up");
310 void rollback()
override {
313 assert(block &&
"expected block");
314 auto &blockList = region->getBlocks();
318 blockList.insert(before, block);
324 assert(block &&
"expected block");
325 assert(block->empty() &&
"expected empty block");
329 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
330 listener->notifyBlockErased(block);
335 block->dropAllDefinedValueUses();
346 Block *insertBeforeBlock;
352 class InlineBlockRewrite :
public BlockRewrite {
356 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
357 sourceBlock(sourceBlock),
358 firstInlinedInst(sourceBlock->empty() ? nullptr
359 : &sourceBlock->front()),
360 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
366 assert(!getConfig().listener &&
367 "InlineBlockRewrite not supported if listener is attached");
370 static bool classof(
const IRRewrite *
rewrite) {
371 return rewrite->getKind() == Kind::InlineBlock;
374 void rollback()
override {
377 if (firstInlinedInst) {
378 assert(lastInlinedInst &&
"expected operation");
398 class MoveBlockRewrite :
public BlockRewrite {
402 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block), region(region),
403 insertBeforeBlock(insertBeforeBlock) {}
405 static bool classof(
const IRRewrite *
rewrite) {
406 return rewrite->getKind() == Kind::MoveBlock;
414 listener->notifyBlockInserted(block, region,
419 void rollback()
override {
432 Block *insertBeforeBlock;
437 struct ConvertedArgInfo {
438 ConvertedArgInfo(
unsigned newArgIdx,
unsigned newArgSize,
439 Value castValue =
nullptr)
440 : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
455 class BlockTypeConversionRewrite :
public BlockRewrite {
457 BlockTypeConversionRewrite(
461 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, block),
462 origBlock(origBlock), argInfo(argInfo), converter(converter) {}
464 static bool classof(
const IRRewrite *
rewrite) {
465 return rewrite->getKind() == Kind::BlockTypeConversion;
476 void rollback()
override;
493 class ReplaceBlockArgRewrite :
public BlockRewrite {
497 : BlockRewrite(
Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
499 static bool classof(
const IRRewrite *
rewrite) {
500 return rewrite->getKind() == Kind::ReplaceBlockArg;
505 void rollback()
override;
512 class OperationRewrite :
public IRRewrite {
515 Operation *getOperation()
const {
return op; }
517 static bool classof(
const IRRewrite *
rewrite) {
518 return rewrite->getKind() >= Kind::MoveOperation &&
519 rewrite->getKind() <= Kind::UnresolvedMaterialization;
525 : IRRewrite(kind, rewriterImpl), op(op) {}
532 class MoveOperationRewrite :
public OperationRewrite {
536 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op), block(block),
537 insertBeforeOp(insertBeforeOp) {}
539 static bool classof(
const IRRewrite *
rewrite) {
540 return rewrite->getKind() == Kind::MoveOperation;
548 listener->notifyOperationInserted(
554 void rollback()
override {
572 class ModifyOperationRewrite :
public OperationRewrite {
576 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
577 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
578 operands(op->operand_begin(), op->operand_end()),
579 successors(op->successor_begin(), op->successor_end()) {
584 name.initOpProperties(propCopy, prop);
588 static bool classof(
const IRRewrite *
rewrite) {
589 return rewrite->getKind() == Kind::ModifyOperation;
592 ~ModifyOperationRewrite()
override {
593 assert(!propertiesStorage &&
594 "rewrite was neither committed nor rolled back");
600 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
601 listener->notifyOperationModified(op);
603 if (propertiesStorage) {
607 name.destroyOpProperties(propCopy);
608 operator delete(propertiesStorage);
609 propertiesStorage =
nullptr;
613 void rollback()
override {
619 if (propertiesStorage) {
622 name.destroyOpProperties(propCopy);
623 operator delete(propertiesStorage);
624 propertiesStorage =
nullptr;
631 DictionaryAttr attrs;
634 void *propertiesStorage =
nullptr;
641 class ReplaceOperationRewrite :
public OperationRewrite {
646 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
647 converter(converter), changedResults(changedResults) {}
649 static bool classof(
const IRRewrite *
rewrite) {
650 return rewrite->getKind() == Kind::ReplaceOperation;
655 void rollback()
override;
659 const TypeConverter *getConverter()
const {
return converter; }
661 bool hasChangedResults()
const {
return changedResults; }
672 class CreateOperationRewrite :
public OperationRewrite {
676 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
678 static bool classof(
const IRRewrite *
rewrite) {
679 return rewrite->getKind() == Kind::CreateOperation;
685 listener->notifyOperationInserted(op, {});
688 void rollback()
override;
692 enum MaterializationKind {
705 class UnresolvedMaterializationRewrite :
public OperationRewrite {
707 UnresolvedMaterializationRewrite(
709 UnrealizedConversionCastOp op,
const TypeConverter *converter =
nullptr,
710 MaterializationKind kind = MaterializationKind::Target,
711 Type origOutputType =
nullptr)
712 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
713 converterAndKind(converter, kind), origOutputType(origOutputType) {}
715 static bool classof(
const IRRewrite *
rewrite) {
716 return rewrite->getKind() == Kind::UnresolvedMaterialization;
719 UnrealizedConversionCastOp getOperation()
const {
720 return cast<UnrealizedConversionCastOp>(op);
723 void rollback()
override;
729 return converterAndKind.getPointer();
733 MaterializationKind getMaterializationKind()
const {
734 return converterAndKind.getInt();
738 void setMaterializationKind(MaterializationKind kind) {
739 converterAndKind.setInt(kind);
743 Type getOrigOutputType()
const {
return origOutputType; }
748 llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
758 template <
typename RewriteTy,
typename R>
760 return any_of(std::move(rewrites), [&](
auto &
rewrite) {
761 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
762 return rewriteTy && rewriteTy->getOperation() == op;
769 template <
typename RewriteTy,
typename R>
771 RewriteTy *result =
nullptr;
772 for (
auto &
rewrite : rewrites) {
773 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
774 if (rewriteTy && rewriteTy->getBlock() == block) {
776 assert(!result &&
"expected single matching rewrite");
794 : context(ctx), config(config) {}
801 RewriterState getCurrentState();
805 void applyRewrites();
808 void resetState(RewriterState state);
812 template <
typename RewriteTy,
typename... Args>
815 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
820 void undoRewrites(
unsigned numRewritesToKeep = 0);
827 std::optional<Location> inputLoc,
875 Block *applySignatureConversion(
885 Value buildUnresolvedMaterialization(MaterializationKind kind,
907 void notifyOperationInserted(
Operation *op,
914 void notifyBlockIsBeingErased(
Block *block);
917 void notifyBlockInserted(
Block *block,
Region *previous,
921 void notifyBlockBeingInlined(
Block *block,
Block *srcBlock,
943 if (erased.contains(op))
951 if (erased.contains(block))
953 assert(block->
empty() &&
"expected empty block");
1009 llvm::ScopedPrinter logger{llvm::dbgs()};
1016 return rewriterImpl.
config;
1019 void BlockTypeConversionRewrite::commit(
RewriterBase &rewriter) {
1023 if (
auto *listener =
1024 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1026 listener->notifyOperationModified(op);
1029 for (
auto [origArg, info] :
1030 llvm::zip_equal(origBlock->getArguments(), argInfo)) {
1034 rewriterImpl.
mapping.lookupOrNull(origArg, origArg.getType()))
1040 Value castValue = info->castValue;
1041 assert(info->newArgSize >= 1 && castValue &&
"expected 1->1+ mapping");
1044 if (!origArg.use_empty()) {
1046 castValue, origArg.getType()));
1051 void BlockTypeConversionRewrite::rollback() {
1055 LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1061 OpBuilder builder(it.value().getContext(), &rewriterImpl);
1062 builder.setInsertionPointToStart(block);
1066 if (rewriterImpl.
mapping.lookupOrNull(origArg, origArg.
getType()))
1068 Operation *liveUser = findLiveUser(origArg);
1072 Value replacementValue = rewriterImpl.
mapping.lookupOrDefault(origArg);
1073 bool isDroppedArg = replacementValue == origArg;
1075 builder.setInsertionPointAfterValue(replacementValue);
1078 newArg = converter->materializeSourceConversion(
1082 "materialization hook did not provide a value of the expected "
1088 <<
"failed to materialize conversion for block argument #"
1089 << it.index() <<
" that remained live after conversion, type was "
1092 diag <<
", with target type " << replacementValue.
getType();
1094 <<
"see existing live user here: " << *liveUser;
1097 rewriterImpl.
mapping.map(origArg, newArg);
1102 void ReplaceBlockArgRewrite::commit(
RewriterBase &rewriter) {
1103 Value repl = rewriterImpl.
mapping.lookupOrNull(arg, arg.getType());
1107 if (isa<BlockArgument>(repl)) {
1115 Operation *replOp = cast<OpResult>(repl).getOwner();
1123 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.
mapping.erase(arg); }
1125 void ReplaceOperationRewrite::commit(
RewriterBase &rewriter) {
1127 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1132 return rewriterImpl.mapping.lookupOrNull(result, result.getType());
1137 listener->notifyOperationReplaced(op, replacements);
1140 for (
auto [result, newValue] :
1141 llvm::zip_equal(op->
getResults(), replacements))
1147 if (getConfig().unlegalizedOps)
1148 getConfig().unlegalizedOps->erase(op);
1154 [&](
Operation *op) { listener->notifyOperationErased(op); });
1162 void ReplaceOperationRewrite::rollback() {
1164 rewriterImpl.
mapping.erase(result);
1167 void ReplaceOperationRewrite::cleanup(
RewriterBase &rewriter) {
1171 void CreateOperationRewrite::rollback() {
1173 while (!region.getBlocks().empty())
1174 region.getBlocks().remove(region.getBlocks().begin());
1180 void UnresolvedMaterializationRewrite::rollback() {
1181 if (getMaterializationKind() == MaterializationKind::Target) {
1183 rewriterImpl.
mapping.erase(input);
1188 void UnresolvedMaterializationRewrite::cleanup(
RewriterBase &rewriter) {
1201 rewrite->cleanup(eraseRewriter);
1216 while (
ignoredOps.size() != state.numIgnoredOperations)
1219 while (
replacedOps.size() != state.numReplacedOps)
1225 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1227 rewrites.resize(numRewritesToKeep);
1231 StringRef valueDiagTag, std::optional<Location> inputLoc,
1234 remapped.reserve(llvm::size(values));
1238 Value operand = it.value();
1250 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1251 << it.index() <<
", type was " << origType;
1257 if (legalTypes.size() == 1)
1258 desiredType = legalTypes.front();
1267 Value newOperand =
mapping.lookupOrDefault(operand, desiredType);
1277 newOperand = castValue;
1279 remapped.push_back(newOperand);
1319 if (!region->
empty())
1330 if (region->
empty())
1337 rewriter, ®ion->
front(), &converter, entryConversion);
1346 if (region->
empty())
1351 assert((blockConversions.empty() ||
1352 blockConversions.size() == region->
getBlocks().size() - 1) &&
1353 "expected either to provide no SignatureConversions at all or to "
1354 "provide a SignatureConversion for each non-entry block");
1357 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1359 blockConversions.empty()
1362 &blockConversions[blockIdx++]);
1386 for (
unsigned i = 0; i < origArgCount; ++i) {
1388 if (!inputMap || inputMap->replacementValue)
1391 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1392 newLocs[inputMap->inputNo +
j] = origLoc;
1399 convertedTypes, newLocs);
1409 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->
end());
1412 while (!block->
empty())
1422 argInfo.resize(origArgCount);
1424 for (
unsigned i = 0; i != origArgCount; ++i) {
1432 if (inputMap->replacementValue) {
1433 assert(inputMap->size == 0 &&
1434 "invalid to provide a replacement value when the argument isn't "
1436 mapping.map(origArg, inputMap->replacementValue);
1437 appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1443 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1455 if (replArgs.size() == 1 &&
1456 (!converter || replArgs[0].getType() == origArg.
getType())) {
1457 newArg = replArgs.front();
1462 Type outputType = origOutputType;
1464 outputType = legalOutputType;
1467 newBlock, origArg.
getLoc(), replArgs, origOutputType, outputType,
1472 appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1473 argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
1476 appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
1497 if (inputs.size() == 1 && inputs.front().
getType() == outputType)
1498 return inputs.front();
1502 OpBuilder builder(insertBlock, insertPt);
1504 builder.
create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1505 appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1507 return convertOp.getResult(0);
1513 block->
begin(), loc, inputs, outputType,
1514 origOutputType, converter);
1521 if (
OpResult inputRes = dyn_cast<OpResult>(input))
1522 insertPt = ++inputRes.getOwner()->getIterator();
1525 insertBlock, insertPt, loc, input,
1526 outputType, outputType, converter);
1535 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"'(" << op
1539 "attempting to insert into a block within a replaced/erased op");
1541 if (!previous.
isSet()) {
1543 appendRewrite<CreateOperationRewrite>(op);
1549 appendRewrite<MoveOperationRewrite>(op, previous.
getBlock(), prevOp);
1555 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1558 bool resultChanged =
false;
1561 for (
auto [newValue, result] : llvm::zip(newValues, op->
getResults())) {
1563 resultChanged =
true;
1567 mapping.map(result, newValue);
1568 resultChanged |= (newValue.getType() != result.
getType());
1580 Block *origNextBlock = block->getNextNode();
1581 appendRewrite<EraseBlockRewrite>(block, region, origNextBlock);
1587 "attempting to insert into a region within a replaced/erased op");
1592 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
1593 <<
"'(" << parent <<
")\n";
1596 <<
"** Insert Block into detached Region (nullptr parent op)'";
1602 appendRewrite<CreateBlockRewrite>(block);
1605 Block *prevBlock = previousIt == previous->
end() ? nullptr : &*previousIt;
1606 appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1611 appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1618 reasonCallback(
diag);
1619 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
1629 ConversionPatternRewriter::ConversionPatternRewriter(
1633 setListener(
impl.get());
1639 assert(op && newOp &&
"expected non-null op");
1645 "incorrect # of replacement values");
1647 impl->logger.startLine()
1648 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
1650 impl->notifyOpReplaced(op, newValues);
1655 impl->logger.startLine()
1656 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
1659 impl->notifyOpReplaced(op, nullRepls);
1664 "attempting to erase a block within a replaced/erased op");
1674 impl->notifyBlockIsBeingErased(block);
1682 "attempting to apply a signature conversion to a block within a "
1683 "replaced/erased op");
1684 return impl->applySignatureConversion(*
this, region, conversion, converter);
1691 "attempting to apply a signature conversion to a block within a "
1692 "replaced/erased op");
1693 return impl->convertRegionTypes(*
this, region, converter, entryConversion);
1700 "attempting to apply a signature conversion to a block within a "
1701 "replaced/erased op");
1702 return impl->convertNonEntryRegionTypes(*
this, region, converter,
1710 impl->logger.startLine() <<
"** Replace Argument : '" << from
1711 <<
"'(in region of '" << parentOp->
getName()
1714 impl->appendRewrite<ReplaceBlockArgRewrite>(from.
getOwner(), from);
1715 impl->mapping.map(
impl->mapping.lookupOrDefault(from), to);
1720 if (
failed(
impl->remapValues(
"value", std::nullopt, *
this, key,
1723 return remappedValues.front();
1731 return impl->remapValues(
"value", std::nullopt, *
this, keys,
1740 "incorrect # of argument replacement values");
1742 "attempting to inline a block from a replaced/erased op");
1744 "attempting to inline a block into a replaced/erased op");
1745 auto opIgnored = [&](
Operation *op) {
return impl->isOpIgnored(op); };
1748 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
1749 "expected 'source' to have no predecessors");
1758 bool fastPath = !
impl->config.listener;
1761 impl->notifyBlockBeingInlined(dest, source, before);
1764 for (
auto it : llvm::zip(source->
getArguments(), argValues))
1765 replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1772 while (!source->
empty())
1773 moveOpBefore(&source->
front(), dest, before);
1781 assert(!
impl->wasOpReplaced(op) &&
1782 "attempting to modify a replaced/erased op");
1784 impl->pendingRootUpdates.insert(op);
1786 impl->appendRewrite<ModifyOperationRewrite>(op);
1790 assert(!
impl->wasOpReplaced(op) &&
1791 "attempting to modify a replaced/erased op");
1796 assert(
impl->pendingRootUpdates.erase(op) &&
1797 "operation did not have a pending in-place update");
1803 assert(
impl->pendingRootUpdates.erase(op) &&
1804 "operation did not have a pending in-place update");
1807 auto it = llvm::find_if(
1808 llvm::reverse(
impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
1809 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1810 return modifyRewrite && modifyRewrite->getOperation() == op;
1812 assert(it !=
impl->rewrites.rend() &&
"no root update started on op");
1814 int updateIdx = std::prev(
impl->rewrites.rend()) - it;
1815 impl->rewrites.erase(
impl->rewrites.begin() + updateIdx);
1830 auto &rewriterImpl = dialectRewriter.getImpl();
1834 getTypeConverter());
1842 return matchAndRewrite(op, operands, dialectRewriter);
1854 class OperationLegalizer {
1890 RewriterState &curState);
1894 legalizePatternBlockRewrites(
Operation *op,
1897 RewriterState &state, RewriterState &newState);
1900 RewriterState &state, RewriterState &newState);
1903 RewriterState &state,
1904 RewriterState &newState);
1914 void buildLegalizationGraph(
1915 LegalizationPatterns &anyOpLegalizerPatterns,
1926 void computeLegalizationGraphBenefit(
1927 LegalizationPatterns &anyOpLegalizerPatterns,
1932 unsigned computeOpLegalizationDepth(
1939 unsigned applyCostModelToPatterns(
1940 LegalizationPatterns &patterns,
1961 : target(targetInfo), applicator(patterns), config(config) {
1965 LegalizationPatterns anyOpLegalizerPatterns;
1967 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1968 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1971 bool OperationLegalizer::isIllegal(
Operation *op)
const {
1972 return target.isIllegal(op);
1976 OperationLegalizer::legalize(
Operation *op,
1979 const char *logLineComment =
1980 "//===-------------------------------------------===//\n";
1985 logger.getOStream() <<
"\n";
1986 logger.startLine() << logLineComment;
1987 logger.startLine() <<
"Legalizing operation : '" << op->
getName() <<
"'("
1993 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1994 logger.getOStream() <<
"\n\n";
1999 if (
auto legalityInfo = target.isLegal(op)) {
2002 logger,
"operation marked legal by the target{0}",
2003 legalityInfo->isRecursivelyLegal
2004 ?
"; NOTE: operation is recursively legal; skipping internals"
2006 logger.startLine() << logLineComment;
2011 if (legalityInfo->isRecursivelyLegal) {
2024 logSuccess(logger,
"operation marked 'ignored' during conversion");
2025 logger.startLine() << logLineComment;
2033 if (
succeeded(legalizeWithFold(op, rewriter))) {
2036 logger.startLine() << logLineComment;
2042 if (
succeeded(legalizeWithPattern(op, rewriter))) {
2045 logger.startLine() << logLineComment;
2051 logFailure(logger,
"no matched legalization pattern");
2052 logger.startLine() << logLineComment;
2058 OperationLegalizer::legalizeWithFold(
Operation *op,
2060 auto &rewriterImpl = rewriter.
getImpl();
2064 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2065 rewriterImpl.
logger.indent();
2077 rewriter.
replaceOp(op, replacementValues);
2080 for (
unsigned i = curState.numRewrites, e = rewriterImpl.
rewrites.size();
2083 dyn_cast<CreateOperationRewrite>(rewriterImpl.
rewrites[i].get());
2086 if (
failed(legalize(createOp->getOperation(), rewriter))) {
2088 "failed to legalize generated constant '{0}'",
2089 createOp->getOperation()->getName()));
2100 OperationLegalizer::legalizeWithPattern(
Operation *op,
2102 auto &rewriterImpl = rewriter.
getImpl();
2105 auto canApply = [&](
const Pattern &pattern) {
2106 bool canApply = canApplyPattern(op, pattern, rewriter);
2107 if (canApply && config.listener)
2108 config.listener->notifyPatternBegin(pattern, op);
2114 auto onFailure = [&](
const Pattern &pattern) {
2120 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2126 if (config.listener)
2127 config.listener->notifyPatternEnd(pattern,
failure());
2129 appliedPatterns.erase(&pattern);
2134 auto onSuccess = [&](
const Pattern &pattern) {
2136 auto result = legalizePatternResult(op, pattern, rewriter, curState);
2137 appliedPatterns.erase(&pattern);
2140 if (config.listener)
2141 config.listener->notifyPatternEnd(pattern, result);
2146 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2150 bool OperationLegalizer::canApplyPattern(
Operation *op,
const Pattern &pattern,
2153 auto &os = rewriter.
getImpl().logger;
2154 os.getOStream() <<
"\n";
2155 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2157 os.getOStream() <<
")' {\n";
2164 !appliedPatterns.insert(&pattern).second) {
2173 OperationLegalizer::legalizePatternResult(
Operation *op,
const Pattern &pattern,
2175 RewriterState &curState) {
2179 assert(
impl.pendingRootUpdates.empty() &&
"dangling root updates");
2181 auto newRewrites = llvm::drop_begin(
impl.rewrites, curState.numRewrites);
2182 auto replacedRoot = [&] {
2183 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2185 auto updatedRootInPlace = [&] {
2186 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2188 assert((replacedRoot() || updatedRootInPlace()) &&
2189 "expected pattern to replace the root operation");
2193 RewriterState newState =
impl.getCurrentState();
2194 if (
failed(legalizePatternBlockRewrites(op, rewriter,
impl, curState,
2196 failed(legalizePatternRootUpdates(rewriter,
impl, curState, newState)) ||
2197 failed(legalizePatternCreatedOperations(rewriter,
impl, curState,
2202 LLVM_DEBUG(
logSuccess(
impl.logger,
"pattern applied successfully"));
2206 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2209 RewriterState &newState) {
2214 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2215 BlockRewrite *
rewrite = dyn_cast<BlockRewrite>(
impl.rewrites[i].get());
2219 if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2220 ReplaceBlockArgRewrite>(
rewrite))
2229 if (
auto *converter =
impl.regionToConverter.lookup(block->
getParent())) {
2230 if (
failed(
impl.convertBlockSignature(rewriter, block, converter))) {
2231 LLVM_DEBUG(
logFailure(
impl.logger,
"failed to convert types of moved "
2242 if (operationsToIgnore.empty()) {
2243 for (
unsigned i = state.numRewrites, e =
impl.rewrites.size(); i != e;
2246 dyn_cast<CreateOperationRewrite>(
impl.rewrites[i].get());
2249 operationsToIgnore.insert(createOp->getOperation());
2254 if (operationsToIgnore.insert(parentOp).second &&
2255 failed(legalize(parentOp, rewriter))) {
2257 "operation '{0}'({1}) became illegal after rewrite",
2258 parentOp->
getName(), parentOp));
2265 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2267 RewriterState &state, RewriterState &newState) {
2268 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2269 auto *createOp = dyn_cast<CreateOperationRewrite>(
impl.rewrites[i].get());
2272 Operation *op = createOp->getOperation();
2273 if (
failed(legalize(op, rewriter))) {
2275 "failed to legalize generated operation '{0}'({1})",
2283 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2285 RewriterState &state, RewriterState &newState) {
2286 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2287 auto *
rewrite = dyn_cast<ModifyOperationRewrite>(
impl.rewrites[i].get());
2291 if (
failed(legalize(op, rewriter))) {
2293 impl.logger,
"failed to legalize operation updated in-place '{0}'",
2304 void OperationLegalizer::buildLegalizationGraph(
2305 LegalizationPatterns &anyOpLegalizerPatterns,
2316 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2317 std::optional<OperationName> root = pattern.
getRootKind();
2323 anyOpLegalizerPatterns.push_back(&pattern);
2328 if (target.getOpAction(*root) == LegalizationAction::Legal)
2333 invalidPatterns[*root].insert(&pattern);
2335 parentOps[op].insert(*root);
2338 patternWorklist.insert(&pattern);
2346 if (!anyOpLegalizerPatterns.empty()) {
2347 for (
const Pattern *pattern : patternWorklist)
2348 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2352 while (!patternWorklist.empty()) {
2353 auto *pattern = patternWorklist.pop_back_val();
2357 std::optional<LegalizationAction> action = target.getOpAction(op);
2358 return !legalizerPatterns.count(op) &&
2359 (!action || action == LegalizationAction::Illegal);
2365 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2366 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2370 for (
auto op : parentOps[*pattern->
getRootKind()])
2371 patternWorklist.set_union(invalidPatterns[op]);
2375 void OperationLegalizer::computeLegalizationGraphBenefit(
2376 LegalizationPatterns &anyOpLegalizerPatterns,
2382 for (
auto &opIt : legalizerPatterns)
2383 if (!minOpPatternDepth.count(opIt.first))
2384 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2390 if (!anyOpLegalizerPatterns.empty())
2391 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2397 applicator.applyCostModel([&](
const Pattern &pattern) {
2399 if (std::optional<OperationName> rootName = pattern.
getRootKind())
2400 orderedPatternList = legalizerPatterns[*rootName];
2402 orderedPatternList = anyOpLegalizerPatterns;
2405 auto *it = llvm::find(orderedPatternList, &pattern);
2406 if (it == orderedPatternList.end())
2410 return PatternBenefit(std::distance(it, orderedPatternList.end()));
2414 unsigned OperationLegalizer::computeOpLegalizationDepth(
2418 auto depthIt = minOpPatternDepth.find(op);
2419 if (depthIt != minOpPatternDepth.end())
2420 return depthIt->second;
2424 auto opPatternsIt = legalizerPatterns.find(op);
2425 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2434 unsigned minDepth = applyCostModelToPatterns(
2435 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2436 minOpPatternDepth[op] = minDepth;
2440 unsigned OperationLegalizer::applyCostModelToPatterns(
2441 LegalizationPatterns &patterns,
2448 patternsByDepth.reserve(patterns.size());
2449 for (
const Pattern *pattern : patterns) {
2452 unsigned generatedOpDepth = computeOpLegalizationDepth(
2453 generatedOp, minOpPatternDepth, legalizerPatterns);
2454 depth =
std::max(depth, generatedOpDepth + 1);
2456 patternsByDepth.emplace_back(pattern, depth);
2459 minDepth =
std::min(minDepth, depth);
2464 if (patternsByDepth.size() == 1)
2468 std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2469 [](
const std::pair<const Pattern *, unsigned> &lhs,
2470 const std::pair<const Pattern *, unsigned> &rhs) {
2473 if (lhs.second != rhs.second)
2474 return lhs.second < rhs.second;
2477 auto lhsBenefit = lhs.first->getBenefit();
2478 auto rhsBenefit = rhs.first->getBenefit();
2479 return lhsBenefit > rhsBenefit;
2484 for (
auto &patternIt : patternsByDepth)
2485 patterns.push_back(patternIt.first);
2493 enum OpConversionMode {
2516 OpConversionMode mode)
2517 : config(config), opLegalizer(target, patterns, this->config),
2559 OperationLegalizer opLegalizer;
2562 OpConversionMode mode;
2569 if (
failed(opLegalizer.legalize(op, rewriter))) {
2572 if (mode == OpConversionMode::Full)
2574 <<
"failed to legalize operation '" << op->
getName() <<
"'";
2578 if (mode == OpConversionMode::Partial) {
2579 if (opLegalizer.isIllegal(op))
2581 <<
"failed to legalize operation '" << op->
getName()
2582 <<
"' that was explicitly marked illegal";
2586 }
else if (mode == OpConversionMode::Analysis) {
2603 for (
auto *op : ops) {
2606 toConvert.push_back(op);
2609 auto legalityInfo = target.
isLegal(op);
2610 if (legalityInfo && legalityInfo->isRecursivelyLegal)
2620 for (
auto *op : toConvert)
2621 if (
failed(convert(rewriter, op)))
2627 if (
failed(finalize(rewriter)))
2632 if (mode == OpConversionMode::Analysis) {
2642 std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
2644 if (
failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2646 failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2650 for (
unsigned i = 0; i < rewriterImpl.
rewrites.size(); ++i) {
2651 auto *opReplacement =
2652 dyn_cast<ReplaceOperationRewrite>(rewriterImpl.
rewrites[i].get());
2653 if (!opReplacement || !opReplacement->hasChangedResults())
2655 Operation *op = opReplacement->getOperation();
2657 Value newValue = rewriterImpl.
mapping.lookupOrNull(result);
2662 if (
failed(legalizeErasedResult(op, result, rewriterImpl)))
2672 if (!inverseMapping)
2673 inverseMapping = rewriterImpl.
mapping.getInverse();
2677 if (
failed(legalizeChangedResultType(
2678 op, result, newValue, opReplacement->getConverter(), rewriter,
2679 rewriterImpl, *inverseMapping)))
2686 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2691 auto findLiveUser = [&](
Value val) {
2692 auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](
Operation *user) {
2693 return rewriterImpl.isOpIgnored(user);
2695 return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2698 for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.
rewrites.size());
2701 if (
auto *blockTypeConversionRewrite =
2702 dyn_cast<BlockTypeConversionRewrite>(
rewrite.get()))
2703 if (
failed(blockTypeConversionRewrite->materializeLiveConversions(
2719 for (
auto [matResult, newValue] : llvm::zip(matResults, values)) {
2720 auto inverseMapIt = inverseMapping.find(matResult);
2721 if (inverseMapIt == inverseMapping.end())
2730 for (
Value inverseMapVal : inverseMapIt->second)
2731 if (!rewriterImpl.
mapping.tryMap(inverseMapVal, newValue))
2732 rewriterImpl.
mapping.erase(inverseMapVal);
2740 &materializationOps,
2745 auto isLive = [&](
Value value) {
2747 auto matIt = materializationOps.find(user);
2748 if (matIt != materializationOps.end())
2749 return !necessaryMaterializations.count(matIt->second);
2753 for (
Value inv : inverseMapping.lookup(value))
2754 if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
2764 Value remappedValue = rewriterImpl.
mapping.lookupOrDefault(value, type);
2765 if (remappedValue.
getType() == type && remappedValue != invalidRoot)
2766 return remappedValue;
2771 auto inputCastOp = value.
getDefiningOp<UnrealizedConversionCastOp>();
2772 if (inputCastOp && inputCastOp->getNumOperands() == 1)
2773 return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
2781 auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(
rewrite.get());
2784 materializationOps.try_emplace(mat->getOperation(), mat);
2785 worklist.insert(mat);
2787 while (!worklist.empty()) {
2788 UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
2789 UnrealizedConversionCastOp op = mat->getOperation();
2792 assert(op->
getNumResults() == 1 &&
"unexpected materialization type");
2800 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
2803 if (castOp->getResultTypes() == inputOperands.
getTypes()) {
2806 necessaryMaterializations.remove(materializationOps.lookup(user));
2812 if (inputOperands.size() == 1) {
2815 Value remappedValue =
2816 lookupRemappedValue(opResult, inputOperands[0], outputType);
2817 if (remappedValue && remappedValue != opResult) {
2820 necessaryMaterializations.remove(mat);
2828 auto isBlockArg = [](
Value v) {
return isa<BlockArgument>(v); };
2829 if (llvm::any_of(op->
getOperands(), isBlockArg) ||
2830 llvm::any_of(inverseMapping[op->
getResult(0)], isBlockArg)) {
2840 bool isMaterializationLive = isLive(opResult);
2842 isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive);
2843 if (!isMaterializationLive)
2845 if (!necessaryMaterializations.insert(mat))
2849 for (
Value input : inputOperands) {
2850 if (
auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
2851 if (
auto *mat = materializationOps.lookup(parentOp))
2852 worklist.insert(mat);
2861 UnresolvedMaterializationRewrite &mat,
2863 &materializationOps,
2867 auto findLiveUser = [&](
auto &&users) {
2868 auto liveUserIt = llvm::find_if_not(
2870 return liveUserIt == users.end() ? nullptr : *liveUserIt;
2877 Value remappedValue = rewriterImpl.
mapping.lookupOrDefault(value, type);
2878 if (remappedValue.
getType() == type)
2879 return remappedValue;
2883 UnrealizedConversionCastOp op = mat.getOperation();
2895 auto valueCast = value.
getDefiningOp<UnrealizedConversionCastOp>();
2899 auto matIt = materializationOps.find(valueCast);
2900 if (matIt != materializationOps.end())
2902 *matIt->second, materializationOps, rewriter, rewriterImpl,
2910 if (inputOperands.size() == 1) {
2913 Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
2914 if (remappedValue && remappedValue != opResult) {
2927 if (inputOperands.size() == 1)
2932 Value newMaterialization;
2933 switch (mat.getMaterializationKind()) {
2943 newMaterialization = converter->materializeArgumentConversion(
2944 rewriter, op->
getLoc(), mat.getOrigOutputType(), inputOperands);
2945 if (newMaterialization)
2951 case MaterializationKind::Target:
2952 newMaterialization = converter->materializeTargetConversion(
2953 rewriter, op->
getLoc(), outputType, inputOperands);
2956 if (newMaterialization) {
2964 <<
"failed to legalize unresolved materialization "
2966 << inputOperands.
getTypes() <<
" to " << outputType
2967 <<
" that remained live after conversion";
2970 <<
"see existing live user here: " << *liveUser;
2975 LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
2979 inverseMapping = rewriterImpl.
mapping.getInverse();
2986 *inverseMapping, necessaryMaterializations);
2989 for (
auto *mat : necessaryMaterializations) {
2991 *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
3003 return rewriterImpl.isOpIgnored(user);
3005 if (liveUserIt != result.
user_end()) {
3007 << op->
getName() <<
"' marked as erased";
3008 diag.attachNote(liveUserIt->getLoc())
3022 while (!worklist.empty()) {
3023 Value value = worklist.pop_back_val();
3028 return rewriterImpl.isOpIgnored(user);
3030 if (liveUserIt != value.
user_end())
3032 auto mapIt = inverseMapping.find(value);
3033 if (mapIt != inverseMapping.end())
3034 worklist.append(mapIt->second);
3039 LogicalResult OperationConverter::legalizeChangedResultType(
3050 auto emitConversionError = [&] {
3052 <<
"failed to materialize conversion for result #"
3055 <<
"' that remained live after conversion";
3057 <<
"see existing live user here: " << *liveUser;
3064 return emitConversionError();
3069 rewriter, op->
getLoc(), resultType, newValue);
3070 if (!convertedValue)
3071 return emitConversionError();
3073 rewriterImpl.
mapping.map(result, convertedValue);
3083 assert(!types.empty() &&
"expected valid types");
3084 remapInput(origInputNo, argTypes.size(), types.size());
3089 assert(!types.empty() &&
3090 "1->0 type remappings don't need to be added explicitly");
3091 argTypes.append(types.begin(), types.end());
3095 unsigned newInputNo,
3096 unsigned newInputCount) {
3097 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3098 assert(newInputCount != 0 &&
"expected valid input count");
3099 remappedInputs[origInputNo] =
3100 InputMapping{newInputNo, newInputCount,
nullptr};
3104 Value replacementValue) {
3105 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3106 remappedInputs[origInputNo] =
3113 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3116 cacheReadLock.lock();
3117 auto existingIt = cachedDirectConversions.find(t);
3118 if (existingIt != cachedDirectConversions.end()) {
3119 if (existingIt->second)
3120 results.push_back(existingIt->second);
3121 return success(existingIt->second !=
nullptr);
3123 auto multiIt = cachedMultiConversions.find(t);
3124 if (multiIt != cachedMultiConversions.end()) {
3125 results.append(multiIt->second.begin(), multiIt->second.end());
3131 size_t currentCount = results.size();
3133 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3136 for (
const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
3137 if (std::optional<LogicalResult> result = converter(t, results)) {
3139 cacheWriteLock.lock();
3141 cachedDirectConversions.try_emplace(t,
nullptr);
3144 auto newTypes =
ArrayRef<Type>(results).drop_front(currentCount);
3145 if (newTypes.size() == 1)
3146 cachedDirectConversions.try_emplace(t, newTypes.front());
3148 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3162 return results.size() == 1 ? results.front() :
nullptr;
3168 for (
Type type : types)
3182 return llvm::all_of(*region, [
this](
Block &block) {
3188 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
3200 if (convertedTypes.empty())
3204 result.
addInputs(inputNo, convertedTypes);
3210 unsigned origInputOffset)
const {
3211 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3217 Value TypeConverter::materializeConversion(
3220 for (
const MaterializationCallbackFn &fn : llvm::reverse(materializations))
3221 if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
3226 std::optional<TypeConverter::SignatureConversion>
3230 return std::nullopt;
3253 return impl.getInt() == resultTag;
3257 return impl.getInt() == naTag;
3261 return impl.getInt() == abortTag;
3265 assert(hasResult() &&
"Cannot get result from N/A or abort");
3266 return impl.getPointer();
3269 std::optional<Attribute>
3271 for (
const TypeAttributeConversionCallbackFn &fn :
3272 llvm::reverse(typeAttributeConversions)) {
3277 return std::nullopt;
3279 return std::nullopt;
3289 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3299 typeConverter, &result)))
3316 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3324 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3329 struct AnyFunctionOpInterfaceSignatureConversion
3345 assert(op &&
"Invalid op");
3359 return rewriter.
create(newOp);
3365 patterns.
add<FunctionOpInterfaceSignatureConversion>(
3366 functionLikeOpName, patterns.
getContext(), converter);
3371 patterns.
add<AnyFunctionOpInterfaceSignatureConversion>(
3381 legalOperations[op].action = action;
3386 for (StringRef dialect : dialectNames)
3387 legalDialects[dialect] = action;
3391 -> std::optional<LegalizationAction> {
3392 std::optional<LegalizationInfo> info = getOpInfo(op);
3393 return info ? info->action : std::optional<LegalizationAction>();
3397 -> std::optional<LegalOpDetails> {
3398 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3400 return std::nullopt;
3403 auto isOpLegal = [&] {
3405 if (info->action == LegalizationAction::Dynamic) {
3406 std::optional<bool> result = info->legalityFn(op);
3412 return info->action == LegalizationAction::Legal;
3415 return std::nullopt;
3419 if (info->isRecursivelyLegal) {
3420 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3421 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3423 legalityFnIt->second(op).value_or(
true);
3428 return legalityDetails;
3432 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3436 if (info->action == LegalizationAction::Dynamic) {
3437 std::optional<bool> result = info->legalityFn(op);
3444 return info->action == LegalizationAction::Illegal;
3453 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3455 if (std::optional<bool> result = newCl(op))
3463 void ConversionTarget::setLegalityCallback(
3464 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3465 assert(callback &&
"expected valid legality callback");
3466 auto *infoIt = legalOperations.find(name);
3467 assert(infoIt != legalOperations.end() &&
3468 infoIt->second.action == LegalizationAction::Dynamic &&
3469 "expected operation to already be marked as dynamically legal");
3470 infoIt->second.legalityFn =
3476 auto *infoIt = legalOperations.find(name);
3477 assert(infoIt != legalOperations.end() &&
3478 infoIt->second.action != LegalizationAction::Illegal &&
3479 "expected operation to already be marked as legal");
3480 infoIt->second.isRecursivelyLegal =
true;
3483 std::move(opRecursiveLegalityFns[name]), callback);
3485 opRecursiveLegalityFns.erase(name);
3488 void ConversionTarget::setLegalityCallback(
3490 assert(callback &&
"expected valid legality callback");
3491 for (StringRef dialect : dialects)
3493 std::move(dialectLegalityFns[dialect]), callback);
3496 void ConversionTarget::setLegalityCallback(
3497 const DynamicLegalityCallbackFn &callback) {
3498 assert(callback &&
"expected valid legality callback");
3503 -> std::optional<LegalizationInfo> {
3505 const auto *it = legalOperations.find(op);
3506 if (it != legalOperations.end())
3509 auto dialectIt = legalDialects.find(op.getDialectNamespace());
3510 if (dialectIt != legalDialects.end()) {
3511 DynamicLegalityCallbackFn callback;
3512 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3513 if (dialectFn != dialectLegalityFns.end())
3514 callback = dialectFn->second;
3515 return LegalizationInfo{dialectIt->second,
false,
3519 if (unknownLegalityFn)
3520 return LegalizationInfo{LegalizationAction::Dynamic,
3521 false, unknownLegalityFn};
3522 return std::nullopt;
3525 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3531 auto &rewriterImpl =
3537 auto &rewriterImpl =
3549 return std::move(mappedValues);
3560 return results->front();
3570 auto &rewriterImpl =
3574 if (
Type newType = converter->convertType(type))
3584 auto &rewriterImpl =
3593 return std::move(remappedTypes);
3609 OpConversionMode::Partial);
3627 OpConversionMode::Full);
3644 OpConversionMode::Analysis);
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult legalizeUnresolvedMaterialization(UnresolvedMaterializationRewrite &mat, DenseMap< Operation *, UnresolvedMaterializationRewrite * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Legalize the given unresolved materialization.
static RewriteTy * findSingleRewrite(R &&rewrites, Block *block)
Find the single rewrite object of the specified type and block among the given rewrites.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, ResultRange matResults, ValueRange values, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Replace the results of a materialization operation with the given values.
static void computeNecessaryMaterializations(DenseMap< Operation *, UnresolvedMaterializationRewrite * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping, SetVector< UnresolvedMaterializationRewrite * > &necessaryMaterializations)
Compute all of the unresolved materializations that will persist beyond the conversion process,...
static bool hasRewrite(R &&rewrites, Operation *op)
Return "true" if there is an operation rewrite that matches the specified rewrite type and operation ...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
OpListType & getOperations()
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
LogicalResult convertNonEntryRegionTypes(Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region.
~ConversionPatternRewriter() override
Base class for the conversion patterns.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class provides support for representing a failure result, or a valid value of type T.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within 'results'.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
OpResult getOpResult(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
user_range getUsers()
Returns a range of all users.
result_range getResults()
int getPropertiesStorageSize() const
Returns the properties storage size.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
This class implements the result iterators for the Operation class.
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
MLIRContext * getContext() const
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
The general result of a type attribute conversion callback, allowing for early termination.
Attribute getResult() const
static AttributeConversionResult abort()
static AttributeConversionResult na()
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_end() const
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
Detect if any of the given parameter types has a sub-element handler.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
This class represents an efficient way to signal success or failure.
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
DenseSet< void * > erased
Pointers to all erased operations and blocks.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
ConversionValueMapping mapping
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc, ValueRange inputs, Type origOutputType, Type outputType, const TypeConverter *converter)
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
void notifyOpReplaced(Operation *op, ValueRange newValues)
Notifies that an op is about to be replaced with the given values.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
LogicalResult convertNonEntryRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions={})
Convert the types of non-entry block arguments within the given region.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter)
Apply a signature conversion on the given region, using converter for materializations if not null.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
Value buildUnresolvedMaterialization(MaterializationKind kind, Block *insertBlock, Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, Type origOutputType, const TypeConverter *converter)
Build an unresolved materialization operation given an output type and set of input operands.
FailureOr< Block * > convertBlockSignature(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion *conversion=nullptr)
Attempt to convert the signature of the given block, if successful a new block is returned containing...
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
MLIRContext * context
MLIR context.
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
Value buildUnresolvedTargetMaterialization(Location loc, Value input, Type outputType, const TypeConverter *converter)
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.