10 #include "mlir/Config/mlir-config.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
30 #define DEBUG_TYPE "dialect-conversion"
33 template <
typename... Args>
34 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
37 os.startLine() <<
"} -> SUCCESS";
39 os.getOStream() <<
" : "
40 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41 os.getOStream() <<
"\n";
46 template <
typename... Args>
47 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
50 os.startLine() <<
"} -> FAILURE : "
51 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
61 if (
OpResult inputRes = dyn_cast<OpResult>(value))
62 insertPt = ++inputRes.getOwner()->getIterator();
73 struct ConversionValueMapping {
84 Value lookupOrDefault(
Value from,
Type desiredType =
nullptr)
const;
94 for (
Value it = newVal; it; it = mapping.lookupOrNull(it))
95 assert(it != oldVal &&
"inserting cyclic mapping");
97 mapping.map(oldVal, newVal);
106 void erase(
Value value) { mapping.erase(value); }
111 for (
auto &it : mapping.getValueMap())
112 inverse[it.second].push_back(it.first);
122 Value ConversionValueMapping::lookupOrDefault(
Value from,
123 Type desiredType)
const {
128 if (!desiredType || from.
getType() == desiredType)
131 Value mappedValue = mapping.lookupOrNull(from);
138 return desiredValue ? desiredValue : from;
141 Value ConversionValueMapping::lookupOrNull(
Value from,
Type desiredType)
const {
142 Value result = lookupOrDefault(from, desiredType);
143 if (result == from || (desiredType && result.
getType() != desiredType))
148 bool ConversionValueMapping::tryMap(
Value oldVal,
Value newVal) {
149 for (
Value it = newVal; it; it = mapping.lookupOrNull(it))
162 struct RewriterState {
163 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
164 unsigned numReplacedOps)
165 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
166 numReplacedOps(numReplacedOps) {}
169 unsigned numRewrites;
172 unsigned numIgnoredOperations;
175 unsigned numReplacedOps;
207 UnresolvedMaterialization
210 virtual ~IRRewrite() =
default;
213 virtual void rollback() = 0;
232 Kind getKind()
const {
return kind; }
234 static bool classof(
const IRRewrite *
rewrite) {
return true; }
238 : kind(kind), rewriterImpl(rewriterImpl) {}
247 class BlockRewrite :
public IRRewrite {
250 Block *getBlock()
const {
return block; }
252 static bool classof(
const IRRewrite *
rewrite) {
253 return rewrite->getKind() >= Kind::CreateBlock &&
254 rewrite->getKind() <= Kind::ReplaceBlockArg;
260 : IRRewrite(kind, rewriterImpl), block(block) {}
269 class CreateBlockRewrite :
public BlockRewrite {
272 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
274 static bool classof(
const IRRewrite *
rewrite) {
275 return rewrite->getKind() == Kind::CreateBlock;
281 listener->notifyBlockInserted(block, {}, {});
284 void rollback()
override {
287 auto &blockOps = block->getOperations();
288 while (!blockOps.empty())
289 blockOps.remove(blockOps.begin());
290 block->dropAllUses();
291 if (block->getParent())
302 class EraseBlockRewrite :
public BlockRewrite {
305 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
306 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
308 static bool classof(
const IRRewrite *
rewrite) {
309 return rewrite->getKind() == Kind::EraseBlock;
312 ~EraseBlockRewrite()
override {
314 "rewrite was neither rolled back nor committed/cleaned up");
317 void rollback()
override {
320 assert(block &&
"expected block");
321 auto &blockList = region->getBlocks();
325 blockList.insert(before, block);
331 assert(block &&
"expected block");
332 assert(block->empty() &&
"expected empty block");
336 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
337 listener->notifyBlockErased(block);
342 block->dropAllDefinedValueUses();
353 Block *insertBeforeBlock;
359 class InlineBlockRewrite :
public BlockRewrite {
363 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
364 sourceBlock(sourceBlock),
365 firstInlinedInst(sourceBlock->empty() ? nullptr
366 : &sourceBlock->front()),
367 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
373 assert(!getConfig().listener &&
374 "InlineBlockRewrite not supported if listener is attached");
377 static bool classof(
const IRRewrite *
rewrite) {
378 return rewrite->getKind() == Kind::InlineBlock;
381 void rollback()
override {
384 if (firstInlinedInst) {
385 assert(lastInlinedInst &&
"expected operation");
405 class MoveBlockRewrite :
public BlockRewrite {
409 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block), region(region),
410 insertBeforeBlock(insertBeforeBlock) {}
412 static bool classof(
const IRRewrite *
rewrite) {
413 return rewrite->getKind() == Kind::MoveBlock;
421 listener->notifyBlockInserted(block, region,
426 void rollback()
override {
439 Block *insertBeforeBlock;
443 class BlockTypeConversionRewrite :
public BlockRewrite {
448 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, block),
449 origBlock(origBlock), converter(converter) {}
451 static bool classof(
const IRRewrite *
rewrite) {
452 return rewrite->getKind() == Kind::BlockTypeConversion;
455 Block *getOrigBlock()
const {
return origBlock; }
457 const TypeConverter *getConverter()
const {
return converter; }
461 void rollback()
override;
474 class ReplaceBlockArgRewrite :
public BlockRewrite {
478 : BlockRewrite(
Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
480 static bool classof(
const IRRewrite *
rewrite) {
481 return rewrite->getKind() == Kind::ReplaceBlockArg;
486 void rollback()
override;
493 class OperationRewrite :
public IRRewrite {
496 Operation *getOperation()
const {
return op; }
498 static bool classof(
const IRRewrite *
rewrite) {
499 return rewrite->getKind() >= Kind::MoveOperation &&
500 rewrite->getKind() <= Kind::UnresolvedMaterialization;
506 : IRRewrite(kind, rewriterImpl), op(op) {}
513 class MoveOperationRewrite :
public OperationRewrite {
517 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op), block(block),
518 insertBeforeOp(insertBeforeOp) {}
520 static bool classof(
const IRRewrite *
rewrite) {
521 return rewrite->getKind() == Kind::MoveOperation;
529 listener->notifyOperationInserted(
535 void rollback()
override {
553 class ModifyOperationRewrite :
public OperationRewrite {
557 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
558 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
559 operands(op->operand_begin(), op->operand_end()),
560 successors(op->successor_begin(), op->successor_end()) {
565 name.initOpProperties(propCopy, prop);
569 static bool classof(
const IRRewrite *
rewrite) {
570 return rewrite->getKind() == Kind::ModifyOperation;
573 ~ModifyOperationRewrite()
override {
574 assert(!propertiesStorage &&
575 "rewrite was neither committed nor rolled back");
581 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
582 listener->notifyOperationModified(op);
584 if (propertiesStorage) {
588 name.destroyOpProperties(propCopy);
589 operator delete(propertiesStorage);
590 propertiesStorage =
nullptr;
594 void rollback()
override {
600 if (propertiesStorage) {
603 name.destroyOpProperties(propCopy);
604 operator delete(propertiesStorage);
605 propertiesStorage =
nullptr;
612 DictionaryAttr attrs;
615 void *propertiesStorage =
nullptr;
622 class ReplaceOperationRewrite :
public OperationRewrite {
626 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
627 converter(converter) {}
629 static bool classof(
const IRRewrite *
rewrite) {
630 return rewrite->getKind() == Kind::ReplaceOperation;
635 void rollback()
override;
639 const TypeConverter *getConverter()
const {
return converter; }
647 class CreateOperationRewrite :
public OperationRewrite {
651 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
653 static bool classof(
const IRRewrite *
rewrite) {
654 return rewrite->getKind() == Kind::CreateOperation;
660 listener->notifyOperationInserted(op, {});
663 void rollback()
override;
667 enum MaterializationKind {
684 class UnresolvedMaterializationRewrite :
public OperationRewrite {
687 UnrealizedConversionCastOp op,
689 MaterializationKind kind,
Type originalType);
691 static bool classof(
const IRRewrite *
rewrite) {
692 return rewrite->getKind() == Kind::UnresolvedMaterialization;
695 void rollback()
override;
697 UnrealizedConversionCastOp getOperation()
const {
698 return cast<UnrealizedConversionCastOp>(op);
703 return converterAndKind.getPointer();
707 MaterializationKind getMaterializationKind()
const {
708 return converterAndKind.getInt();
712 Type getOriginalType()
const {
return originalType; }
717 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
728 template <
typename RewriteTy,
typename R>
730 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
731 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
732 return rewriteTy && rewriteTy->getOperation() == op;
744 : context(ctx), eraseRewriter(ctx), config(config) {}
751 RewriterState getCurrentState();
755 void applyRewrites();
758 void resetState(RewriterState state);
762 template <
typename RewriteTy,
typename... Args>
765 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
770 void undoRewrites(
unsigned numRewritesToKeep = 0);
776 LogicalResult remapValues(StringRef valueDiagTag,
777 std::optional<Location> inputLoc,
804 Block *applySignatureConversion(
815 Value buildUnresolvedMaterialization(MaterializationKind kind,
826 void notifyOperationInserted(
Operation *op,
833 void notifyBlockIsBeingErased(
Block *block);
836 void notifyBlockInserted(
Block *block,
Region *previous,
840 void notifyBlockBeingInlined(
Block *block,
Block *srcBlock,
870 if (wasErased(block))
872 assert(block->
empty() &&
"expected empty block");
877 bool wasErased(
void *ptr)
const {
return erased.contains(ptr); }
941 llvm::ScopedPrinter logger{llvm::dbgs()};
948 return rewriterImpl.
config;
951 void BlockTypeConversionRewrite::commit(
RewriterBase &rewriter) {
956 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
958 listener->notifyOperationModified(op);
961 void BlockTypeConversionRewrite::rollback() {
965 void ReplaceBlockArgRewrite::commit(
RewriterBase &rewriter) {
966 Value repl = rewriterImpl.
mapping.lookupOrNull(arg, arg.getType());
970 if (isa<BlockArgument>(repl)) {
978 Operation *replOp = cast<OpResult>(repl).getOwner();
986 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.
mapping.erase(arg); }
988 void ReplaceOperationRewrite::commit(
RewriterBase &rewriter) {
990 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
995 return rewriterImpl.mapping.lookupOrNull(result, result.getType());
1000 listener->notifyOperationReplaced(op, replacements);
1003 for (
auto [result, newValue] :
1004 llvm::zip_equal(op->
getResults(), replacements))
1010 if (getConfig().unlegalizedOps)
1011 getConfig().unlegalizedOps->erase(op);
1017 [&](
Operation *op) { listener->notifyOperationErased(op); });
1025 void ReplaceOperationRewrite::rollback() {
1027 rewriterImpl.
mapping.erase(result);
1030 void ReplaceOperationRewrite::cleanup(
RewriterBase &rewriter) {
1034 void CreateOperationRewrite::rollback() {
1036 while (!region.getBlocks().empty())
1037 region.getBlocks().remove(region.getBlocks().begin());
1043 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1046 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
1047 converterAndKind(converter, kind), originalType(originalType) {
1048 assert((!originalType || kind == MaterializationKind::Target) &&
1049 "original type is valid only for target materializations");
1053 void UnresolvedMaterializationRewrite::rollback() {
1054 if (getMaterializationKind() == MaterializationKind::Target) {
1056 rewriterImpl.
mapping.erase(input);
1085 while (
ignoredOps.size() != state.numIgnoredOperations)
1088 while (
replacedOps.size() != state.numReplacedOps)
1094 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1096 rewrites.resize(numRewritesToKeep);
1100 StringRef valueDiagTag, std::optional<Location> inputLoc,
1103 remapped.reserve(llvm::size(values));
1106 Value operand = it.value();
1114 remapped.push_back(
mapping.lookupOrDefault(operand));
1122 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1123 << it.index() <<
", type was " << origType;
1128 if (legalTypes.size() != 1) {
1136 remapped.push_back(
mapping.lookupOrDefault(operand));
1141 Type desiredType = legalTypes.front();
1144 Value newOperand =
mapping.lookupOrDefault(operand, desiredType);
1145 if (newOperand.
getType() != desiredType) {
1151 operandLoc, newOperand, desiredType,
1153 mapping.map(newOperand, castValue);
1154 newOperand = castValue;
1156 remapped.push_back(newOperand);
1179 if (region->
empty())
1184 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1186 std::optional<TypeConverter::SignatureConversion> conversion =
1196 if (entryConversion)
1199 std::optional<TypeConverter::SignatureConversion> conversion =
1222 for (
unsigned i = 0; i < origArgCount; ++i) {
1224 if (!inputMap || inputMap->replacementValue)
1227 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1228 newLocs[inputMap->inputNo +
j] = origLoc;
1235 convertedTypes, newLocs);
1245 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->
end());
1248 while (!block->
empty())
1255 for (
unsigned i = 0; i != origArgCount; ++i) {
1259 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1265 MaterializationKind::Source,
1268 origArgType,
Type(), converter);
1270 appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1274 if (
Value repl = inputMap->replacementValue) {
1276 assert(inputMap->size == 0 &&
1277 "invalid to provide a replacement value when the argument isn't "
1280 appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1289 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1293 replArgs, origArgType,
1297 Type legalOutputType;
1299 legalOutputType = converter->
convertType(origArgType);
1300 }
else if (replArgs.size() == 1) {
1308 legalOutputType = replArgs[0].getType();
1310 if (legalOutputType && legalOutputType != origArgType) {
1313 origArg.
getLoc(), argMat, legalOutputType,
1314 origArgType, converter);
1315 mapping.map(argMat, targetMat);
1317 appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1320 appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
1339 assert((!originalType || kind == MaterializationKind::Target) &&
1340 "original type is valid only for target materializations");
1343 if (inputs.size() == 1 && inputs.front().
getType() == outputType)
1344 return inputs.front();
1351 builder.
create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1352 appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1354 return convertOp.getResult(0);
1363 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"'(" << op
1367 "attempting to insert into a block within a replaced/erased op");
1369 if (!previous.
isSet()) {
1371 appendRewrite<CreateOperationRewrite>(op);
1377 appendRewrite<MoveOperationRewrite>(op, previous.
getBlock(), prevOp);
1383 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1386 for (
auto [newValue, result] : llvm::zip(newValues, op->
getResults())) {
1389 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1406 mapping.map(result, newValue);
1416 appendRewrite<EraseBlockRewrite>(block);
1422 "attempting to insert into a region within a replaced/erased op");
1427 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
1428 <<
"'(" << parent <<
")\n";
1431 <<
"** Insert Block into detached Region (nullptr parent op)'";
1437 appendRewrite<CreateBlockRewrite>(block);
1440 Block *prevBlock = previousIt == previous->
end() ? nullptr : &*previousIt;
1441 appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1446 appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1453 reasonCallback(
diag);
1454 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
1464 ConversionPatternRewriter::ConversionPatternRewriter(
1468 setListener(
impl.get());
1474 assert(op && newOp &&
"expected non-null op");
1480 "incorrect # of replacement values");
1482 impl->logger.startLine()
1483 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
1485 impl->notifyOpReplaced(op, newValues);
1490 impl->logger.startLine()
1491 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
1494 impl->notifyOpReplaced(op, nullRepls);
1499 "attempting to erase a block within a replaced/erased op");
1509 impl->notifyBlockIsBeingErased(block);
1517 "attempting to apply a signature conversion to a block within a "
1518 "replaced/erased op");
1519 return impl->applySignatureConversion(*
this, block, converter, conversion);
1526 "attempting to apply a signature conversion to a block within a "
1527 "replaced/erased op");
1528 return impl->convertRegionTypes(*
this, region, converter, entryConversion);
1535 impl->logger.startLine() <<
"** Replace Argument : '" << from
1536 <<
"'(in region of '" << parentOp->
getName()
1539 impl->appendRewrite<ReplaceBlockArgRewrite>(from.
getOwner(), from);
1540 impl->mapping.map(
impl->mapping.lookupOrDefault(from), to);
1545 if (failed(
impl->remapValues(
"value", std::nullopt, *
this, key,
1548 return remappedValues.front();
1556 return impl->remapValues(
"value", std::nullopt, *
this, keys,
1565 "incorrect # of argument replacement values");
1567 "attempting to inline a block from a replaced/erased op");
1569 "attempting to inline a block into a replaced/erased op");
1570 auto opIgnored = [&](
Operation *op) {
return impl->isOpIgnored(op); };
1573 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
1574 "expected 'source' to have no predecessors");
1583 bool fastPath = !
impl->config.listener;
1586 impl->notifyBlockBeingInlined(dest, source, before);
1589 for (
auto it : llvm::zip(source->
getArguments(), argValues))
1590 replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1597 while (!source->
empty())
1598 moveOpBefore(&source->
front(), dest, before);
1606 assert(!
impl->wasOpReplaced(op) &&
1607 "attempting to modify a replaced/erased op");
1609 impl->pendingRootUpdates.insert(op);
1611 impl->appendRewrite<ModifyOperationRewrite>(op);
1615 assert(!
impl->wasOpReplaced(op) &&
1616 "attempting to modify a replaced/erased op");
1621 assert(
impl->pendingRootUpdates.erase(op) &&
1622 "operation did not have a pending in-place update");
1628 assert(
impl->pendingRootUpdates.erase(op) &&
1629 "operation did not have a pending in-place update");
1632 auto it = llvm::find_if(
1633 llvm::reverse(
impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
1634 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1635 return modifyRewrite && modifyRewrite->getOperation() == op;
1637 assert(it !=
impl->rewrites.rend() &&
"no root update started on op");
1639 int updateIdx = std::prev(
impl->rewrites.rend()) - it;
1640 impl->rewrites.erase(
impl->rewrites.begin() + updateIdx);
1655 auto &rewriterImpl = dialectRewriter.getImpl();
1659 getTypeConverter());
1667 return matchAndRewrite(op, operands, dialectRewriter);
1679 class OperationLegalizer {
1699 LogicalResult legalizeWithFold(
Operation *op,
1704 LogicalResult legalizeWithPattern(
Operation *op,
1715 RewriterState &curState);
1719 legalizePatternBlockRewrites(
Operation *op,
1722 RewriterState &state, RewriterState &newState);
1723 LogicalResult legalizePatternCreatedOperations(
1725 RewriterState &state, RewriterState &newState);
1728 RewriterState &state,
1729 RewriterState &newState);
1739 void buildLegalizationGraph(
1740 LegalizationPatterns &anyOpLegalizerPatterns,
1751 void computeLegalizationGraphBenefit(
1752 LegalizationPatterns &anyOpLegalizerPatterns,
1757 unsigned computeOpLegalizationDepth(
1764 unsigned applyCostModelToPatterns(
1765 LegalizationPatterns &patterns,
1786 : target(targetInfo), applicator(patterns), config(config) {
1790 LegalizationPatterns anyOpLegalizerPatterns;
1792 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1793 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1796 bool OperationLegalizer::isIllegal(
Operation *op)
const {
1797 return target.isIllegal(op);
1801 OperationLegalizer::legalize(
Operation *op,
1804 const char *logLineComment =
1805 "//===-------------------------------------------===//\n";
1810 logger.getOStream() <<
"\n";
1811 logger.startLine() << logLineComment;
1812 logger.startLine() <<
"Legalizing operation : '" << op->
getName() <<
"'("
1818 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1819 logger.getOStream() <<
"\n\n";
1824 if (
auto legalityInfo = target.isLegal(op)) {
1827 logger,
"operation marked legal by the target{0}",
1828 legalityInfo->isRecursivelyLegal
1829 ?
"; NOTE: operation is recursively legal; skipping internals"
1831 logger.startLine() << logLineComment;
1836 if (legalityInfo->isRecursivelyLegal) {
1849 logSuccess(logger,
"operation marked 'ignored' during conversion");
1850 logger.startLine() << logLineComment;
1858 if (succeeded(legalizeWithFold(op, rewriter))) {
1861 logger.startLine() << logLineComment;
1867 if (succeeded(legalizeWithPattern(op, rewriter))) {
1870 logger.startLine() << logLineComment;
1876 logFailure(logger,
"no matched legalization pattern");
1877 logger.startLine() << logLineComment;
1883 OperationLegalizer::legalizeWithFold(
Operation *op,
1885 auto &rewriterImpl = rewriter.
getImpl();
1889 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
1890 rewriterImpl.
logger.indent();
1896 if (failed(rewriter.
tryFold(op, replacementValues))) {
1902 if (replacementValues.empty())
1903 return legalize(op, rewriter);
1906 rewriter.
replaceOp(op, replacementValues);
1909 for (
unsigned i = curState.numRewrites, e = rewriterImpl.
rewrites.size();
1912 dyn_cast<CreateOperationRewrite>(rewriterImpl.
rewrites[i].get());
1915 if (failed(legalize(createOp->getOperation(), rewriter))) {
1917 "failed to legalize generated constant '{0}'",
1918 createOp->getOperation()->getName()));
1929 OperationLegalizer::legalizeWithPattern(
Operation *op,
1931 auto &rewriterImpl = rewriter.
getImpl();
1934 auto canApply = [&](
const Pattern &pattern) {
1935 bool canApply = canApplyPattern(op, pattern, rewriter);
1936 if (canApply && config.listener)
1937 config.listener->notifyPatternBegin(pattern, op);
1943 auto onFailure = [&](
const Pattern &pattern) {
1949 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
1955 if (config.listener)
1956 config.listener->notifyPatternEnd(pattern, failure());
1958 appliedPatterns.erase(&pattern);
1963 auto onSuccess = [&](
const Pattern &pattern) {
1965 auto result = legalizePatternResult(op, pattern, rewriter, curState);
1966 appliedPatterns.erase(&pattern);
1969 if (config.listener)
1970 config.listener->notifyPatternEnd(pattern, result);
1975 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1979 bool OperationLegalizer::canApplyPattern(
Operation *op,
const Pattern &pattern,
1982 auto &os = rewriter.
getImpl().logger;
1983 os.getOStream() <<
"\n";
1984 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
1986 os.getOStream() <<
")' {\n";
1993 !appliedPatterns.insert(&pattern).second) {
2002 OperationLegalizer::legalizePatternResult(
Operation *op,
const Pattern &pattern,
2004 RewriterState &curState) {
2008 assert(
impl.pendingRootUpdates.empty() &&
"dangling root updates");
2010 auto newRewrites = llvm::drop_begin(
impl.rewrites, curState.numRewrites);
2011 auto replacedRoot = [&] {
2012 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2014 auto updatedRootInPlace = [&] {
2015 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2017 assert((replacedRoot() || updatedRootInPlace()) &&
2018 "expected pattern to replace the root operation");
2022 RewriterState newState =
impl.getCurrentState();
2023 if (failed(legalizePatternBlockRewrites(op, rewriter,
impl, curState,
2025 failed(legalizePatternRootUpdates(rewriter,
impl, curState, newState)) ||
2026 failed(legalizePatternCreatedOperations(rewriter,
impl, curState,
2031 LLVM_DEBUG(
logSuccess(
impl.logger,
"pattern applied successfully"));
2035 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2038 RewriterState &newState) {
2043 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2044 BlockRewrite *
rewrite = dyn_cast<BlockRewrite>(
impl.rewrites[i].get());
2048 if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2049 ReplaceBlockArgRewrite>(
rewrite))
2058 if (
auto *converter =
impl.regionToConverter.lookup(block->
getParent())) {
2059 std::optional<TypeConverter::SignatureConversion> conversion =
2062 LLVM_DEBUG(
logFailure(
impl.logger,
"failed to convert types of moved "
2066 impl.applySignatureConversion(rewriter, block, converter, *conversion);
2074 if (operationsToIgnore.empty()) {
2075 for (
unsigned i = state.numRewrites, e =
impl.rewrites.size(); i != e;
2078 dyn_cast<CreateOperationRewrite>(
impl.rewrites[i].get());
2081 operationsToIgnore.insert(createOp->getOperation());
2086 if (operationsToIgnore.insert(parentOp).second &&
2087 failed(legalize(parentOp, rewriter))) {
2089 "operation '{0}'({1}) became illegal after rewrite",
2090 parentOp->
getName(), parentOp));
2097 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2099 RewriterState &state, RewriterState &newState) {
2100 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2101 auto *createOp = dyn_cast<CreateOperationRewrite>(
impl.rewrites[i].get());
2104 Operation *op = createOp->getOperation();
2105 if (failed(legalize(op, rewriter))) {
2107 "failed to legalize generated operation '{0}'({1})",
2115 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2117 RewriterState &state, RewriterState &newState) {
2118 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2119 auto *
rewrite = dyn_cast<ModifyOperationRewrite>(
impl.rewrites[i].get());
2123 if (failed(legalize(op, rewriter))) {
2125 impl.logger,
"failed to legalize operation updated in-place '{0}'",
2136 void OperationLegalizer::buildLegalizationGraph(
2137 LegalizationPatterns &anyOpLegalizerPatterns,
2148 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2149 std::optional<OperationName> root = pattern.
getRootKind();
2155 anyOpLegalizerPatterns.push_back(&pattern);
2160 if (target.getOpAction(*root) == LegalizationAction::Legal)
2165 invalidPatterns[*root].insert(&pattern);
2167 parentOps[op].insert(*root);
2170 patternWorklist.insert(&pattern);
2178 if (!anyOpLegalizerPatterns.empty()) {
2179 for (
const Pattern *pattern : patternWorklist)
2180 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2184 while (!patternWorklist.empty()) {
2185 auto *pattern = patternWorklist.pop_back_val();
2189 std::optional<LegalizationAction> action = target.getOpAction(op);
2190 return !legalizerPatterns.count(op) &&
2191 (!action || action == LegalizationAction::Illegal);
2197 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2198 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2202 for (
auto op : parentOps[*pattern->
getRootKind()])
2203 patternWorklist.set_union(invalidPatterns[op]);
2207 void OperationLegalizer::computeLegalizationGraphBenefit(
2208 LegalizationPatterns &anyOpLegalizerPatterns,
2214 for (
auto &opIt : legalizerPatterns)
2215 if (!minOpPatternDepth.count(opIt.first))
2216 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2222 if (!anyOpLegalizerPatterns.empty())
2223 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2229 applicator.applyCostModel([&](
const Pattern &pattern) {
2231 if (std::optional<OperationName> rootName = pattern.
getRootKind())
2232 orderedPatternList = legalizerPatterns[*rootName];
2234 orderedPatternList = anyOpLegalizerPatterns;
2237 auto *it = llvm::find(orderedPatternList, &pattern);
2238 if (it == orderedPatternList.end())
2242 return PatternBenefit(std::distance(it, orderedPatternList.end()));
2246 unsigned OperationLegalizer::computeOpLegalizationDepth(
2250 auto depthIt = minOpPatternDepth.find(op);
2251 if (depthIt != minOpPatternDepth.end())
2252 return depthIt->second;
2256 auto opPatternsIt = legalizerPatterns.find(op);
2257 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2266 unsigned minDepth = applyCostModelToPatterns(
2267 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2268 minOpPatternDepth[op] = minDepth;
2272 unsigned OperationLegalizer::applyCostModelToPatterns(
2273 LegalizationPatterns &patterns,
2280 patternsByDepth.reserve(patterns.size());
2281 for (
const Pattern *pattern : patterns) {
2284 unsigned generatedOpDepth = computeOpLegalizationDepth(
2285 generatedOp, minOpPatternDepth, legalizerPatterns);
2286 depth =
std::max(depth, generatedOpDepth + 1);
2288 patternsByDepth.emplace_back(pattern, depth);
2291 minDepth =
std::min(minDepth, depth);
2296 if (patternsByDepth.size() == 1)
2300 std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2301 [](
const std::pair<const Pattern *, unsigned> &lhs,
2302 const std::pair<const Pattern *, unsigned> &rhs) {
2305 if (lhs.second != rhs.second)
2306 return lhs.second < rhs.second;
2309 auto lhsBenefit = lhs.first->getBenefit();
2310 auto rhsBenefit = rhs.first->getBenefit();
2311 return lhsBenefit > rhsBenefit;
2316 for (
auto &patternIt : patternsByDepth)
2317 patterns.push_back(patternIt.first);
2325 enum OpConversionMode {
2348 OpConversionMode mode)
2349 : config(config), opLegalizer(target, patterns, this->config),
2367 OperationLegalizer opLegalizer;
2370 OpConversionMode mode;
2377 if (failed(opLegalizer.legalize(op, rewriter))) {
2380 if (mode == OpConversionMode::Full)
2382 <<
"failed to legalize operation '" << op->
getName() <<
"'";
2386 if (mode == OpConversionMode::Partial) {
2387 if (opLegalizer.isIllegal(op))
2389 <<
"failed to legalize operation '" << op->
getName()
2390 <<
"' that was explicitly marked illegal";
2394 }
else if (mode == OpConversionMode::Analysis) {
2404 static LogicalResult
2406 UnresolvedMaterializationRewrite *
rewrite) {
2407 UnrealizedConversionCastOp op =
rewrite->getOperation();
2409 "expected that dead materializations have already been DCE'd");
2416 Value newMaterialization;
2417 switch (
rewrite->getMaterializationKind()) {
2421 rewriter, op->
getLoc(), outputType, inputOperands);
2422 if (newMaterialization)
2427 case MaterializationKind::Target:
2429 rewriter, op->
getLoc(), outputType, inputOperands,
2432 case MaterializationKind::Source:
2434 rewriter, op->
getLoc(), outputType, inputOperands);
2437 if (newMaterialization) {
2438 assert(newMaterialization.
getType() == outputType &&
2439 "materialization callback produced value of incorrect type");
2440 rewriter.
replaceOp(op, newMaterialization);
2446 <<
"failed to legalize unresolved materialization "
2448 << inputOperands.
getTypes() <<
") to " << outputType
2449 <<
" that remained live after conversion";
2451 <<
"see existing live user here: " << *op->
getUsers().begin();
2462 for (
auto *op : ops) {
2465 toConvert.push_back(op);
2468 auto legalityInfo = target.
isLegal(op);
2469 if (legalityInfo && legalityInfo->isRecursivelyLegal)
2479 for (
auto *op : toConvert)
2480 if (failed(convert(rewriter, op)))
2495 for (
auto it : materializations) {
2498 allCastOps.push_back(it.first);
2510 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2511 auto it = materializations.find(castOp);
2512 assert(it != materializations.end() &&
"inconsistent state");
2527 while (!worklist.empty()) {
2528 Value value = worklist.pop_back_val();
2533 return rewriterImpl.isOpIgnored(user);
2535 if (liveUserIt != value.
user_end())
2537 auto mapIt = inverseMapping.find(value);
2538 if (mapIt != inverseMapping.end())
2539 worklist.append(mapIt->second);
2548 static std::pair<ValueRange, const TypeConverter *>
2550 if (
auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(
rewrite))
2551 return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
2552 if (
auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(
rewrite))
2553 return {blockRewrite->getOrigBlock()->getArguments(),
2554 blockRewrite->getConverter()};
2561 rewriterImpl.
mapping.getInverse();
2564 for (
unsigned i = 0, e = rewriterImpl.
rewrites.size(); i < e; ++i) {
2567 std::tie(replacedValues, converter) =
2569 for (
Value originalValue : replacedValues) {
2572 if (rewriterImpl.
mapping.lookupOrNull(originalValue,
2573 originalValue.getType()))
2581 Value newValue = rewriterImpl.
mapping.lookupOrNull(originalValue);
2582 assert(newValue &&
"replacement value not found");
2585 originalValue.getLoc(),
2586 newValue, originalValue.getType(),
2588 rewriterImpl.
mapping.map(originalValue, castValue);
2589 inverseMapping[castValue].push_back(originalValue);
2590 llvm::erase(inverseMapping[newValue], originalValue);
2609 auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2610 for (
Value v : castOp.getInputs())
2611 if (
auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2612 worklist.insert(inputCastOp);
2619 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2620 if (castOp.getInputs().empty())
2623 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2626 if (inputCastOp.getOutputs() != castOp.getInputs())
2632 while (!worklist.empty()) {
2633 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2634 if (castOp->use_empty()) {
2637 enqueueOperands(castOp);
2638 if (remainingCastOps)
2639 erasedOps.insert(castOp.getOperation());
2646 UnrealizedConversionCastOp nextCast = castOp;
2648 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2652 enqueueOperands(castOp);
2653 castOp.replaceAllUsesWith(nextCast.getInputs());
2654 if (remainingCastOps)
2655 erasedOps.insert(castOp.getOperation());
2659 nextCast = getInputCast(nextCast);
2663 if (remainingCastOps)
2664 for (UnrealizedConversionCastOp op : castOps)
2665 if (!erasedOps.contains(op.getOperation()))
2666 remainingCastOps->push_back(op);
2675 assert(!types.empty() &&
"expected valid types");
2676 remapInput(origInputNo, argTypes.size(), types.size());
2681 assert(!types.empty() &&
2682 "1->0 type remappings don't need to be added explicitly");
2683 argTypes.append(types.begin(), types.end());
2687 unsigned newInputNo,
2688 unsigned newInputCount) {
2689 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
2690 assert(newInputCount != 0 &&
"expected valid input count");
2691 remappedInputs[origInputNo] =
2692 InputMapping{newInputNo, newInputCount,
nullptr};
2696 Value replacementValue) {
2697 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
2698 remappedInputs[origInputNo] =
2705 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2708 cacheReadLock.lock();
2709 auto existingIt = cachedDirectConversions.find(t);
2710 if (existingIt != cachedDirectConversions.end()) {
2711 if (existingIt->second)
2712 results.push_back(existingIt->second);
2713 return success(existingIt->second !=
nullptr);
2715 auto multiIt = cachedMultiConversions.find(t);
2716 if (multiIt != cachedMultiConversions.end()) {
2717 results.append(multiIt->second.begin(), multiIt->second.end());
2723 size_t currentCount = results.size();
2725 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2728 for (
const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2729 if (std::optional<LogicalResult> result = converter(t, results)) {
2731 cacheWriteLock.lock();
2732 if (!succeeded(*result)) {
2733 cachedDirectConversions.try_emplace(t,
nullptr);
2736 auto newTypes =
ArrayRef<Type>(results).drop_front(currentCount);
2737 if (newTypes.size() == 1)
2738 cachedDirectConversions.try_emplace(t, newTypes.front());
2740 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2754 return results.size() == 1 ? results.front() :
nullptr;
2760 for (
Type type : types)
2774 return llvm::all_of(*region, [
this](
Block &block) {
2780 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2792 if (convertedTypes.empty())
2796 result.
addInputs(inputNo, convertedTypes);
2802 unsigned origInputOffset)
const {
2803 for (
unsigned i = 0, e = types.size(); i != e; ++i)
2813 for (
const MaterializationCallbackFn &fn :
2814 llvm::reverse(argumentMaterializations))
2815 if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
2823 for (
const MaterializationCallbackFn &fn :
2824 llvm::reverse(sourceMaterializations))
2825 if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
2833 Type originalType)
const {
2834 for (
const TargetMaterializationCallbackFn &fn :
2835 llvm::reverse(targetMaterializations))
2836 if (std::optional<Value> result =
2837 fn(builder, resultType, inputs, loc, originalType))
2842 std::optional<TypeConverter::SignatureConversion>
2846 return std::nullopt;
2869 return impl.getInt() == resultTag;
2873 return impl.getInt() == naTag;
2877 return impl.getInt() == abortTag;
2881 assert(hasResult() &&
"Cannot get result from N/A or abort");
2882 return impl.getPointer();
2885 std::optional<Attribute>
2887 for (
const TypeAttributeConversionCallbackFn &fn :
2888 llvm::reverse(typeAttributeConversions)) {
2893 return std::nullopt;
2895 return std::nullopt;
2905 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
2913 failed(typeConverter.
convertTypes(type.getResults(), newResults)) ||
2915 typeConverter, &result)))
2932 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
2940 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
2945 struct AnyFunctionOpInterfaceSignatureConversion
2957 FailureOr<Operation *>
2961 assert(op &&
"Invalid op");
2975 return rewriter.
create(newOp);
2981 patterns.
add<FunctionOpInterfaceSignatureConversion>(
2982 functionLikeOpName, patterns.
getContext(), converter);
2987 patterns.
add<AnyFunctionOpInterfaceSignatureConversion>(
2997 legalOperations[op].action = action;
3002 for (StringRef dialect : dialectNames)
3003 legalDialects[dialect] = action;
3007 -> std::optional<LegalizationAction> {
3008 std::optional<LegalizationInfo> info = getOpInfo(op);
3009 return info ? info->action : std::optional<LegalizationAction>();
3013 -> std::optional<LegalOpDetails> {
3014 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3016 return std::nullopt;
3019 auto isOpLegal = [&] {
3021 if (info->action == LegalizationAction::Dynamic) {
3022 std::optional<bool> result = info->legalityFn(op);
3028 return info->action == LegalizationAction::Legal;
3031 return std::nullopt;
3035 if (info->isRecursivelyLegal) {
3036 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3037 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3039 legalityFnIt->second(op).value_or(
true);
3044 return legalityDetails;
3048 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3052 if (info->action == LegalizationAction::Dynamic) {
3053 std::optional<bool> result = info->legalityFn(op);
3060 return info->action == LegalizationAction::Illegal;
3069 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3071 if (std::optional<bool> result = newCl(op))
3079 void ConversionTarget::setLegalityCallback(
3080 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3081 assert(callback &&
"expected valid legality callback");
3082 auto *infoIt = legalOperations.find(name);
3083 assert(infoIt != legalOperations.end() &&
3084 infoIt->second.action == LegalizationAction::Dynamic &&
3085 "expected operation to already be marked as dynamically legal");
3086 infoIt->second.legalityFn =
3092 auto *infoIt = legalOperations.find(name);
3093 assert(infoIt != legalOperations.end() &&
3094 infoIt->second.action != LegalizationAction::Illegal &&
3095 "expected operation to already be marked as legal");
3096 infoIt->second.isRecursivelyLegal =
true;
3099 std::move(opRecursiveLegalityFns[name]), callback);
3101 opRecursiveLegalityFns.erase(name);
3104 void ConversionTarget::setLegalityCallback(
3106 assert(callback &&
"expected valid legality callback");
3107 for (StringRef dialect : dialects)
3109 std::move(dialectLegalityFns[dialect]), callback);
3112 void ConversionTarget::setLegalityCallback(
3113 const DynamicLegalityCallbackFn &callback) {
3114 assert(callback &&
"expected valid legality callback");
3119 -> std::optional<LegalizationInfo> {
3121 const auto *it = legalOperations.find(op);
3122 if (it != legalOperations.end())
3125 auto dialectIt = legalDialects.find(op.getDialectNamespace());
3126 if (dialectIt != legalDialects.end()) {
3127 DynamicLegalityCallbackFn callback;
3128 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3129 if (dialectFn != dialectLegalityFns.end())
3130 callback = dialectFn->second;
3131 return LegalizationInfo{dialectIt->second,
false,
3135 if (unknownLegalityFn)
3136 return LegalizationInfo{LegalizationAction::Dynamic,
3137 false, unknownLegalityFn};
3138 return std::nullopt;
3141 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3147 auto &rewriterImpl =
3153 auto &rewriterImpl =
3160 static FailureOr<SmallVector<Value>>
3165 return std::move(mappedValues);
3174 if (failed(results))
3176 return results->front();
3186 auto &rewriterImpl =
3199 TypeRange types) -> FailureOr<SmallVector<Type>> {
3200 auto &rewriterImpl =
3207 if (failed(converter->
convertTypes(types, remappedTypes)))
3209 return std::move(remappedTypes);
3225 OpConversionMode::Partial);
3243 OpConversionMode::Full);
3266 "expected top-level op to be isolated from above");
3269 "expected ops to have a common ancestor");
3278 for (
Operation *op : ops.drop_front()) {
3282 assert(commonAncestor &&
3283 "expected to find a common isolated from above ancestor");
3287 return commonAncestor;
3305 inverseOperationMap[it.second] = it.first;
3311 OpConversionMode::Analysis);
3312 LogicalResult status = opConverter.convertOperations(opsToConvert);
3319 originalLegalizableOps.insert(inverseOperationMap[op]);
3324 clonedAncestor->
erase();
static std::pair< ValueRange, const TypeConverter * > getReplacedValues(IRRewrite *rewrite)
Helper function that returns the replaced values and the type converter if the given rewrite object i...
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static bool hasRewrite(R &&rewrites, Operation *op)
Return "true" if there is an operation rewrite that matches the specified rewrite type and operation ...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
OpListType & getOperations()
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
~ConversionPatternRewriter() override
Base class for the conversion patterns.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
user_range getUsers()
Returns a range of all users.
result_range getResults()
int getPropertiesStorageSize() const
Returns the properties storage size.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
MLIRContext * getContext() const
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
The general result of a type attribute conversion callback, allowing for early termination.
Attribute getResult() const
static AttributeConversionResult abort()
static AttributeConversionResult na()
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_end() const
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target/ argument materializations t...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
A rewriter that keeps track of erased ops and blocks.
bool wasErased(void *ptr) const
SingleEraseRewriter(MLIRContext *context)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
void notifyOpReplaced(Operation *op, ValueRange newValues)
Notifies that an op is about to be replaced with the given values.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
SingleEraseRewriter eraseRewriter
A rewriter that keeps track of ops/block that were already erased and skips duplicate op/block erasur...
MLIRContext * context
MLIR context.
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
Value buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueRange inputs, Type outputType, Type originalType, const TypeConverter *converter)
Build an unresolved materialization operation given an output type and set of input operands.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.