10 #include "mlir/Config/mlir-config.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
30 #define DEBUG_TYPE "dialect-conversion"
33 template <
typename... Args>
34 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
37 os.startLine() <<
"} -> SUCCESS";
39 os.getOStream() <<
" : "
40 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41 os.getOStream() <<
"\n";
46 template <
typename... Args>
47 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
50 os.startLine() <<
"} -> FAILURE : "
51 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
61 if (
OpResult inputRes = dyn_cast<OpResult>(value))
62 insertPt = ++inputRes.getOwner()->getIterator();
73 struct ConversionValueMapping {
84 Value lookupOrDefault(
Value from,
Type desiredType =
nullptr)
const;
94 for (
Value it = newVal; it; it = mapping.lookupOrNull(it))
95 assert(it != oldVal &&
"inserting cyclic mapping");
97 mapping.map(oldVal, newVal);
106 void erase(
Value value) { mapping.erase(value); }
111 for (
auto &it : mapping.getValueMap())
112 inverse[it.second].push_back(it.first);
122 Value ConversionValueMapping::lookupOrDefault(
Value from,
123 Type desiredType)
const {
128 if (!desiredType || from.
getType() == desiredType)
131 Value mappedValue = mapping.lookupOrNull(from);
138 return desiredValue ? desiredValue : from;
141 Value ConversionValueMapping::lookupOrNull(
Value from,
Type desiredType)
const {
142 Value result = lookupOrDefault(from, desiredType);
143 if (result == from || (desiredType && result.
getType() != desiredType))
148 bool ConversionValueMapping::tryMap(
Value oldVal,
Value newVal) {
149 for (
Value it = newVal; it; it = mapping.lookupOrNull(it))
162 struct RewriterState {
163 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
164 unsigned numReplacedOps)
165 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
166 numReplacedOps(numReplacedOps) {}
169 unsigned numRewrites;
172 unsigned numIgnoredOperations;
175 unsigned numReplacedOps;
207 UnresolvedMaterialization
210 virtual ~IRRewrite() =
default;
213 virtual void rollback() = 0;
232 Kind getKind()
const {
return kind; }
234 static bool classof(
const IRRewrite *
rewrite) {
return true; }
238 : kind(kind), rewriterImpl(rewriterImpl) {}
247 class BlockRewrite :
public IRRewrite {
250 Block *getBlock()
const {
return block; }
252 static bool classof(
const IRRewrite *
rewrite) {
253 return rewrite->getKind() >= Kind::CreateBlock &&
254 rewrite->getKind() <= Kind::ReplaceBlockArg;
260 : IRRewrite(kind, rewriterImpl), block(block) {}
269 class CreateBlockRewrite :
public BlockRewrite {
272 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
274 static bool classof(
const IRRewrite *
rewrite) {
275 return rewrite->getKind() == Kind::CreateBlock;
281 listener->notifyBlockInserted(block, {}, {});
284 void rollback()
override {
287 auto &blockOps = block->getOperations();
288 while (!blockOps.empty())
289 blockOps.remove(blockOps.begin());
290 block->dropAllUses();
291 if (block->getParent())
302 class EraseBlockRewrite :
public BlockRewrite {
305 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
306 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
308 static bool classof(
const IRRewrite *
rewrite) {
309 return rewrite->getKind() == Kind::EraseBlock;
312 ~EraseBlockRewrite()
override {
314 "rewrite was neither rolled back nor committed/cleaned up");
317 void rollback()
override {
320 assert(block &&
"expected block");
321 auto &blockList = region->getBlocks();
325 blockList.insert(before, block);
331 assert(block &&
"expected block");
332 assert(block->empty() &&
"expected empty block");
336 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
337 listener->notifyBlockErased(block);
342 block->dropAllDefinedValueUses();
353 Block *insertBeforeBlock;
359 class InlineBlockRewrite :
public BlockRewrite {
363 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
364 sourceBlock(sourceBlock),
365 firstInlinedInst(sourceBlock->empty() ? nullptr
366 : &sourceBlock->front()),
367 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
373 assert(!getConfig().listener &&
374 "InlineBlockRewrite not supported if listener is attached");
377 static bool classof(
const IRRewrite *
rewrite) {
378 return rewrite->getKind() == Kind::InlineBlock;
381 void rollback()
override {
384 if (firstInlinedInst) {
385 assert(lastInlinedInst &&
"expected operation");
405 class MoveBlockRewrite :
public BlockRewrite {
409 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block), region(region),
410 insertBeforeBlock(insertBeforeBlock) {}
412 static bool classof(
const IRRewrite *
rewrite) {
413 return rewrite->getKind() == Kind::MoveBlock;
421 listener->notifyBlockInserted(block, region,
426 void rollback()
override {
439 Block *insertBeforeBlock;
443 class BlockTypeConversionRewrite :
public BlockRewrite {
448 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, block),
449 origBlock(origBlock), converter(converter) {}
451 static bool classof(
const IRRewrite *
rewrite) {
452 return rewrite->getKind() == Kind::BlockTypeConversion;
455 Block *getOrigBlock()
const {
return origBlock; }
457 const TypeConverter *getConverter()
const {
return converter; }
461 void rollback()
override;
474 class ReplaceBlockArgRewrite :
public BlockRewrite {
478 : BlockRewrite(
Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
480 static bool classof(
const IRRewrite *
rewrite) {
481 return rewrite->getKind() == Kind::ReplaceBlockArg;
486 void rollback()
override;
493 class OperationRewrite :
public IRRewrite {
496 Operation *getOperation()
const {
return op; }
498 static bool classof(
const IRRewrite *
rewrite) {
499 return rewrite->getKind() >= Kind::MoveOperation &&
500 rewrite->getKind() <= Kind::UnresolvedMaterialization;
506 : IRRewrite(kind, rewriterImpl), op(op) {}
513 class MoveOperationRewrite :
public OperationRewrite {
517 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op), block(block),
518 insertBeforeOp(insertBeforeOp) {}
520 static bool classof(
const IRRewrite *
rewrite) {
521 return rewrite->getKind() == Kind::MoveOperation;
529 listener->notifyOperationInserted(
535 void rollback()
override {
553 class ModifyOperationRewrite :
public OperationRewrite {
557 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
558 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
559 operands(op->operand_begin(), op->operand_end()),
560 successors(op->successor_begin(), op->successor_end()) {
565 name.initOpProperties(propCopy, prop);
569 static bool classof(
const IRRewrite *
rewrite) {
570 return rewrite->getKind() == Kind::ModifyOperation;
573 ~ModifyOperationRewrite()
override {
574 assert(!propertiesStorage &&
575 "rewrite was neither committed nor rolled back");
581 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
582 listener->notifyOperationModified(op);
584 if (propertiesStorage) {
588 name.destroyOpProperties(propCopy);
589 operator delete(propertiesStorage);
590 propertiesStorage =
nullptr;
594 void rollback()
override {
600 if (propertiesStorage) {
603 name.destroyOpProperties(propCopy);
604 operator delete(propertiesStorage);
605 propertiesStorage =
nullptr;
612 DictionaryAttr attrs;
615 void *propertiesStorage =
nullptr;
622 class ReplaceOperationRewrite :
public OperationRewrite {
626 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
627 converter(converter) {}
629 static bool classof(
const IRRewrite *
rewrite) {
630 return rewrite->getKind() == Kind::ReplaceOperation;
635 void rollback()
override;
639 const TypeConverter *getConverter()
const {
return converter; }
647 class CreateOperationRewrite :
public OperationRewrite {
651 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
653 static bool classof(
const IRRewrite *
rewrite) {
654 return rewrite->getKind() == Kind::CreateOperation;
660 listener->notifyOperationInserted(op, {});
663 void rollback()
override;
667 enum MaterializationKind {
684 class UnresolvedMaterializationRewrite :
public OperationRewrite {
686 UnresolvedMaterializationRewrite(
688 UnrealizedConversionCastOp op,
const TypeConverter *converter =
nullptr,
689 MaterializationKind kind = MaterializationKind::Target);
691 static bool classof(
const IRRewrite *
rewrite) {
692 return rewrite->getKind() == Kind::UnresolvedMaterialization;
695 void rollback()
override;
697 UnrealizedConversionCastOp getOperation()
const {
698 return cast<UnrealizedConversionCastOp>(op);
703 return converterAndKind.getPointer();
707 MaterializationKind getMaterializationKind()
const {
708 return converterAndKind.getInt();
714 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
721 template <
typename RewriteTy,
typename R>
723 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
724 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
725 return rewriteTy && rewriteTy->getOperation() == op;
737 : context(ctx), eraseRewriter(ctx), config(config) {}
744 RewriterState getCurrentState();
748 void applyRewrites();
751 void resetState(RewriterState state);
755 template <
typename RewriteTy,
typename... Args>
758 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
763 void undoRewrites(
unsigned numRewritesToKeep = 0);
769 LogicalResult remapValues(StringRef valueDiagTag,
770 std::optional<Location> inputLoc,
797 Block *applySignatureConversion(
808 Value buildUnresolvedMaterialization(MaterializationKind kind,
818 void notifyOperationInserted(
Operation *op,
825 void notifyBlockIsBeingErased(
Block *block);
828 void notifyBlockInserted(
Block *block,
Region *previous,
832 void notifyBlockBeingInlined(
Block *block,
Block *srcBlock,
862 if (wasErased(block))
864 assert(block->
empty() &&
"expected empty block");
869 bool wasErased(
void *ptr)
const {
return erased.contains(ptr); }
933 llvm::ScopedPrinter logger{llvm::dbgs()};
940 return rewriterImpl.
config;
943 void BlockTypeConversionRewrite::commit(
RewriterBase &rewriter) {
948 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
950 listener->notifyOperationModified(op);
953 void BlockTypeConversionRewrite::rollback() {
957 void ReplaceBlockArgRewrite::commit(
RewriterBase &rewriter) {
958 Value repl = rewriterImpl.
mapping.lookupOrNull(arg, arg.getType());
962 if (isa<BlockArgument>(repl)) {
970 Operation *replOp = cast<OpResult>(repl).getOwner();
978 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.
mapping.erase(arg); }
980 void ReplaceOperationRewrite::commit(
RewriterBase &rewriter) {
982 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
987 return rewriterImpl.mapping.lookupOrNull(result, result.getType());
992 listener->notifyOperationReplaced(op, replacements);
995 for (
auto [result, newValue] :
996 llvm::zip_equal(op->
getResults(), replacements))
1002 if (getConfig().unlegalizedOps)
1003 getConfig().unlegalizedOps->erase(op);
1009 [&](
Operation *op) { listener->notifyOperationErased(op); });
1017 void ReplaceOperationRewrite::rollback() {
1019 rewriterImpl.
mapping.erase(result);
1022 void ReplaceOperationRewrite::cleanup(
RewriterBase &rewriter) {
1026 void CreateOperationRewrite::rollback() {
1028 while (!region.getBlocks().empty())
1029 region.getBlocks().remove(region.getBlocks().begin());
1035 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1038 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
1039 converterAndKind(converter, kind) {
1043 void UnresolvedMaterializationRewrite::rollback() {
1044 if (getMaterializationKind() == MaterializationKind::Target) {
1046 rewriterImpl.
mapping.erase(input);
1075 while (
ignoredOps.size() != state.numIgnoredOperations)
1078 while (
replacedOps.size() != state.numReplacedOps)
1084 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1086 rewrites.resize(numRewritesToKeep);
1090 StringRef valueDiagTag, std::optional<Location> inputLoc,
1093 remapped.reserve(llvm::size(values));
1096 Value operand = it.value();
1104 remapped.push_back(
mapping.lookupOrDefault(operand));
1112 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1113 << it.index() <<
", type was " << origType;
1118 if (legalTypes.size() != 1) {
1126 remapped.push_back(
mapping.lookupOrDefault(operand));
1131 Type desiredType = legalTypes.front();
1134 Value newOperand =
mapping.lookupOrDefault(operand, desiredType);
1135 if (newOperand.
getType() != desiredType) {
1141 operandLoc, newOperand, desiredType,
1143 mapping.map(newOperand, castValue);
1144 newOperand = castValue;
1146 remapped.push_back(newOperand);
1169 if (region->
empty())
1174 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1176 std::optional<TypeConverter::SignatureConversion> conversion =
1186 if (entryConversion)
1189 std::optional<TypeConverter::SignatureConversion> conversion =
1212 for (
unsigned i = 0; i < origArgCount; ++i) {
1214 if (!inputMap || inputMap->replacementValue)
1217 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1218 newLocs[inputMap->inputNo +
j] = origLoc;
1225 convertedTypes, newLocs);
1235 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->
end());
1238 while (!block->
empty())
1245 for (
unsigned i = 0; i != origArgCount; ++i) {
1249 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1255 MaterializationKind::Source,
1258 origArgType, converter);
1260 appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1264 if (
Value repl = inputMap->replacementValue) {
1266 assert(inputMap->size == 0 &&
1267 "invalid to provide a replacement value when the argument isn't "
1270 appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1279 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1283 replArgs, origArgType, converter);
1286 Type legalOutputType;
1288 legalOutputType = converter->
convertType(origArgType);
1289 }
else if (replArgs.size() == 1) {
1297 legalOutputType = replArgs[0].getType();
1299 if (legalOutputType && legalOutputType != origArgType) {
1302 origArg.
getLoc(), argMat, legalOutputType, converter);
1303 mapping.map(argMat, targetMat);
1305 appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1308 appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
1327 if (inputs.size() == 1 && inputs.front().
getType() == outputType)
1328 return inputs.front();
1335 builder.
create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1336 appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1337 return convertOp.getResult(0);
1346 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"'(" << op
1350 "attempting to insert into a block within a replaced/erased op");
1352 if (!previous.
isSet()) {
1354 appendRewrite<CreateOperationRewrite>(op);
1360 appendRewrite<MoveOperationRewrite>(op, previous.
getBlock(), prevOp);
1366 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1369 for (
auto [newValue, result] : llvm::zip(newValues, op->
getResults())) {
1372 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1388 mapping.map(result, newValue);
1398 appendRewrite<EraseBlockRewrite>(block);
1404 "attempting to insert into a region within a replaced/erased op");
1409 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
1410 <<
"'(" << parent <<
")\n";
1413 <<
"** Insert Block into detached Region (nullptr parent op)'";
1419 appendRewrite<CreateBlockRewrite>(block);
1422 Block *prevBlock = previousIt == previous->
end() ? nullptr : &*previousIt;
1423 appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1428 appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1435 reasonCallback(
diag);
1436 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
1446 ConversionPatternRewriter::ConversionPatternRewriter(
1450 setListener(
impl.get());
1456 assert(op && newOp &&
"expected non-null op");
1462 "incorrect # of replacement values");
1464 impl->logger.startLine()
1465 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
1467 impl->notifyOpReplaced(op, newValues);
1472 impl->logger.startLine()
1473 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
1476 impl->notifyOpReplaced(op, nullRepls);
1481 "attempting to erase a block within a replaced/erased op");
1491 impl->notifyBlockIsBeingErased(block);
1499 "attempting to apply a signature conversion to a block within a "
1500 "replaced/erased op");
1501 return impl->applySignatureConversion(*
this, block, converter, conversion);
1508 "attempting to apply a signature conversion to a block within a "
1509 "replaced/erased op");
1510 return impl->convertRegionTypes(*
this, region, converter, entryConversion);
1517 impl->logger.startLine() <<
"** Replace Argument : '" << from
1518 <<
"'(in region of '" << parentOp->
getName()
1521 impl->appendRewrite<ReplaceBlockArgRewrite>(from.
getOwner(), from);
1522 impl->mapping.map(
impl->mapping.lookupOrDefault(from), to);
1527 if (failed(
impl->remapValues(
"value", std::nullopt, *
this, key,
1530 return remappedValues.front();
1538 return impl->remapValues(
"value", std::nullopt, *
this, keys,
1547 "incorrect # of argument replacement values");
1549 "attempting to inline a block from a replaced/erased op");
1551 "attempting to inline a block into a replaced/erased op");
1552 auto opIgnored = [&](
Operation *op) {
return impl->isOpIgnored(op); };
1555 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
1556 "expected 'source' to have no predecessors");
1565 bool fastPath = !
impl->config.listener;
1568 impl->notifyBlockBeingInlined(dest, source, before);
1571 for (
auto it : llvm::zip(source->
getArguments(), argValues))
1572 replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1579 while (!source->
empty())
1580 moveOpBefore(&source->
front(), dest, before);
1588 assert(!
impl->wasOpReplaced(op) &&
1589 "attempting to modify a replaced/erased op");
1591 impl->pendingRootUpdates.insert(op);
1593 impl->appendRewrite<ModifyOperationRewrite>(op);
1597 assert(!
impl->wasOpReplaced(op) &&
1598 "attempting to modify a replaced/erased op");
1603 assert(
impl->pendingRootUpdates.erase(op) &&
1604 "operation did not have a pending in-place update");
1610 assert(
impl->pendingRootUpdates.erase(op) &&
1611 "operation did not have a pending in-place update");
1614 auto it = llvm::find_if(
1615 llvm::reverse(
impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
1616 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1617 return modifyRewrite && modifyRewrite->getOperation() == op;
1619 assert(it !=
impl->rewrites.rend() &&
"no root update started on op");
1621 int updateIdx = std::prev(
impl->rewrites.rend()) - it;
1622 impl->rewrites.erase(
impl->rewrites.begin() + updateIdx);
1637 auto &rewriterImpl = dialectRewriter.getImpl();
1641 getTypeConverter());
1649 return matchAndRewrite(op, operands, dialectRewriter);
1661 class OperationLegalizer {
1681 LogicalResult legalizeWithFold(
Operation *op,
1686 LogicalResult legalizeWithPattern(
Operation *op,
1697 RewriterState &curState);
1701 legalizePatternBlockRewrites(
Operation *op,
1704 RewriterState &state, RewriterState &newState);
1705 LogicalResult legalizePatternCreatedOperations(
1707 RewriterState &state, RewriterState &newState);
1710 RewriterState &state,
1711 RewriterState &newState);
1721 void buildLegalizationGraph(
1722 LegalizationPatterns &anyOpLegalizerPatterns,
1733 void computeLegalizationGraphBenefit(
1734 LegalizationPatterns &anyOpLegalizerPatterns,
1739 unsigned computeOpLegalizationDepth(
1746 unsigned applyCostModelToPatterns(
1747 LegalizationPatterns &patterns,
1768 : target(targetInfo), applicator(patterns), config(config) {
1772 LegalizationPatterns anyOpLegalizerPatterns;
1774 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1775 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1778 bool OperationLegalizer::isIllegal(
Operation *op)
const {
1779 return target.isIllegal(op);
1783 OperationLegalizer::legalize(
Operation *op,
1786 const char *logLineComment =
1787 "//===-------------------------------------------===//\n";
1792 logger.getOStream() <<
"\n";
1793 logger.startLine() << logLineComment;
1794 logger.startLine() <<
"Legalizing operation : '" << op->
getName() <<
"'("
1800 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1801 logger.getOStream() <<
"\n\n";
1806 if (
auto legalityInfo = target.isLegal(op)) {
1809 logger,
"operation marked legal by the target{0}",
1810 legalityInfo->isRecursivelyLegal
1811 ?
"; NOTE: operation is recursively legal; skipping internals"
1813 logger.startLine() << logLineComment;
1818 if (legalityInfo->isRecursivelyLegal) {
1831 logSuccess(logger,
"operation marked 'ignored' during conversion");
1832 logger.startLine() << logLineComment;
1840 if (succeeded(legalizeWithFold(op, rewriter))) {
1843 logger.startLine() << logLineComment;
1849 if (succeeded(legalizeWithPattern(op, rewriter))) {
1852 logger.startLine() << logLineComment;
1858 logFailure(logger,
"no matched legalization pattern");
1859 logger.startLine() << logLineComment;
1865 OperationLegalizer::legalizeWithFold(
Operation *op,
1867 auto &rewriterImpl = rewriter.
getImpl();
1871 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
1872 rewriterImpl.
logger.indent();
1878 if (failed(rewriter.
tryFold(op, replacementValues))) {
1884 if (replacementValues.empty())
1885 return legalize(op, rewriter);
1888 rewriter.
replaceOp(op, replacementValues);
1891 for (
unsigned i = curState.numRewrites, e = rewriterImpl.
rewrites.size();
1894 dyn_cast<CreateOperationRewrite>(rewriterImpl.
rewrites[i].get());
1897 if (failed(legalize(createOp->getOperation(), rewriter))) {
1899 "failed to legalize generated constant '{0}'",
1900 createOp->getOperation()->getName()));
1911 OperationLegalizer::legalizeWithPattern(
Operation *op,
1913 auto &rewriterImpl = rewriter.
getImpl();
1916 auto canApply = [&](
const Pattern &pattern) {
1917 bool canApply = canApplyPattern(op, pattern, rewriter);
1918 if (canApply && config.listener)
1919 config.listener->notifyPatternBegin(pattern, op);
1925 auto onFailure = [&](
const Pattern &pattern) {
1931 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
1937 if (config.listener)
1938 config.listener->notifyPatternEnd(pattern, failure());
1940 appliedPatterns.erase(&pattern);
1945 auto onSuccess = [&](
const Pattern &pattern) {
1947 auto result = legalizePatternResult(op, pattern, rewriter, curState);
1948 appliedPatterns.erase(&pattern);
1951 if (config.listener)
1952 config.listener->notifyPatternEnd(pattern, result);
1957 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
1961 bool OperationLegalizer::canApplyPattern(
Operation *op,
const Pattern &pattern,
1964 auto &os = rewriter.
getImpl().logger;
1965 os.getOStream() <<
"\n";
1966 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
1968 os.getOStream() <<
")' {\n";
1975 !appliedPatterns.insert(&pattern).second) {
1984 OperationLegalizer::legalizePatternResult(
Operation *op,
const Pattern &pattern,
1986 RewriterState &curState) {
1990 assert(
impl.pendingRootUpdates.empty() &&
"dangling root updates");
1992 auto newRewrites = llvm::drop_begin(
impl.rewrites, curState.numRewrites);
1993 auto replacedRoot = [&] {
1994 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
1996 auto updatedRootInPlace = [&] {
1997 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
1999 assert((replacedRoot() || updatedRootInPlace()) &&
2000 "expected pattern to replace the root operation");
2004 RewriterState newState =
impl.getCurrentState();
2005 if (failed(legalizePatternBlockRewrites(op, rewriter,
impl, curState,
2007 failed(legalizePatternRootUpdates(rewriter,
impl, curState, newState)) ||
2008 failed(legalizePatternCreatedOperations(rewriter,
impl, curState,
2013 LLVM_DEBUG(
logSuccess(
impl.logger,
"pattern applied successfully"));
2017 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2020 RewriterState &newState) {
2025 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2026 BlockRewrite *
rewrite = dyn_cast<BlockRewrite>(
impl.rewrites[i].get());
2030 if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2031 ReplaceBlockArgRewrite>(
rewrite))
2040 if (
auto *converter =
impl.regionToConverter.lookup(block->
getParent())) {
2041 std::optional<TypeConverter::SignatureConversion> conversion =
2042 converter->convertBlockSignature(block);
2044 LLVM_DEBUG(
logFailure(
impl.logger,
"failed to convert types of moved "
2048 impl.applySignatureConversion(rewriter, block, converter, *conversion);
2056 if (operationsToIgnore.empty()) {
2057 for (
unsigned i = state.numRewrites, e =
impl.rewrites.size(); i != e;
2060 dyn_cast<CreateOperationRewrite>(
impl.rewrites[i].get());
2063 operationsToIgnore.insert(createOp->getOperation());
2068 if (operationsToIgnore.insert(parentOp).second &&
2069 failed(legalize(parentOp, rewriter))) {
2071 "operation '{0}'({1}) became illegal after rewrite",
2072 parentOp->
getName(), parentOp));
2079 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2081 RewriterState &state, RewriterState &newState) {
2082 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2083 auto *createOp = dyn_cast<CreateOperationRewrite>(
impl.rewrites[i].get());
2086 Operation *op = createOp->getOperation();
2087 if (failed(legalize(op, rewriter))) {
2089 "failed to legalize generated operation '{0}'({1})",
2097 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2099 RewriterState &state, RewriterState &newState) {
2100 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2101 auto *
rewrite = dyn_cast<ModifyOperationRewrite>(
impl.rewrites[i].get());
2105 if (failed(legalize(op, rewriter))) {
2107 impl.logger,
"failed to legalize operation updated in-place '{0}'",
2118 void OperationLegalizer::buildLegalizationGraph(
2119 LegalizationPatterns &anyOpLegalizerPatterns,
2130 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2131 std::optional<OperationName> root = pattern.
getRootKind();
2137 anyOpLegalizerPatterns.push_back(&pattern);
2142 if (target.getOpAction(*root) == LegalizationAction::Legal)
2147 invalidPatterns[*root].insert(&pattern);
2149 parentOps[op].insert(*root);
2152 patternWorklist.insert(&pattern);
2160 if (!anyOpLegalizerPatterns.empty()) {
2161 for (
const Pattern *pattern : patternWorklist)
2162 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2166 while (!patternWorklist.empty()) {
2167 auto *pattern = patternWorklist.pop_back_val();
2171 std::optional<LegalizationAction> action = target.getOpAction(op);
2172 return !legalizerPatterns.count(op) &&
2173 (!action || action == LegalizationAction::Illegal);
2179 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2180 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2184 for (
auto op : parentOps[*pattern->
getRootKind()])
2185 patternWorklist.set_union(invalidPatterns[op]);
2189 void OperationLegalizer::computeLegalizationGraphBenefit(
2190 LegalizationPatterns &anyOpLegalizerPatterns,
2196 for (
auto &opIt : legalizerPatterns)
2197 if (!minOpPatternDepth.count(opIt.first))
2198 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2204 if (!anyOpLegalizerPatterns.empty())
2205 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2211 applicator.applyCostModel([&](
const Pattern &pattern) {
2213 if (std::optional<OperationName> rootName = pattern.
getRootKind())
2214 orderedPatternList = legalizerPatterns[*rootName];
2216 orderedPatternList = anyOpLegalizerPatterns;
2219 auto *it = llvm::find(orderedPatternList, &pattern);
2220 if (it == orderedPatternList.end())
2224 return PatternBenefit(std::distance(it, orderedPatternList.end()));
2228 unsigned OperationLegalizer::computeOpLegalizationDepth(
2232 auto depthIt = minOpPatternDepth.find(op);
2233 if (depthIt != minOpPatternDepth.end())
2234 return depthIt->second;
2238 auto opPatternsIt = legalizerPatterns.find(op);
2239 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2248 unsigned minDepth = applyCostModelToPatterns(
2249 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2250 minOpPatternDepth[op] = minDepth;
2254 unsigned OperationLegalizer::applyCostModelToPatterns(
2255 LegalizationPatterns &patterns,
2262 patternsByDepth.reserve(patterns.size());
2263 for (
const Pattern *pattern : patterns) {
2266 unsigned generatedOpDepth = computeOpLegalizationDepth(
2267 generatedOp, minOpPatternDepth, legalizerPatterns);
2268 depth =
std::max(depth, generatedOpDepth + 1);
2270 patternsByDepth.emplace_back(pattern, depth);
2273 minDepth =
std::min(minDepth, depth);
2278 if (patternsByDepth.size() == 1)
2282 std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2283 [](
const std::pair<const Pattern *, unsigned> &lhs,
2284 const std::pair<const Pattern *, unsigned> &rhs) {
2287 if (lhs.second != rhs.second)
2288 return lhs.second < rhs.second;
2291 auto lhsBenefit = lhs.first->getBenefit();
2292 auto rhsBenefit = rhs.first->getBenefit();
2293 return lhsBenefit > rhsBenefit;
2298 for (
auto &patternIt : patternsByDepth)
2299 patterns.push_back(patternIt.first);
2307 enum OpConversionMode {
2330 OpConversionMode mode)
2331 : config(config), opLegalizer(target, patterns, this->config),
2349 OperationLegalizer opLegalizer;
2352 OpConversionMode mode;
2359 if (failed(opLegalizer.legalize(op, rewriter))) {
2362 if (mode == OpConversionMode::Full)
2364 <<
"failed to legalize operation '" << op->
getName() <<
"'";
2368 if (mode == OpConversionMode::Partial) {
2369 if (opLegalizer.isIllegal(op))
2371 <<
"failed to legalize operation '" << op->
getName()
2372 <<
"' that was explicitly marked illegal";
2376 }
else if (mode == OpConversionMode::Analysis) {
2386 static LogicalResult
2388 UnresolvedMaterializationRewrite *
rewrite) {
2389 UnrealizedConversionCastOp op =
rewrite->getOperation();
2391 "expected that dead materializations have already been DCE'd");
2398 Value newMaterialization;
2399 switch (
rewrite->getMaterializationKind()) {
2402 newMaterialization = converter->materializeArgumentConversion(
2403 rewriter, op->
getLoc(), outputType, inputOperands);
2404 if (newMaterialization)
2409 case MaterializationKind::Target:
2410 newMaterialization = converter->materializeTargetConversion(
2411 rewriter, op->
getLoc(), outputType, inputOperands);
2413 case MaterializationKind::Source:
2414 newMaterialization = converter->materializeSourceConversion(
2415 rewriter, op->
getLoc(), outputType, inputOperands);
2418 if (newMaterialization) {
2419 assert(newMaterialization.
getType() == outputType &&
2420 "materialization callback produced value of incorrect type");
2421 rewriter.
replaceOp(op, newMaterialization);
2427 <<
"failed to legalize unresolved materialization "
2429 << inputOperands.
getTypes() <<
") to " << outputType
2430 <<
" that remained live after conversion";
2432 <<
"see existing live user here: " << *op->
getUsers().begin();
2443 for (
auto *op : ops) {
2446 toConvert.push_back(op);
2449 auto legalityInfo = target.
isLegal(op);
2450 if (legalityInfo && legalityInfo->isRecursivelyLegal)
2460 for (
auto *op : toConvert)
2461 if (failed(convert(rewriter, op)))
2476 for (
auto it : materializations) {
2479 allCastOps.push_back(it.first);
2491 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2492 auto it = materializations.find(castOp);
2493 assert(it != materializations.end() &&
"inconsistent state");
2508 while (!worklist.empty()) {
2509 Value value = worklist.pop_back_val();
2514 return rewriterImpl.isOpIgnored(user);
2516 if (liveUserIt != value.
user_end())
2518 auto mapIt = inverseMapping.find(value);
2519 if (mapIt != inverseMapping.end())
2520 worklist.append(mapIt->second);
2529 static std::pair<ValueRange, const TypeConverter *>
2531 if (
auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(
rewrite))
2532 return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
2533 if (
auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(
rewrite))
2534 return {blockRewrite->getOrigBlock()->getArguments(),
2535 blockRewrite->getConverter()};
2542 rewriterImpl.
mapping.getInverse();
2545 for (
unsigned i = 0, e = rewriterImpl.
rewrites.size(); i < e; ++i) {
2548 std::tie(replacedValues, converter) =
2550 for (
Value originalValue : replacedValues) {
2553 if (rewriterImpl.
mapping.lookupOrNull(originalValue,
2554 originalValue.getType()))
2562 Value newValue = rewriterImpl.
mapping.lookupOrNull(originalValue);
2563 assert(newValue &&
"replacement value not found");
2566 originalValue.getLoc(),
2567 newValue, originalValue.getType(),
2569 rewriterImpl.
mapping.map(originalValue, castValue);
2570 inverseMapping[castValue].push_back(originalValue);
2571 llvm::erase(inverseMapping[newValue], originalValue);
2590 auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2591 for (
Value v : castOp.getInputs())
2592 if (
auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2593 worklist.insert(inputCastOp);
2600 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2601 if (castOp.getInputs().empty())
2604 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2607 if (inputCastOp.getOutputs() != castOp.getInputs())
2613 while (!worklist.empty()) {
2614 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2615 if (castOp->use_empty()) {
2618 enqueueOperands(castOp);
2619 if (remainingCastOps)
2620 erasedOps.insert(castOp.getOperation());
2627 UnrealizedConversionCastOp nextCast = castOp;
2629 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2633 enqueueOperands(castOp);
2634 castOp.replaceAllUsesWith(nextCast.getInputs());
2635 if (remainingCastOps)
2636 erasedOps.insert(castOp.getOperation());
2640 nextCast = getInputCast(nextCast);
2644 if (remainingCastOps)
2645 for (UnrealizedConversionCastOp op : castOps)
2646 if (!erasedOps.contains(op.getOperation()))
2647 remainingCastOps->push_back(op);
2656 assert(!types.empty() &&
"expected valid types");
2657 remapInput(origInputNo, argTypes.size(), types.size());
2662 assert(!types.empty() &&
2663 "1->0 type remappings don't need to be added explicitly");
2664 argTypes.append(types.begin(), types.end());
2668 unsigned newInputNo,
2669 unsigned newInputCount) {
2670 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
2671 assert(newInputCount != 0 &&
"expected valid input count");
2672 remappedInputs[origInputNo] =
2673 InputMapping{newInputNo, newInputCount,
nullptr};
2677 Value replacementValue) {
2678 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
2679 remappedInputs[origInputNo] =
2686 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2689 cacheReadLock.lock();
2690 auto existingIt = cachedDirectConversions.find(t);
2691 if (existingIt != cachedDirectConversions.end()) {
2692 if (existingIt->second)
2693 results.push_back(existingIt->second);
2694 return success(existingIt->second !=
nullptr);
2696 auto multiIt = cachedMultiConversions.find(t);
2697 if (multiIt != cachedMultiConversions.end()) {
2698 results.append(multiIt->second.begin(), multiIt->second.end());
2704 size_t currentCount = results.size();
2706 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2709 for (
const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2710 if (std::optional<LogicalResult> result = converter(t, results)) {
2712 cacheWriteLock.lock();
2713 if (!succeeded(*result)) {
2714 cachedDirectConversions.try_emplace(t,
nullptr);
2717 auto newTypes =
ArrayRef<Type>(results).drop_front(currentCount);
2718 if (newTypes.size() == 1)
2719 cachedDirectConversions.try_emplace(t, newTypes.front());
2721 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2735 return results.size() == 1 ? results.front() :
nullptr;
2741 for (
Type type : types)
2755 return llvm::all_of(*region, [
this](
Block &block) {
2761 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2773 if (convertedTypes.empty())
2777 result.
addInputs(inputNo, convertedTypes);
2783 unsigned origInputOffset)
const {
2784 for (
unsigned i = 0, e = types.size(); i != e; ++i)
2790 Value TypeConverter::materializeConversion(
2793 for (
const MaterializationCallbackFn &fn : llvm::reverse(materializations))
2794 if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
2799 std::optional<TypeConverter::SignatureConversion>
2803 return std::nullopt;
2826 return impl.getInt() == resultTag;
2830 return impl.getInt() == naTag;
2834 return impl.getInt() == abortTag;
2838 assert(hasResult() &&
"Cannot get result from N/A or abort");
2839 return impl.getPointer();
2842 std::optional<Attribute>
2844 for (
const TypeAttributeConversionCallbackFn &fn :
2845 llvm::reverse(typeAttributeConversions)) {
2850 return std::nullopt;
2852 return std::nullopt;
2862 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
2870 failed(typeConverter.
convertTypes(type.getResults(), newResults)) ||
2872 typeConverter, &result)))
2889 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
2897 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
2902 struct AnyFunctionOpInterfaceSignatureConversion
2914 FailureOr<Operation *>
2918 assert(op &&
"Invalid op");
2932 return rewriter.
create(newOp);
2938 patterns.
add<FunctionOpInterfaceSignatureConversion>(
2939 functionLikeOpName, patterns.
getContext(), converter);
2944 patterns.
add<AnyFunctionOpInterfaceSignatureConversion>(
2954 legalOperations[op].action = action;
2959 for (StringRef dialect : dialectNames)
2960 legalDialects[dialect] = action;
2964 -> std::optional<LegalizationAction> {
2965 std::optional<LegalizationInfo> info = getOpInfo(op);
2966 return info ? info->action : std::optional<LegalizationAction>();
2970 -> std::optional<LegalOpDetails> {
2971 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
2973 return std::nullopt;
2976 auto isOpLegal = [&] {
2978 if (info->action == LegalizationAction::Dynamic) {
2979 std::optional<bool> result = info->legalityFn(op);
2985 return info->action == LegalizationAction::Legal;
2988 return std::nullopt;
2992 if (info->isRecursivelyLegal) {
2993 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
2994 if (legalityFnIt != opRecursiveLegalityFns.end()) {
2996 legalityFnIt->second(op).value_or(
true);
3001 return legalityDetails;
3005 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3009 if (info->action == LegalizationAction::Dynamic) {
3010 std::optional<bool> result = info->legalityFn(op);
3017 return info->action == LegalizationAction::Illegal;
3026 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3028 if (std::optional<bool> result = newCl(op))
3036 void ConversionTarget::setLegalityCallback(
3037 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3038 assert(callback &&
"expected valid legality callback");
3039 auto *infoIt = legalOperations.find(name);
3040 assert(infoIt != legalOperations.end() &&
3041 infoIt->second.action == LegalizationAction::Dynamic &&
3042 "expected operation to already be marked as dynamically legal");
3043 infoIt->second.legalityFn =
3049 auto *infoIt = legalOperations.find(name);
3050 assert(infoIt != legalOperations.end() &&
3051 infoIt->second.action != LegalizationAction::Illegal &&
3052 "expected operation to already be marked as legal");
3053 infoIt->second.isRecursivelyLegal =
true;
3056 std::move(opRecursiveLegalityFns[name]), callback);
3058 opRecursiveLegalityFns.erase(name);
3061 void ConversionTarget::setLegalityCallback(
3063 assert(callback &&
"expected valid legality callback");
3064 for (StringRef dialect : dialects)
3066 std::move(dialectLegalityFns[dialect]), callback);
3069 void ConversionTarget::setLegalityCallback(
3070 const DynamicLegalityCallbackFn &callback) {
3071 assert(callback &&
"expected valid legality callback");
3076 -> std::optional<LegalizationInfo> {
3078 const auto *it = legalOperations.find(op);
3079 if (it != legalOperations.end())
3082 auto dialectIt = legalDialects.find(op.getDialectNamespace());
3083 if (dialectIt != legalDialects.end()) {
3084 DynamicLegalityCallbackFn callback;
3085 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3086 if (dialectFn != dialectLegalityFns.end())
3087 callback = dialectFn->second;
3088 return LegalizationInfo{dialectIt->second,
false,
3092 if (unknownLegalityFn)
3093 return LegalizationInfo{LegalizationAction::Dynamic,
3094 false, unknownLegalityFn};
3095 return std::nullopt;
3098 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3104 auto &rewriterImpl =
3110 auto &rewriterImpl =
3117 static FailureOr<SmallVector<Value>>
3122 return std::move(mappedValues);
3131 if (failed(results))
3133 return results->front();
3143 auto &rewriterImpl =
3156 TypeRange types) -> FailureOr<SmallVector<Type>> {
3157 auto &rewriterImpl =
3164 if (failed(converter->
convertTypes(types, remappedTypes)))
3166 return std::move(remappedTypes);
3182 OpConversionMode::Partial);
3200 OpConversionMode::Full);
3223 "expected top-level op to be isolated from above");
3226 "expected ops to have a common ancestor");
3235 for (
Operation *op : ops.drop_front()) {
3239 assert(commonAncestor &&
3240 "expected to find a common isolated from above ancestor");
3244 return commonAncestor;
3262 inverseOperationMap[it.second] = it.first;
3268 OpConversionMode::Analysis);
3269 LogicalResult status = opConverter.convertOperations(opsToConvert);
3276 originalLegalizableOps.insert(inverseOperationMap[op]);
3281 clonedAncestor->
erase();
static std::pair< ValueRange, const TypeConverter * > getReplacedValues(IRRewrite *rewrite)
Helper function that returns the replaced values and the type converter if the given rewrite object i...
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static Operation * findLiveUserOfReplaced(Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, const DenseMap< Value, SmallVector< Value >> &inverseMapping)
Finds a user of the given value, or of any other value that the given value replaced,...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static bool hasRewrite(R &&rewrites, Operation *op)
Return "true" if there is an operation rewrite that matches the specified rewrite type and operation ...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
OpListType & getOperations()
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
~ConversionPatternRewriter() override
Base class for the conversion patterns.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
user_range getUsers()
Returns a range of all users.
result_range getResults()
int getPropertiesStorageSize() const
Returns the properties storage size.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
MLIRContext * getContext() const
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
The general result of a type attribute conversion callback, allowing for early termination.
Attribute getResult() const
static AttributeConversionResult abort()
static AttributeConversionResult na()
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_end() const
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target/ argument materializations t...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
A rewriter that keeps track of erased ops and blocks.
bool wasErased(void *ptr) const
SingleEraseRewriter(MLIRContext *context)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
void notifyOpReplaced(Operation *op, ValueRange newValues)
Notifies that an op is about to be replaced with the given values.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
Value buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueRange inputs, Type outputType, const TypeConverter *converter)
Build an unresolved materialization operation given an output type and set of input operands.
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
SingleEraseRewriter eraseRewriter
A rewriter that keeps track of ops/block that were already erased and skips duplicate op/block erasur...
MLIRContext * context
MLIR context.
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.