10 #include "mlir/Config/mlir-config.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
30 #define DEBUG_TYPE "dialect-conversion"
33 template <
typename... Args>
34 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
37 os.startLine() <<
"} -> SUCCESS";
39 os.getOStream() <<
" : "
40 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41 os.getOStream() <<
"\n";
46 template <
typename... Args>
47 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
50 os.startLine() <<
"} -> FAILURE : "
51 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
63 struct ConversionValueMapping {
68 Value lookupOrDefault(
Value from,
Type desiredType =
nullptr)
const;
78 for (
Value it = newVal; it; it = mapping.lookupOrNull(it))
79 assert(it != oldVal &&
"inserting cyclic mapping");
81 mapping.map(oldVal, newVal);
90 void erase(
Value value) { mapping.erase(value); }
95 for (
auto &it : mapping.getValueMap())
96 inverse[it.second].push_back(it.first);
106 Value ConversionValueMapping::lookupOrDefault(
Value from,
107 Type desiredType)
const {
112 while (
auto mappedValue = mapping.lookupOrNull(from))
120 if (from.
getType() == desiredType)
123 Value mappedValue = mapping.lookupOrNull(from);
130 return desiredValue ? desiredValue : from;
133 Value ConversionValueMapping::lookupOrNull(
Value from,
Type desiredType)
const {
134 Value result = lookupOrDefault(from, desiredType);
135 if (result == from || (desiredType && result.
getType() != desiredType))
140 bool ConversionValueMapping::tryMap(
Value oldVal,
Value newVal) {
141 for (
Value it = newVal; it; it = mapping.lookupOrNull(it))
154 struct RewriterState {
155 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
156 unsigned numReplacedOps)
157 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
158 numReplacedOps(numReplacedOps) {}
161 unsigned numRewrites;
164 unsigned numIgnoredOperations;
167 unsigned numReplacedOps;
199 UnresolvedMaterialization
202 virtual ~IRRewrite() =
default;
205 virtual void rollback() = 0;
224 Kind getKind()
const {
return kind; }
226 static bool classof(
const IRRewrite *
rewrite) {
return true; }
230 : kind(kind), rewriterImpl(rewriterImpl) {}
239 class BlockRewrite :
public IRRewrite {
242 Block *getBlock()
const {
return block; }
244 static bool classof(
const IRRewrite *
rewrite) {
245 return rewrite->getKind() >= Kind::CreateBlock &&
246 rewrite->getKind() <= Kind::ReplaceBlockArg;
252 : IRRewrite(kind, rewriterImpl), block(block) {}
261 class CreateBlockRewrite :
public BlockRewrite {
264 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
266 static bool classof(
const IRRewrite *
rewrite) {
267 return rewrite->getKind() == Kind::CreateBlock;
273 listener->notifyBlockInserted(block, {}, {});
276 void rollback()
override {
279 auto &blockOps = block->getOperations();
280 while (!blockOps.empty())
281 blockOps.remove(blockOps.begin());
282 block->dropAllUses();
283 if (block->getParent())
294 class EraseBlockRewrite :
public BlockRewrite {
298 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block), region(region),
299 insertBeforeBlock(insertBeforeBlock) {}
301 static bool classof(
const IRRewrite *
rewrite) {
302 return rewrite->getKind() == Kind::EraseBlock;
305 ~EraseBlockRewrite()
override {
307 "rewrite was neither rolled back nor committed/cleaned up");
310 void rollback()
override {
313 assert(block &&
"expected block");
314 auto &blockList = region->getBlocks();
318 blockList.insert(before, block);
324 assert(block &&
"expected block");
325 assert(block->empty() &&
"expected empty block");
329 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
330 listener->notifyBlockErased(block);
335 block->dropAllDefinedValueUses();
346 Block *insertBeforeBlock;
352 class InlineBlockRewrite :
public BlockRewrite {
356 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
357 sourceBlock(sourceBlock),
358 firstInlinedInst(sourceBlock->empty() ? nullptr
359 : &sourceBlock->front()),
360 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
366 assert(!getConfig().listener &&
367 "InlineBlockRewrite not supported if listener is attached");
370 static bool classof(
const IRRewrite *
rewrite) {
371 return rewrite->getKind() == Kind::InlineBlock;
374 void rollback()
override {
377 if (firstInlinedInst) {
378 assert(lastInlinedInst &&
"expected operation");
398 class MoveBlockRewrite :
public BlockRewrite {
402 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block), region(region),
403 insertBeforeBlock(insertBeforeBlock) {}
405 static bool classof(
const IRRewrite *
rewrite) {
406 return rewrite->getKind() == Kind::MoveBlock;
414 listener->notifyBlockInserted(block, region,
419 void rollback()
override {
432 Block *insertBeforeBlock;
437 struct ConvertedArgInfo {
438 ConvertedArgInfo(
unsigned newArgIdx,
unsigned newArgSize,
439 Value castValue =
nullptr)
440 : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
455 class BlockTypeConversionRewrite :
public BlockRewrite {
457 BlockTypeConversionRewrite(
461 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, block),
462 origBlock(origBlock), argInfo(argInfo), converter(converter) {}
464 static bool classof(
const IRRewrite *
rewrite) {
465 return rewrite->getKind() == Kind::BlockTypeConversion;
476 void rollback()
override;
493 class ReplaceBlockArgRewrite :
public BlockRewrite {
497 : BlockRewrite(
Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
499 static bool classof(
const IRRewrite *
rewrite) {
500 return rewrite->getKind() == Kind::ReplaceBlockArg;
505 void rollback()
override;
512 class OperationRewrite :
public IRRewrite {
515 Operation *getOperation()
const {
return op; }
517 static bool classof(
const IRRewrite *
rewrite) {
518 return rewrite->getKind() >= Kind::MoveOperation &&
519 rewrite->getKind() <= Kind::UnresolvedMaterialization;
525 : IRRewrite(kind, rewriterImpl), op(op) {}
532 class MoveOperationRewrite :
public OperationRewrite {
536 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op), block(block),
537 insertBeforeOp(insertBeforeOp) {}
539 static bool classof(
const IRRewrite *
rewrite) {
540 return rewrite->getKind() == Kind::MoveOperation;
548 listener->notifyOperationInserted(
554 void rollback()
override {
572 class ModifyOperationRewrite :
public OperationRewrite {
576 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
577 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
578 operands(op->operand_begin(), op->operand_end()),
579 successors(op->successor_begin(), op->successor_end()) {
584 name.initOpProperties(propCopy, prop);
588 static bool classof(
const IRRewrite *
rewrite) {
589 return rewrite->getKind() == Kind::ModifyOperation;
592 ~ModifyOperationRewrite()
override {
593 assert(!propertiesStorage &&
594 "rewrite was neither committed nor rolled back");
600 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
601 listener->notifyOperationModified(op);
603 if (propertiesStorage) {
607 name.destroyOpProperties(propCopy);
608 operator delete(propertiesStorage);
609 propertiesStorage =
nullptr;
613 void rollback()
override {
619 if (propertiesStorage) {
622 name.destroyOpProperties(propCopy);
623 operator delete(propertiesStorage);
624 propertiesStorage =
nullptr;
631 DictionaryAttr attrs;
634 void *propertiesStorage =
nullptr;
641 class ReplaceOperationRewrite :
public OperationRewrite {
646 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
647 converter(converter), changedResults(changedResults) {}
649 static bool classof(
const IRRewrite *
rewrite) {
650 return rewrite->getKind() == Kind::ReplaceOperation;
655 void rollback()
override;
659 const TypeConverter *getConverter()
const {
return converter; }
661 bool hasChangedResults()
const {
return changedResults; }
672 class CreateOperationRewrite :
public OperationRewrite {
676 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
678 static bool classof(
const IRRewrite *
rewrite) {
679 return rewrite->getKind() == Kind::CreateOperation;
685 listener->notifyOperationInserted(op, {});
688 void rollback()
override;
692 enum MaterializationKind {
705 class UnresolvedMaterializationRewrite :
public OperationRewrite {
707 UnresolvedMaterializationRewrite(
709 UnrealizedConversionCastOp op,
const TypeConverter *converter =
nullptr,
710 MaterializationKind kind = MaterializationKind::Target,
711 Type origOutputType =
nullptr)
712 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
713 converterAndKind(converter, kind), origOutputType(origOutputType) {}
715 static bool classof(
const IRRewrite *
rewrite) {
716 return rewrite->getKind() == Kind::UnresolvedMaterialization;
719 UnrealizedConversionCastOp getOperation()
const {
720 return cast<UnrealizedConversionCastOp>(op);
723 void rollback()
override;
729 return converterAndKind.getPointer();
733 MaterializationKind getMaterializationKind()
const {
734 return converterAndKind.getInt();
738 void setMaterializationKind(MaterializationKind kind) {
739 converterAndKind.setInt(kind);
743 Type getOrigOutputType()
const {
return origOutputType; }
748 llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
758 template <
typename RewriteTy,
typename R>
760 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
761 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
762 return rewriteTy && rewriteTy->getOperation() == op;
769 template <
typename RewriteTy,
typename R>
771 RewriteTy *result =
nullptr;
772 for (
auto &
rewrite : rewrites) {
773 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
774 if (rewriteTy && rewriteTy->getBlock() == block) {
776 assert(!result &&
"expected single matching rewrite");
794 : context(ctx), config(config) {}
801 RewriterState getCurrentState();
805 void applyRewrites();
808 void resetState(RewriterState state);
812 template <
typename RewriteTy,
typename... Args>
815 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
820 void undoRewrites(
unsigned numRewritesToKeep = 0);
827 std::optional<Location> inputLoc,
875 Block *applySignatureConversion(
885 Value buildUnresolvedMaterialization(MaterializationKind kind,
907 void notifyOperationInserted(
Operation *op,
914 void notifyBlockIsBeingErased(
Block *block);
917 void notifyBlockInserted(
Block *block,
Region *previous,
921 void notifyBlockBeingInlined(
Block *block,
Block *srcBlock,
943 if (erased.contains(op))
951 if (erased.contains(block))
953 assert(block->
empty() &&
"expected empty block");
1009 llvm::ScopedPrinter logger{llvm::dbgs()};
1016 return rewriterImpl.
config;
1019 void BlockTypeConversionRewrite::commit(
RewriterBase &rewriter) {
1023 if (
auto *listener =
1024 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1026 listener->notifyOperationModified(op);
1029 for (
auto [origArg, info] :
1030 llvm::zip_equal(origBlock->getArguments(), argInfo)) {
1034 rewriterImpl.
mapping.lookupOrNull(origArg, origArg.getType()))
1040 Value castValue = info->castValue;
1041 assert(info->newArgSize >= 1 && castValue &&
"expected 1->1+ mapping");
1044 if (!origArg.use_empty()) {
1046 castValue, origArg.getType()));
1051 void BlockTypeConversionRewrite::rollback() {
1055 LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1061 OpBuilder builder(it.value().getContext(), &rewriterImpl);
1062 builder.setInsertionPointToStart(block);
1066 if (rewriterImpl.
mapping.lookupOrNull(origArg, origArg.
getType()))
1068 Operation *liveUser = findLiveUser(origArg);
1072 Value replacementValue = rewriterImpl.
mapping.lookupOrDefault(origArg);
1073 bool isDroppedArg = replacementValue == origArg;
1075 builder.setInsertionPointAfterValue(replacementValue);
1078 newArg = converter->materializeSourceConversion(
1082 "materialization hook did not provide a value of the expected "
1088 <<
"failed to materialize conversion for block argument #"
1089 << it.index() <<
" that remained live after conversion, type was "
1092 diag <<
", with target type " << replacementValue.
getType();
1094 <<
"see existing live user here: " << *liveUser;
1097 rewriterImpl.
mapping.map(origArg, newArg);
1102 void ReplaceBlockArgRewrite::commit(
RewriterBase &rewriter) {
1103 Value repl = rewriterImpl.
mapping.lookupOrNull(arg, arg.getType());
1107 if (isa<BlockArgument>(repl)) {
1115 Operation *replOp = cast<OpResult>(repl).getOwner();
1123 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.
mapping.erase(arg); }
1125 void ReplaceOperationRewrite::commit(
RewriterBase &rewriter) {
1127 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1132 return rewriterImpl.mapping.lookupOrNull(result, result.getType());
1137 listener->notifyOperationReplaced(op, replacements);
1140 for (
auto [result, newValue] :
1141 llvm::zip_equal(op->
getResults(), replacements))
1147 if (getConfig().unlegalizedOps)
1148 getConfig().unlegalizedOps->erase(op);
1154 [&](
Operation *op) { listener->notifyOperationErased(op); });
1162 void ReplaceOperationRewrite::rollback() {
1164 rewriterImpl.
mapping.erase(result);
1167 void ReplaceOperationRewrite::cleanup(
RewriterBase &rewriter) {
1171 void CreateOperationRewrite::rollback() {
1173 while (!region.getBlocks().empty())
1174 region.getBlocks().remove(region.getBlocks().begin());
1180 void UnresolvedMaterializationRewrite::rollback() {
1181 if (getMaterializationKind() == MaterializationKind::Target) {
1183 rewriterImpl.
mapping.erase(input);
1188 void UnresolvedMaterializationRewrite::cleanup(
RewriterBase &rewriter) {
1201 rewrite->cleanup(eraseRewriter);
1216 while (
ignoredOps.size() != state.numIgnoredOperations)
1219 while (
replacedOps.size() != state.numReplacedOps)
1225 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1227 rewrites.resize(numRewritesToKeep);
1231 StringRef valueDiagTag, std::optional<Location> inputLoc,
1234 remapped.reserve(llvm::size(values));
1238 Value operand = it.value();
1250 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1251 << it.index() <<
", type was " << origType;
1257 if (legalTypes.size() == 1)
1258 desiredType = legalTypes.front();
1267 Value newOperand =
mapping.lookupOrDefault(operand, desiredType);
1277 newOperand = castValue;
1279 remapped.push_back(newOperand);
1319 if (!region->
empty())
1330 if (region->
empty())
1337 rewriter, ®ion->
front(), &converter, entryConversion);
1346 if (region->
empty())
1351 assert((blockConversions.empty() ||
1352 blockConversions.size() == region->
getBlocks().size() - 1) &&
1353 "expected either to provide no SignatureConversions at all or to "
1354 "provide a SignatureConversion for each non-entry block");
1357 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1359 blockConversions.empty()
1362 &blockConversions[blockIdx++]);
1386 for (
unsigned i = 0; i < origArgCount; ++i) {
1388 if (!inputMap || inputMap->replacementValue)
1391 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1392 newLocs[inputMap->inputNo +
j] = origLoc;
1399 convertedTypes, newLocs);
1409 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->
end());
1412 while (!block->
empty())
1422 argInfo.resize(origArgCount);
1424 for (
unsigned i = 0; i != origArgCount; ++i) {
1432 if (inputMap->replacementValue) {
1433 assert(inputMap->size == 0 &&
1434 "invalid to provide a replacement value when the argument isn't "
1436 mapping.map(origArg, inputMap->replacementValue);
1437 appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1443 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1455 if (replArgs.size() == 1 &&
1456 (!converter || replArgs[0].getType() == origArg.
getType())) {
1457 newArg = replArgs.front();
1462 Type outputType = origOutputType;
1464 outputType = legalOutputType;
1467 newBlock, origArg.
getLoc(), replArgs, origOutputType, outputType,
1472 appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1473 argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
1476 appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
1497 if (inputs.size() == 1 && inputs.front().
getType() == outputType)
1498 return inputs.front();
1502 OpBuilder builder(insertBlock, insertPt);
1504 builder.
create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1505 appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1507 return convertOp.getResult(0);
1513 block->
begin(), loc, inputs, outputType,
1514 origOutputType, converter);
1521 if (
OpResult inputRes = dyn_cast<OpResult>(input))
1522 insertPt = ++inputRes.getOwner()->getIterator();
1525 insertBlock, insertPt, loc, input,
1526 outputType, outputType, converter);
1535 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"'(" << op
1539 "attempting to insert into a block within a replaced/erased op");
1541 if (!previous.
isSet()) {
1543 appendRewrite<CreateOperationRewrite>(op);
1549 appendRewrite<MoveOperationRewrite>(op, previous.
getBlock(), prevOp);
1555 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1558 bool resultChanged =
false;
1561 for (
auto [newValue, result] : llvm::zip(newValues, op->
getResults())) {
1563 resultChanged =
true;
1567 mapping.map(result, newValue);
1568 resultChanged |= (newValue.getType() != result.
getType());
1580 Block *origNextBlock = block->getNextNode();
1581 appendRewrite<EraseBlockRewrite>(block, region, origNextBlock);
1587 "attempting to insert into a region within a replaced/erased op");
1592 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
1593 <<
"'(" << parent <<
")\n";
1596 <<
"** Insert Block into detached Region (nullptr parent op)'";
1602 appendRewrite<CreateBlockRewrite>(block);
1605 Block *prevBlock = previousIt == previous->
end() ? nullptr : &*previousIt;
1606 appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1611 appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1618 reasonCallback(
diag);
1619 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
1629 ConversionPatternRewriter::ConversionPatternRewriter(
1633 setListener(
impl.get());
1639 assert(op && newOp &&
"expected non-null op");
1645 "incorrect # of replacement values");
1647 impl->logger.startLine()
1648 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
1650 impl->notifyOpReplaced(op, newValues);
1655 impl->logger.startLine()
1656 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
1659 impl->notifyOpReplaced(op, nullRepls);
1664 "attempting to erase a block within a replaced/erased op");
1674 impl->notifyBlockIsBeingErased(block);
1682 "attempting to apply a signature conversion to a block within a "
1683 "replaced/erased op");
1684 return impl->applySignatureConversion(*
this, region, conversion, converter);
1691 "attempting to apply a signature conversion to a block within a "
1692 "replaced/erased op");
1693 return impl->convertRegionTypes(*
this, region, converter, entryConversion);
1700 "attempting to apply a signature conversion to a block within a "
1701 "replaced/erased op");
1702 return impl->convertNonEntryRegionTypes(*
this, region, converter,
1710 impl->logger.startLine() <<
"** Replace Argument : '" << from
1711 <<
"'(in region of '" << parentOp->
getName()
1714 impl->appendRewrite<ReplaceBlockArgRewrite>(from.
getOwner(), from);
1715 impl->mapping.map(
impl->mapping.lookupOrDefault(from), to);
1720 if (
failed(
impl->remapValues(
"value", std::nullopt, *
this, key,
1723 return remappedValues.front();
1731 return impl->remapValues(
"value", std::nullopt, *
this, keys,
1740 "incorrect # of argument replacement values");
1742 "attempting to inline a block from a replaced/erased op");
1744 "attempting to inline a block into a replaced/erased op");
1745 auto opIgnored = [&](
Operation *op) {
return impl->isOpIgnored(op); };
1748 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
1749 "expected 'source' to have no predecessors");
1758 bool fastPath = !
impl->config.listener;
1761 impl->notifyBlockBeingInlined(dest, source, before);
1764 for (
auto it : llvm::zip(source->
getArguments(), argValues))
1765 replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1772 while (!source->
empty())
1773 moveOpBefore(&source->
front(), dest, before);
1781 assert(!
impl->wasOpReplaced(op) &&
1782 "attempting to modify a replaced/erased op");
1784 impl->pendingRootUpdates.insert(op);
1786 impl->appendRewrite<ModifyOperationRewrite>(op);
1790 assert(!
impl->wasOpReplaced(op) &&
1791 "attempting to modify a replaced/erased op");
1796 assert(
impl->pendingRootUpdates.erase(op) &&
1797 "operation did not have a pending in-place update");
1803 assert(
impl->pendingRootUpdates.erase(op) &&
1804 "operation did not have a pending in-place update");
1807 auto it = llvm::find_if(
1808 llvm::reverse(
impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
1809 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1810 return modifyRewrite && modifyRewrite->getOperation() == op;
1812 assert(it !=
impl->rewrites.rend() &&
"no root update started on op");
1814 int updateIdx = std::prev(
impl->rewrites.rend()) - it;
1815 impl->rewrites.erase(
impl->rewrites.begin() + updateIdx);
1830 auto &rewriterImpl = dialectRewriter.getImpl();
1834 getTypeConverter());
1842 return matchAndRewrite(op, operands, dialectRewriter);
1854 class OperationLegalizer {
1890 RewriterState &curState);
1894 legalizePatternBlockRewrites(
Operation *op,
1897 RewriterState &state, RewriterState &newState);
1900 RewriterState &state, RewriterState &newState);
1903 RewriterState &state,
1904 RewriterState &newState);
1914 void buildLegalizationGraph(
1915 LegalizationPatterns &anyOpLegalizerPatterns,
1926 void computeLegalizationGraphBenefit(
1927 LegalizationPatterns &anyOpLegalizerPatterns,
1932 unsigned computeOpLegalizationDepth(
1939 unsigned applyCostModelToPatterns(
1940 LegalizationPatterns &patterns,
1961 : target(targetInfo), applicator(patterns), config(config) {
1965 LegalizationPatterns anyOpLegalizerPatterns;
1967 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1968 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1971 bool OperationLegalizer::isIllegal(
Operation *op)
const {
1972 return target.isIllegal(op);
1976 OperationLegalizer::legalize(
Operation *op,
1979 const char *logLineComment =
1980 "//===-------------------------------------------===//\n";
1985 logger.getOStream() <<
"\n";
1986 logger.startLine() << logLineComment;
1987 logger.startLine() <<
"Legalizing operation : '" << op->
getName() <<
"'("
1993 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1994 logger.getOStream() <<
"\n\n";
1999 if (
auto legalityInfo = target.isLegal(op)) {
2002 logger,
"operation marked legal by the target{0}",
2003 legalityInfo->isRecursivelyLegal
2004 ?
"; NOTE: operation is recursively legal; skipping internals"
2006 logger.startLine() << logLineComment;
2011 if (legalityInfo->isRecursivelyLegal) {
2024 logSuccess(logger,
"operation marked 'ignored' during conversion");
2025 logger.startLine() << logLineComment;
2033 if (
succeeded(legalizeWithFold(op, rewriter))) {
2036 logger.startLine() << logLineComment;
2042 if (
succeeded(legalizeWithPattern(op, rewriter))) {
2045 logger.startLine() << logLineComment;
2051 logFailure(logger,
"no matched legalization pattern");
2052 logger.startLine() << logLineComment;
2058 OperationLegalizer::legalizeWithFold(
Operation *op,
2060 auto &rewriterImpl = rewriter.
getImpl();
2064 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2065 rewriterImpl.
logger.indent();
2077 if (replacementValues.empty())
2078 return legalize(op, rewriter);
2081 rewriter.
replaceOp(op, replacementValues);
2084 for (
unsigned i = curState.numRewrites, e = rewriterImpl.
rewrites.size();
2087 dyn_cast<CreateOperationRewrite>(rewriterImpl.
rewrites[i].get());
2090 if (
failed(legalize(createOp->getOperation(), rewriter))) {
2092 "failed to legalize generated constant '{0}'",
2093 createOp->getOperation()->getName()));
2104 OperationLegalizer::legalizeWithPattern(
Operation *op,
2106 auto &rewriterImpl = rewriter.
getImpl();
2109 auto canApply = [&](
const Pattern &pattern) {
2110 bool canApply = canApplyPattern(op, pattern, rewriter);
2111 if (canApply && config.listener)
2112 config.listener->notifyPatternBegin(pattern, op);
2118 auto onFailure = [&](
const Pattern &pattern) {
2124 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2130 if (config.listener)
2131 config.listener->notifyPatternEnd(pattern,
failure());
2133 appliedPatterns.erase(&pattern);
2138 auto onSuccess = [&](
const Pattern &pattern) {
2140 auto result = legalizePatternResult(op, pattern, rewriter, curState);
2141 appliedPatterns.erase(&pattern);
2144 if (config.listener)
2145 config.listener->notifyPatternEnd(pattern, result);
2150 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2154 bool OperationLegalizer::canApplyPattern(
Operation *op,
const Pattern &pattern,
2157 auto &os = rewriter.
getImpl().logger;
2158 os.getOStream() <<
"\n";
2159 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2161 os.getOStream() <<
")' {\n";
2168 !appliedPatterns.insert(&pattern).second) {
2177 OperationLegalizer::legalizePatternResult(
Operation *op,
const Pattern &pattern,
2179 RewriterState &curState) {
2183 assert(
impl.pendingRootUpdates.empty() &&
"dangling root updates");
2185 auto newRewrites = llvm::drop_begin(
impl.rewrites, curState.numRewrites);
2186 auto replacedRoot = [&] {
2187 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2189 auto updatedRootInPlace = [&] {
2190 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2192 assert((replacedRoot() || updatedRootInPlace()) &&
2193 "expected pattern to replace the root operation");
2197 RewriterState newState =
impl.getCurrentState();
2198 if (
failed(legalizePatternBlockRewrites(op, rewriter,
impl, curState,
2200 failed(legalizePatternRootUpdates(rewriter,
impl, curState, newState)) ||
2201 failed(legalizePatternCreatedOperations(rewriter,
impl, curState,
2206 LLVM_DEBUG(
logSuccess(
impl.logger,
"pattern applied successfully"));
2210 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2213 RewriterState &newState) {
2218 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2219 BlockRewrite *
rewrite = dyn_cast<BlockRewrite>(
impl.rewrites[i].get());
2223 if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2224 ReplaceBlockArgRewrite>(
rewrite))
2233 if (
auto *converter =
impl.regionToConverter.lookup(block->
getParent())) {
2234 if (
failed(
impl.convertBlockSignature(rewriter, block, converter))) {
2235 LLVM_DEBUG(
logFailure(
impl.logger,
"failed to convert types of moved "
2246 if (operationsToIgnore.empty()) {
2247 for (
unsigned i = state.numRewrites, e =
impl.rewrites.size(); i != e;
2250 dyn_cast<CreateOperationRewrite>(
impl.rewrites[i].get());
2253 operationsToIgnore.insert(createOp->getOperation());
2258 if (operationsToIgnore.insert(parentOp).second &&
2259 failed(legalize(parentOp, rewriter))) {
2261 "operation '{0}'({1}) became illegal after rewrite",
2262 parentOp->
getName(), parentOp));
2269 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2271 RewriterState &state, RewriterState &newState) {
2272 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2273 auto *createOp = dyn_cast<CreateOperationRewrite>(
impl.rewrites[i].get());
2276 Operation *op = createOp->getOperation();
2277 if (
failed(legalize(op, rewriter))) {
2279 "failed to legalize generated operation '{0}'({1})",
2287 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2289 RewriterState &state, RewriterState &newState) {
2290 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2291 auto *
rewrite = dyn_cast<ModifyOperationRewrite>(
impl.rewrites[i].get());
2295 if (
failed(legalize(op, rewriter))) {
2297 impl.logger,
"failed to legalize operation updated in-place '{0}'",
2308 void OperationLegalizer::buildLegalizationGraph(
2309 LegalizationPatterns &anyOpLegalizerPatterns,
2320 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2321 std::optional<OperationName> root = pattern.
getRootKind();
2327 anyOpLegalizerPatterns.push_back(&pattern);
2332 if (target.getOpAction(*root) == LegalizationAction::Legal)
2337 invalidPatterns[*root].insert(&pattern);
2339 parentOps[op].insert(*root);
2342 patternWorklist.insert(&pattern);
2350 if (!anyOpLegalizerPatterns.empty()) {
2351 for (
const Pattern *pattern : patternWorklist)
2352 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2356 while (!patternWorklist.empty()) {
2357 auto *pattern = patternWorklist.pop_back_val();
2361 std::optional<LegalizationAction> action = target.getOpAction(op);
2362 return !legalizerPatterns.count(op) &&
2363 (!action || action == LegalizationAction::Illegal);
2369 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2370 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2374 for (
auto op : parentOps[*pattern->
getRootKind()])
2375 patternWorklist.set_union(invalidPatterns[op]);
2379 void OperationLegalizer::computeLegalizationGraphBenefit(
2380 LegalizationPatterns &anyOpLegalizerPatterns,
2386 for (
auto &opIt : legalizerPatterns)
2387 if (!minOpPatternDepth.count(opIt.first))
2388 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2394 if (!anyOpLegalizerPatterns.empty())
2395 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2401 applicator.applyCostModel([&](
const Pattern &pattern) {
2403 if (std::optional<OperationName> rootName = pattern.
getRootKind())
2404 orderedPatternList = legalizerPatterns[*rootName];
2406 orderedPatternList = anyOpLegalizerPatterns;
2409 auto *it = llvm::find(orderedPatternList, &pattern);
2410 if (it == orderedPatternList.end())
2414 return PatternBenefit(std::distance(it, orderedPatternList.end()));
2418 unsigned OperationLegalizer::computeOpLegalizationDepth(
2422 auto depthIt = minOpPatternDepth.find(op);
2423 if (depthIt != minOpPatternDepth.end())
2424 return depthIt->second;
2428 auto opPatternsIt = legalizerPatterns.find(op);
2429 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2438 unsigned minDepth = applyCostModelToPatterns(
2439 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2440 minOpPatternDepth[op] = minDepth;
2444 unsigned OperationLegalizer::applyCostModelToPatterns(
2445 LegalizationPatterns &patterns,
2452 patternsByDepth.reserve(patterns.size());
2453 for (
const Pattern *pattern : patterns) {
2456 unsigned generatedOpDepth = computeOpLegalizationDepth(
2457 generatedOp, minOpPatternDepth, legalizerPatterns);
2458 depth =
std::max(depth, generatedOpDepth + 1);
2460 patternsByDepth.emplace_back(pattern, depth);
2463 minDepth =
std::min(minDepth, depth);
2468 if (patternsByDepth.size() == 1)
2472 std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2473 [](
const std::pair<const Pattern *, unsigned> &lhs,
2474 const std::pair<const Pattern *, unsigned> &rhs) {
2477 if (lhs.second != rhs.second)
2478 return lhs.second < rhs.second;
2481 auto lhsBenefit = lhs.first->getBenefit();
2482 auto rhsBenefit = rhs.first->getBenefit();
2483 return lhsBenefit > rhsBenefit;
2488 for (
auto &patternIt : patternsByDepth)
2489 patterns.push_back(patternIt.first);
2497 enum OpConversionMode {
2520 OpConversionMode mode)
2521 : config(config), opLegalizer(target, patterns, this->config),
2563 OperationLegalizer opLegalizer;
2566 OpConversionMode mode;
2573 if (
failed(opLegalizer.legalize(op, rewriter))) {
2576 if (mode == OpConversionMode::Full)
2578 <<
"failed to legalize operation '" << op->
getName() <<
"'";
2582 if (mode == OpConversionMode::Partial) {
2583 if (opLegalizer.isIllegal(op))
2585 <<
"failed to legalize operation '" << op->
getName()
2586 <<
"' that was explicitly marked illegal";
2590 }
else if (mode == OpConversionMode::Analysis) {
2607 for (
auto *op : ops) {
2610 toConvert.push_back(op);
2613 auto legalityInfo = target.
isLegal(op);
2614 if (legalityInfo && legalityInfo->isRecursivelyLegal)
2624 for (
auto *op : toConvert)
2625 if (
failed(convert(rewriter, op)))
2631 if (
failed(finalize(rewriter)))
2636 if (mode == OpConversionMode::Analysis) {
2646 std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
2648 if (
failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2650 failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2654 for (
unsigned i = 0; i < rewriterImpl.
rewrites.size(); ++i) {
2655 auto *opReplacement =
2656 dyn_cast<ReplaceOperationRewrite>(rewriterImpl.
rewrites[i].get());
2657 if (!opReplacement || !opReplacement->hasChangedResults())
2659 Operation *op = opReplacement->getOperation();
2661 Value newValue = rewriterImpl.
mapping.lookupOrNull(result);
2666 if (
failed(legalizeErasedResult(op, result, rewriterImpl)))
2676 if (!inverseMapping)
2677 inverseMapping = rewriterImpl.
mapping.getInverse();
2681 if (
failed(legalizeChangedResultType(
2682 op, result, newValue, opReplacement->getConverter(), rewriter,
2683 rewriterImpl, *inverseMapping)))
2690 LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2695 auto findLiveUser = [&](
Value val) {
2696 auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](
Operation *user) {
2697 return rewriterImpl.isOpIgnored(user);
2699 return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2702 for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.
rewrites.size());
2705 if (
auto *blockTypeConversionRewrite =
2706 dyn_cast<BlockTypeConversionRewrite>(
rewrite.get()))
2707 if (
failed(blockTypeConversionRewrite->materializeLiveConversions(
2723 for (
auto [matResult, newValue] : llvm::zip(matResults, values)) {
2724 auto inverseMapIt = inverseMapping.find(matResult);
2725 if (inverseMapIt == inverseMapping.end())
2734 for (
Value inverseMapVal : inverseMapIt->second)
2735 if (!rewriterImpl.
mapping.tryMap(inverseMapVal, newValue))
2736 rewriterImpl.
mapping.erase(inverseMapVal);
2744 &materializationOps,
2749 auto isLive = [&](
Value value) {
2751 auto matIt = materializationOps.find(user);
2752 if (matIt != materializationOps.end())
2753 return !necessaryMaterializations.count(matIt->second);
2757 for (
Value inv : inverseMapping.lookup(value))
2758 if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
2768 Value remappedValue = rewriterImpl.
mapping.lookupOrDefault(value, type);
2769 if (remappedValue.
getType() == type && remappedValue != invalidRoot)
2770 return remappedValue;
2775 auto inputCastOp = value.
getDefiningOp<UnrealizedConversionCastOp>();
2776 if (inputCastOp && inputCastOp->getNumOperands() == 1)
2777 return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
2785 auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(
rewrite.get());
2788 materializationOps.try_emplace(mat->getOperation(), mat);
2789 worklist.insert(mat);
2791 while (!worklist.empty()) {
2792 UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
2793 UnrealizedConversionCastOp op = mat->getOperation();
2796 assert(op->
getNumResults() == 1 &&
"unexpected materialization type");
2804 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
2807 if (castOp->getResultTypes() == inputOperands.
getTypes()) {
2810 necessaryMaterializations.remove(materializationOps.lookup(user));
2816 if (inputOperands.size() == 1) {
2819 Value remappedValue =
2820 lookupRemappedValue(opResult, inputOperands[0], outputType);
2821 if (remappedValue && remappedValue != opResult) {
2824 necessaryMaterializations.remove(mat);
2832 if (llvm::any_of(op->
getOperands(), llvm::IsaPred<BlockArgument>) ||
2833 llvm::any_of(inverseMapping[op->
getResult(0)],
2834 llvm::IsaPred<BlockArgument>)) {
2844 bool isMaterializationLive = isLive(opResult);
2846 isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive);
2847 if (!isMaterializationLive)
2849 if (!necessaryMaterializations.insert(mat))
2853 for (
Value input : inputOperands) {
2854 if (
auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
2855 if (
auto *mat = materializationOps.lookup(parentOp))
2856 worklist.insert(mat);
2865 UnresolvedMaterializationRewrite &mat,
2867 &materializationOps,
2871 auto findLiveUser = [&](
auto &&users) {
2872 auto liveUserIt = llvm::find_if_not(
2874 return liveUserIt == users.end() ? nullptr : *liveUserIt;
2881 Value remappedValue = rewriterImpl.
mapping.lookupOrDefault(value, type);
2882 if (remappedValue.
getType() == type)
2883 return remappedValue;
2887 UnrealizedConversionCastOp op = mat.getOperation();
2899 auto valueCast = value.
getDefiningOp<UnrealizedConversionCastOp>();
2903 auto matIt = materializationOps.find(valueCast);
2904 if (matIt != materializationOps.end())
2906 *matIt->second, materializationOps, rewriter, rewriterImpl,
2914 if (inputOperands.size() == 1) {
2917 Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
2918 if (remappedValue && remappedValue != opResult) {
2931 if (inputOperands.size() == 1)
2936 Value newMaterialization;
2937 switch (mat.getMaterializationKind()) {
2947 newMaterialization = converter->materializeArgumentConversion(
2948 rewriter, op->
getLoc(), mat.getOrigOutputType(), inputOperands);
2949 if (newMaterialization)
2955 case MaterializationKind::Target:
2956 newMaterialization = converter->materializeTargetConversion(
2957 rewriter, op->
getLoc(), outputType, inputOperands);
2960 if (newMaterialization) {
2968 <<
"failed to legalize unresolved materialization "
2970 << inputOperands.
getTypes() <<
" to " << outputType
2971 <<
" that remained live after conversion";
2974 <<
"see existing live user here: " << *liveUser;
2979 LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
2983 inverseMapping = rewriterImpl.
mapping.getInverse();
2990 *inverseMapping, necessaryMaterializations);
2993 for (
auto *mat : necessaryMaterializations) {
2995 *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
3007 return rewriterImpl.isOpIgnored(user);
3009 if (liveUserIt != result.
user_end()) {
3011 << op->
getName() <<
"' marked as erased";
3012 diag.attachNote(liveUserIt->getLoc())
3026 while (!worklist.empty()) {
3027 Value value = worklist.pop_back_val();
3032 return rewriterImpl.isOpIgnored(user);
3034 if (liveUserIt != value.
user_end())
3036 auto mapIt = inverseMapping.find(value);
3037 if (mapIt != inverseMapping.end())
3038 worklist.append(mapIt->second);
3043 LogicalResult OperationConverter::legalizeChangedResultType(
3054 auto emitConversionError = [&] {
3056 <<
"failed to materialize conversion for result #"
3059 <<
"' that remained live after conversion";
3061 <<
"see existing live user here: " << *liveUser;
3068 return emitConversionError();
3073 rewriter, op->
getLoc(), resultType, newValue);
3074 if (!convertedValue)
3075 return emitConversionError();
3077 rewriterImpl.
mapping.map(result, convertedValue);
3087 assert(!types.empty() &&
"expected valid types");
3088 remapInput(origInputNo, argTypes.size(), types.size());
3093 assert(!types.empty() &&
3094 "1->0 type remappings don't need to be added explicitly");
3095 argTypes.append(types.begin(), types.end());
3099 unsigned newInputNo,
3100 unsigned newInputCount) {
3101 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3102 assert(newInputCount != 0 &&
"expected valid input count");
3103 remappedInputs[origInputNo] =
3104 InputMapping{newInputNo, newInputCount,
nullptr};
3108 Value replacementValue) {
3109 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3110 remappedInputs[origInputNo] =
3117 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3120 cacheReadLock.lock();
3121 auto existingIt = cachedDirectConversions.find(t);
3122 if (existingIt != cachedDirectConversions.end()) {
3123 if (existingIt->second)
3124 results.push_back(existingIt->second);
3125 return success(existingIt->second !=
nullptr);
3127 auto multiIt = cachedMultiConversions.find(t);
3128 if (multiIt != cachedMultiConversions.end()) {
3129 results.append(multiIt->second.begin(), multiIt->second.end());
3135 size_t currentCount = results.size();
3137 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3140 for (
const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
3141 if (std::optional<LogicalResult> result = converter(t, results)) {
3143 cacheWriteLock.lock();
3145 cachedDirectConversions.try_emplace(t,
nullptr);
3148 auto newTypes =
ArrayRef<Type>(results).drop_front(currentCount);
3149 if (newTypes.size() == 1)
3150 cachedDirectConversions.try_emplace(t, newTypes.front());
3152 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3166 return results.size() == 1 ? results.front() :
nullptr;
3172 for (
Type type : types)
3186 return llvm::all_of(*region, [
this](
Block &block) {
3192 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
3204 if (convertedTypes.empty())
3208 result.
addInputs(inputNo, convertedTypes);
3214 unsigned origInputOffset)
const {
3215 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3221 Value TypeConverter::materializeConversion(
3224 for (
const MaterializationCallbackFn &fn : llvm::reverse(materializations))
3225 if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
3230 std::optional<TypeConverter::SignatureConversion>
3234 return std::nullopt;
3257 return impl.getInt() == resultTag;
3261 return impl.getInt() == naTag;
3265 return impl.getInt() == abortTag;
3269 assert(hasResult() &&
"Cannot get result from N/A or abort");
3270 return impl.getPointer();
3273 std::optional<Attribute>
3275 for (
const TypeAttributeConversionCallbackFn &fn :
3276 llvm::reverse(typeAttributeConversions)) {
3281 return std::nullopt;
3283 return std::nullopt;
3293 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3303 typeConverter, &result)))
3320 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3328 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3333 struct AnyFunctionOpInterfaceSignatureConversion
3349 assert(op &&
"Invalid op");
3363 return rewriter.
create(newOp);
3369 patterns.
add<FunctionOpInterfaceSignatureConversion>(
3370 functionLikeOpName, patterns.
getContext(), converter);
3375 patterns.
add<AnyFunctionOpInterfaceSignatureConversion>(
3385 legalOperations[op].action = action;
3390 for (StringRef dialect : dialectNames)
3391 legalDialects[dialect] = action;
3395 -> std::optional<LegalizationAction> {
3396 std::optional<LegalizationInfo> info = getOpInfo(op);
3397 return info ? info->action : std::optional<LegalizationAction>();
3401 -> std::optional<LegalOpDetails> {
3402 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3404 return std::nullopt;
3407 auto isOpLegal = [&] {
3409 if (info->action == LegalizationAction::Dynamic) {
3410 std::optional<bool> result = info->legalityFn(op);
3416 return info->action == LegalizationAction::Legal;
3419 return std::nullopt;
3423 if (info->isRecursivelyLegal) {
3424 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3425 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3427 legalityFnIt->second(op).value_or(
true);
3432 return legalityDetails;
3436 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3440 if (info->action == LegalizationAction::Dynamic) {
3441 std::optional<bool> result = info->legalityFn(op);
3448 return info->action == LegalizationAction::Illegal;
3457 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3459 if (std::optional<bool> result = newCl(op))
3467 void ConversionTarget::setLegalityCallback(
3468 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3469 assert(callback &&
"expected valid legality callback");
3470 auto *infoIt = legalOperations.find(name);
3471 assert(infoIt != legalOperations.end() &&
3472 infoIt->second.action == LegalizationAction::Dynamic &&
3473 "expected operation to already be marked as dynamically legal");
3474 infoIt->second.legalityFn =
3480 auto *infoIt = legalOperations.find(name);
3481 assert(infoIt != legalOperations.end() &&
3482 infoIt->second.action != LegalizationAction::Illegal &&
3483 "expected operation to already be marked as legal");
3484 infoIt->second.isRecursivelyLegal =
true;
3487 std::move(opRecursiveLegalityFns[name]), callback);
3489 opRecursiveLegalityFns.erase(name);
3492 void ConversionTarget::setLegalityCallback(
3494 assert(callback &&
"expected valid legality callback");
3495 for (StringRef dialect : dialects)
3497 std::move(dialectLegalityFns[dialect]), callback);
3500 void ConversionTarget::setLegalityCallback(
3501 const DynamicLegalityCallbackFn &callback) {
3502 assert(callback &&
"expected valid legality callback");
3507 -> std::optional<LegalizationInfo> {
3509 const auto *it = legalOperations.find(op);
3510 if (it != legalOperations.end())
3513 auto dialectIt = legalDialects.find(op.getDialectNamespace());
3514 if (dialectIt != legalDialects.end()) {
3515 DynamicLegalityCallbackFn callback;
3516 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3517 if (dialectFn != dialectLegalityFns.end())
3518 callback = dialectFn->second;
3519 return LegalizationInfo{dialectIt->second,
false,
3523 if (unknownLegalityFn)
3524 return LegalizationInfo{LegalizationAction::Dynamic,
3525 false, unknownLegalityFn};
3526 return std::nullopt;
3529 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3535 auto &rewriterImpl =
3541 auto &rewriterImpl =
3553 return std::move(mappedValues);
3564 return results->front();
3574 auto &rewriterImpl =
3578 if (
Type newType = converter->convertType(type))
3588 auto &rewriterImpl =
3597 return std::move(remappedTypes);
3613 OpConversionMode::Partial);
3631 OpConversionMode::Full);
3648 OpConversionMode::Analysis);
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult legalizeUnresolvedMaterialization(UnresolvedMaterializationRewrite &mat, DenseMap< Operation *, UnresolvedMaterializationRewrite * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Legalize the given unresolved materialization.
static RewriteTy * findSingleRewrite(R &&rewrites, Block *block)
Find the single rewrite object of the specified type and block among the given rewrites.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, ResultRange matResults, ValueRange values, DenseMap< Value, SmallVector< Value >> &inverseMapping)
Replace the results of a materialization operation with the given values.
static void computeNecessaryMaterializations(DenseMap< Operation *, UnresolvedMaterializationRewrite * > &materializationOps, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, DenseMap< Value, SmallVector< Value >> &inverseMapping, SetVector< UnresolvedMaterializationRewrite * > &necessaryMaterializations)
Compute all of the unresolved materializations that will persist beyond the conversion process,...
static bool hasRewrite(R &&rewrites, Operation *op)
Return "true" if there is an operation rewrite that matches the specified rewrite type and operation ...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
OpListType & getOperations()
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to the entry block of the given region.
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
LogicalResult convertNonEntryRegionTypes(Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region.
~ConversionPatternRewriter() override
Base class for the conversion patterns.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class provides support for representing a failure result, or a valid value of type T.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
OpResult getOpResult(unsigned idx)
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
user_range getUsers()
Returns a range of all users.
result_range getResults()
int getPropertiesStorageSize() const
Returns the properties storage size.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
This class implements the result iterators for the Operation class.
std::enable_if_t<!std::is_convertible< ValuesT, Operation * >::value > replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this range with the provided 'values'.
MLIRContext * getContext() const
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
The general result of a type attribute conversion callback, allowing for early termination.
Attribute getResult() const
static AttributeConversionResult abort()
static AttributeConversionResult na()
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_end() const
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
Detect if any of the given parameter types has a sub-element handler.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
This class represents an efficient way to signal success or failure.
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
DenseSet< void * > erased
Pointers to all erased operations and blocks.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
ConversionValueMapping mapping
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc, ValueRange inputs, Type origOutputType, Type outputType, const TypeConverter *converter)
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
void notifyOpReplaced(Operation *op, ValueRange newValues)
Notifies that an op is about to be replaced with the given values.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
LogicalResult convertNonEntryRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions={})
Convert the types of non-entry block arguments within the given region.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter)
Apply a signature conversion on the given region, using converter for materializations if not null.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
Value buildUnresolvedMaterialization(MaterializationKind kind, Block *insertBlock, Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, Type origOutputType, const TypeConverter *converter)
Build an unresolved materialization operation given an output type and set of input operands.
FailureOr< Block * > convertBlockSignature(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion *conversion=nullptr)
Attempt to convert the signature of the given block, if successful a new block is returned containing...
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
MLIRContext * context
MLIR context.
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
Value buildUnresolvedTargetMaterialization(Location loc, Value input, Type outputType, const TypeConverter *converter)
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.