10 #include "mlir/Config/mlir-config.h"
20 #include "llvm/ADT/ScopeExit.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "llvm/Support/SaveAndRestore.h"
26 #include "llvm/Support/ScopedPrinter.h"
32 #define DEBUG_TYPE "dialect-conversion"
35 template <
typename... Args>
36 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
39 os.startLine() <<
"} -> SUCCESS";
41 os.getOStream() <<
" : "
42 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
43 os.getOStream() <<
"\n";
48 template <
typename... Args>
49 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
52 os.startLine() <<
"} -> FAILURE : "
53 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
63 if (
OpResult inputRes = dyn_cast<OpResult>(value))
64 insertPt = ++inputRes.getOwner()->getIterator();
71 assert(!vals.empty() &&
"expected at least one value");
74 for (
Value v : vals.drop_front()) {
88 assert(dom &&
"unable to find valid insertion point");
106 struct ValueVectorMapInfo {
109 static ::llvm::hash_code getHashValue(
const ValueVector &val) {
110 return ::llvm::hash_combine_range(val);
119 struct ConversionValueMapping {
122 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
127 template <
typename T>
128 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
131 template <
typename OldVal,
typename NewVal>
132 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
133 map(OldVal &&oldVal, NewVal &&newVal) {
137 assert(next != oldVal &&
"inserting cyclic mapping");
138 auto it = mapping.find(next);
139 if (it == mapping.end())
144 mappedTo.insert_range(newVal);
146 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
150 template <
typename OldVal,
typename NewVal>
151 std::enable_if_t<!IsValueVector<OldVal>::value ||
152 !IsValueVector<NewVal>::value>
153 map(OldVal &&oldVal, NewVal &&newVal) {
154 if constexpr (IsValueVector<OldVal>{}) {
155 map(std::forward<OldVal>(oldVal),
ValueVector{newVal});
156 }
else if constexpr (IsValueVector<NewVal>{}) {
157 map(
ValueVector{oldVal}, std::forward<NewVal>(newVal));
163 void map(
Value oldVal, SmallVector<Value> &&newVal) {
168 void erase(
const ValueVector &value) { mapping.erase(value); }
188 assert(!values.empty() &&
"expected non-empty value vector");
189 Operation *op = values.front().getDefiningOp();
190 for (
Value v : llvm::drop_begin(values)) {
191 if (v.getDefiningOp() != op)
201 assert(!values.empty() &&
"expected non-empty value vector");
207 auto it = mapping.find(from);
208 if (it == mapping.end()) {
221 struct RewriterState {
222 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
223 unsigned numReplacedOps)
224 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
225 numReplacedOps(numReplacedOps) {}
228 unsigned numRewrites;
231 unsigned numIgnoredOperations;
234 unsigned numReplacedOps;
246 notifyIRErased(listener, op);
255 notifyIRErased(listener, b);
285 UnresolvedMaterialization,
290 virtual ~IRRewrite() =
default;
293 virtual void rollback() = 0;
312 Kind getKind()
const {
return kind; }
314 static bool classof(
const IRRewrite *
rewrite) {
return true; }
318 :
kind(
kind), rewriterImpl(rewriterImpl) {}
327 class BlockRewrite :
public IRRewrite {
330 Block *getBlock()
const {
return block; }
332 static bool classof(
const IRRewrite *
rewrite) {
333 return rewrite->getKind() >= Kind::CreateBlock &&
334 rewrite->getKind() <= Kind::BlockTypeConversion;
340 : IRRewrite(
kind, rewriterImpl), block(block) {}
347 class ValueRewrite :
public IRRewrite {
350 Value getValue()
const {
return value; }
352 static bool classof(
const IRRewrite *
rewrite) {
353 return rewrite->getKind() == Kind::ReplaceValue;
359 : IRRewrite(
kind, rewriterImpl), value(value) {}
368 class CreateBlockRewrite :
public BlockRewrite {
371 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
373 static bool classof(
const IRRewrite *
rewrite) {
374 return rewrite->getKind() == Kind::CreateBlock;
383 void rollback()
override {
386 auto &blockOps = block->getOperations();
387 while (!blockOps.empty())
388 blockOps.remove(blockOps.begin());
389 block->dropAllUses();
390 if (block->getParent())
401 class EraseBlockRewrite :
public BlockRewrite {
404 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
405 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
407 static bool classof(
const IRRewrite *
rewrite) {
408 return rewrite->getKind() == Kind::EraseBlock;
411 ~EraseBlockRewrite()
override {
413 "rewrite was neither rolled back nor committed/cleaned up");
416 void rollback()
override {
419 assert(block &&
"expected block");
420 auto &blockList = region->getBlocks();
424 blockList.insert(before, block);
429 assert(block &&
"expected block");
433 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
434 notifyIRErased(listener, *block);
439 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
441 assert(block->empty() &&
"expected empty block");
444 block->dropAllDefinedValueUses();
455 Block *insertBeforeBlock;
461 class InlineBlockRewrite :
public BlockRewrite {
465 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
466 sourceBlock(sourceBlock),
467 firstInlinedInst(sourceBlock->empty() ? nullptr
468 : &sourceBlock->front()),
469 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
475 assert(!getConfig().listener &&
476 "InlineBlockRewrite not supported if listener is attached");
479 static bool classof(
const IRRewrite *
rewrite) {
480 return rewrite->getKind() == Kind::InlineBlock;
483 void rollback()
override {
486 if (firstInlinedInst) {
487 assert(lastInlinedInst &&
"expected operation");
507 class MoveBlockRewrite :
public BlockRewrite {
511 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block),
512 region(previousRegion),
513 insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
516 static bool classof(
const IRRewrite *
rewrite) {
517 return rewrite->getKind() == Kind::MoveBlock;
530 void rollback()
override {
543 Block *insertBeforeBlock;
547 class BlockTypeConversionRewrite :
public BlockRewrite {
551 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, origBlock),
552 newBlock(newBlock) {}
554 static bool classof(
const IRRewrite *
rewrite) {
555 return rewrite->getKind() == Kind::BlockTypeConversion;
558 Block *getOrigBlock()
const {
return block; }
560 Block *getNewBlock()
const {
return newBlock; }
564 void rollback()
override;
574 class ReplaceValueRewrite :
public ValueRewrite {
578 : ValueRewrite(
Kind::ReplaceValue, rewriterImpl, value),
579 converter(converter) {}
581 static bool classof(
const IRRewrite *
rewrite) {
582 return rewrite->getKind() == Kind::ReplaceValue;
587 void rollback()
override;
595 class OperationRewrite :
public IRRewrite {
598 Operation *getOperation()
const {
return op; }
600 static bool classof(
const IRRewrite *
rewrite) {
601 return rewrite->getKind() >= Kind::MoveOperation &&
602 rewrite->getKind() <= Kind::UnresolvedMaterialization;
608 : IRRewrite(
kind, rewriterImpl), op(op) {}
615 class MoveOperationRewrite :
public OperationRewrite {
619 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op),
620 block(previous.getBlock()),
621 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
623 : &*previous.getPoint()) {}
625 static bool classof(
const IRRewrite *
rewrite) {
626 return rewrite->getKind() == Kind::MoveOperation;
640 void rollback()
override {
658 class ModifyOperationRewrite :
public OperationRewrite {
662 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
663 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
664 operands(op->operand_begin(), op->operand_end()),
665 successors(op->successor_begin(), op->successor_end()) {
670 name.initOpProperties(propCopy, prop);
674 static bool classof(
const IRRewrite *
rewrite) {
675 return rewrite->getKind() == Kind::ModifyOperation;
678 ~ModifyOperationRewrite()
override {
679 assert(!propertiesStorage &&
680 "rewrite was neither committed nor rolled back");
686 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
689 if (propertiesStorage) {
693 name.destroyOpProperties(propCopy);
694 operator delete(propertiesStorage);
695 propertiesStorage =
nullptr;
699 void rollback()
override {
705 if (propertiesStorage) {
708 name.destroyOpProperties(propCopy);
709 operator delete(propertiesStorage);
710 propertiesStorage =
nullptr;
717 DictionaryAttr attrs;
718 SmallVector<Value, 8> operands;
719 SmallVector<Block *, 2> successors;
720 void *propertiesStorage =
nullptr;
727 class ReplaceOperationRewrite :
public OperationRewrite {
731 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
732 converter(converter) {}
734 static bool classof(
const IRRewrite *
rewrite) {
735 return rewrite->getKind() == Kind::ReplaceOperation;
740 void rollback()
override;
750 class CreateOperationRewrite :
public OperationRewrite {
754 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
756 static bool classof(
const IRRewrite *
rewrite) {
757 return rewrite->getKind() == Kind::CreateOperation;
766 void rollback()
override;
770 enum MaterializationKind {
781 class UnresolvedMaterializationInfo {
783 UnresolvedMaterializationInfo() =
default;
784 UnresolvedMaterializationInfo(
const TypeConverter *converter,
785 MaterializationKind
kind,
Type originalType)
786 : converterAndKind(converter,
kind), originalType(originalType) {}
790 return converterAndKind.getPointer();
794 MaterializationKind getMaterializationKind()
const {
795 return converterAndKind.getInt();
799 Type getOriginalType()
const {
return originalType; }
804 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
815 class UnresolvedMaterializationRewrite :
public OperationRewrite {
818 UnrealizedConversionCastOp op,
820 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
821 mappedValues(std::move(mappedValues)) {}
823 static bool classof(
const IRRewrite *
rewrite) {
824 return rewrite->getKind() == Kind::UnresolvedMaterialization;
827 void rollback()
override;
829 UnrealizedConversionCastOp getOperation()
const {
830 return cast<UnrealizedConversionCastOp>(op);
840 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
843 template <
typename RewriteTy,
typename R>
844 static bool hasRewrite(R &&rewrites,
Operation *op) {
845 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
846 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
847 return rewriteTy && rewriteTy->getOperation() == op;
853 template <
typename RewriteTy,
typename R>
854 static bool hasRewrite(R &&rewrites,
Block *block) {
855 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
856 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
857 return rewriteTy && rewriteTy->getBlock() == block;
878 RewriterState getCurrentState();
882 void applyRewrites();
887 void resetState(RewriterState state, StringRef patternName =
"");
891 template <
typename RewriteTy,
typename... Args>
893 assert(
config.allowPatternRollback &&
"appending rewrites is not allowed");
895 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
901 void undoRewrites(
unsigned numRewritesToKeep = 0, StringRef patternName =
"");
907 LogicalResult remapValues(StringRef valueDiagTag,
908 std::optional<Location> inputLoc,
ValueRange values,
925 bool skipPureTypeConversions =
false)
const;
947 Block *applySignatureConversion(
958 void replaceOp(
Operation *op, SmallVector<SmallVector<Value>> &&newValues);
966 void eraseBlock(
Block *block);
1004 Value findOrBuildReplacementValue(
Value value,
1012 void notifyOperationInserted(
Operation *op,
1016 void notifyBlockInserted(
Block *block,
Region *previous,
1035 std::function<
void(
Operation *)> opErasedCallback =
nullptr)
1037 opErasedCallback(opErasedCallback) {}
1049 if (wasErased(block))
1051 assert(block->
empty() &&
"expected empty block");
1056 bool wasErased(
void *ptr)
const {
return erased.contains(ptr); }
1060 if (opErasedCallback)
1061 opErasedCallback(op);
1071 std::function<void(
Operation *)> opErasedCallback;
1159 llvm::impl::raw_ldbg_ostream os{(Twine(
"[") +
DEBUG_TYPE +
":1] ").str(),
1163 llvm::ScopedPrinter logger{os};
1171 return rewriterImpl.
config;
1174 void BlockTypeConversionRewrite::commit(
RewriterBase &rewriter) {
1178 if (
auto *listener =
1179 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1180 for (
Operation *op : getNewBlock()->getUsers())
1184 void BlockTypeConversionRewrite::rollback() {
1185 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1191 if (isa<BlockArgument>(repl)) {
1228 void ReplaceValueRewrite::commit(
RewriterBase &rewriter) {
1235 void ReplaceValueRewrite::rollback() {
1236 rewriterImpl.
mapping.erase({value});
1242 void ReplaceOperationRewrite::commit(
RewriterBase &rewriter) {
1244 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1247 SmallVector<Value> replacements =
1249 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1257 for (
auto [result, newValue] :
1258 llvm::zip_equal(op->
getResults(), replacements))
1264 if (getConfig().unlegalizedOps)
1265 getConfig().unlegalizedOps->erase(op);
1269 notifyIRErased(listener, *op);
1276 void ReplaceOperationRewrite::rollback() {
1278 rewriterImpl.
mapping.erase({result});
1281 void ReplaceOperationRewrite::cleanup(
RewriterBase &rewriter) {
1285 void CreateOperationRewrite::rollback() {
1287 while (!region.getBlocks().empty())
1288 region.getBlocks().remove(region.getBlocks().begin());
1294 void UnresolvedMaterializationRewrite::rollback() {
1295 if (!mappedValues.empty())
1296 rewriterImpl.
mapping.erase(mappedValues);
1307 for (
size_t i = 0; i <
rewrites.size(); ++i)
1313 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1314 unresolvedMaterializations.erase(castOp);
1317 rewrite->cleanup(eraseRewriter);
1325 Value from,
TypeRange desiredTypes,
bool skipPureTypeConversions)
const {
1328 assert(!values.empty() &&
"expected non-empty value vector");
1333 return mapping.lookup(values);
1340 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1345 if (castOp.getOutputs() != values)
1347 return castOp.getInputs();
1356 for (
Value v : values) {
1359 llvm::append_range(next, r);
1364 if (next != values) {
1393 if (skipPureTypeConversions) {
1396 match &= !pureConversion;
1399 if (!pureConversion)
1400 lastNonMaterialization = current;
1403 desiredValue = current;
1409 current = std::move(next);
1414 if (!desiredTypes.empty())
1415 return desiredValue;
1416 if (skipPureTypeConversions)
1417 return lastNonMaterialization;
1436 StringRef patternName) {
1441 while (
ignoredOps.size() != state.numIgnoredOperations)
1444 while (
replacedOps.size() != state.numReplacedOps)
1449 StringRef patternName) {
1451 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1453 rewrites.resize(numRewritesToKeep);
1457 StringRef valueDiagTag, std::optional<Location> inputLoc,
ValueRange values,
1459 remapped.reserve(llvm::size(values));
1462 Value operand = it.value();
1481 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1482 << it.index() <<
", type was " << origType;
1487 if (legalTypes.empty()) {
1488 remapped.push_back({});
1497 remapped.push_back(std::move(repl));
1506 repl, repl, legalTypes,
1508 remapped.push_back(castValues);
1531 if (region->
empty())
1536 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1538 std::optional<TypeConverter::SignatureConversion> conversion =
1548 if (entryConversion)
1551 std::optional<TypeConverter::SignatureConversion> conversion =
1561 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1563 if (hasRewrite<BlockTypeConversionRewrite>(
rewrites, block))
1564 llvm::report_fatal_error(
"block was already converted");
1578 for (
unsigned i = 0; i < origArgCount; ++i) {
1580 if (!inputMap || inputMap->replacedWithValues())
1583 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1584 newLocs[inputMap->inputNo +
j] = origLoc;
1591 convertedTypes, newLocs);
1602 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->
end());
1605 while (!block->
empty())
1612 for (
unsigned i = 0; i != origArgCount; ++i) {
1616 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1625 MaterializationKind::Source,
1629 origArgType,
Type(), converter,
1636 if (inputMap->replacedWithValues()) {
1638 assert(inputMap->size == 0 &&
1639 "invalid to provide a replacement value when the argument isn't "
1647 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1652 appendRewrite<BlockTypeConversionRewrite>(block, newBlock);
1672 assert((!originalType ||
kind == MaterializationKind::Target) &&
1673 "original type is valid only for target materializations");
1674 assert(
TypeRange(inputs) != outputTypes &&
1675 "materialization is not necessary");
1679 OpBuilder builder(outputTypes.front().getContext());
1681 UnrealizedConversionCastOp convertOp =
1682 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1685 kind == MaterializationKind::Source ?
"source" :
"target";
1686 convertOp->setAttr(
"__kind__", builder.
getStringAttr(kindStr));
1693 UnresolvedMaterializationInfo(converter,
kind, originalType);
1695 if (!valuesToMap.empty())
1696 mapping.map(valuesToMap, convertOp.getResults());
1697 appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
1698 std::move(valuesToMap));
1702 return convertOp.getResults();
1708 "this code path is valid only in rollback mode");
1715 return repl.front();
1722 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1747 MaterializationKind::Source, ip, value.
getLoc(),
1763 bool wasDetached = !previous.
isSet();
1765 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"' (" << op
1768 logger.getOStream() <<
" (was detached)";
1769 logger.getOStream() <<
"\n";
1775 "attempting to insert into a block within a replaced/erased op");
1792 appendRewrite<CreateOperationRewrite>(op);
1803 appendRewrite<MoveOperationRewrite>(op, previous);
1810 const SmallVector<SmallVector<Value>> &toRange,
1812 assert(!
impl.config.allowPatternRollback &&
1813 "this code path is valid only in 'no rollback' mode");
1814 SmallVector<Value> repls;
1815 for (
auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1818 repls.push_back(
Value());
1825 Value srcMat =
impl.buildUnresolvedMaterialization(
1830 repls.push_back(srcMat);
1836 repls.push_back(to[0]);
1845 Value srcMat =
impl.buildUnresolvedMaterialization(
1848 Type(), converter)[0];
1849 repls.push_back(srcMat);
1858 "incorrect number of replacement values");
1869 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1884 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1888 "attempting to replace a value that was already replaced");
1893 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1898 "attempting to replace/erase an unresolved materialization");
1902 for (
auto [repl, result] : llvm::zip_equal(newValues, op->
getResults()))
1903 mapping.map(
static_cast<Value>(result), std::move(repl));
1917 Value repl = repls.front();
1934 "attempting to replace a value that was already replaced");
1936 "attempting to replace a op result that was already replaced");
1941 appendRewrite<ReplaceValueRewrite>(from, converter);
1952 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1968 "attempting to erase a block within a replaced/erased op");
1969 appendRewrite<EraseBlockRewrite>(block);
1984 bool wasDetached = !previous;
1990 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
1991 <<
"' (" << parent <<
")";
1994 <<
"** Insert Block into detached Region (nullptr parent op)";
1997 logger.getOStream() <<
" (was detached)";
1998 logger.getOStream() <<
"\n";
2004 "attempting to insert into a region within a replaced/erased op");
2019 appendRewrite<CreateBlockRewrite>(block);
2030 appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
2036 appendRewrite<InlineBlockRewrite>(dest, source, before);
2043 reasonCallback(
diag);
2044 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
2054 ConversionPatternRewriter::ConversionPatternRewriter(
2058 setListener(
impl.get());
2064 return impl->config;
2068 assert(op && newOp &&
"expected non-null op");
2074 "incorrect # of replacement values");
2076 impl->logger.startLine()
2077 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
2082 if (getInsertionPoint() == op->getIterator())
2089 impl->replaceOp(op, std::move(newVals));
2095 "incorrect # of replacement values");
2097 impl->logger.startLine()
2098 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
2103 if (getInsertionPoint() == op->getIterator())
2106 impl->replaceOp(op, std::move(newValues));
2111 impl->logger.startLine()
2112 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
2117 if (getInsertionPoint() == op->getIterator())
2121 impl->replaceOp(op, std::move(nullRepls));
2125 impl->eraseBlock(block);
2132 "attempting to apply a signature conversion to a block within a "
2133 "replaced/erased op");
2134 return impl->applySignatureConversion(block, converter, conversion);
2141 "attempting to apply a signature conversion to a block within a "
2142 "replaced/erased op");
2143 return impl->convertRegionTypes(region, converter, entryConversion);
2148 impl->logger.startLine() <<
"** Replace Value : '" << from <<
"'";
2149 if (
auto blockArg = dyn_cast<BlockArgument>(from)) {
2151 impl->logger.getOStream() <<
" (in region of '" << parentOp->getName()
2152 <<
"' (" << parentOp <<
")\n";
2154 impl->logger.getOStream() <<
" (unlinked block)\n";
2158 impl->replaceAllUsesWith(from, to,
impl->currentTypeConverter);
2163 if (
failed(
impl->remapValues(
"value", std::nullopt, key,
2166 assert(remappedValues.front().size() == 1 &&
"1:N conversion not supported");
2167 return remappedValues.front().front();
2176 if (
failed(
impl->remapValues(
"value", std::nullopt, keys,
2179 for (
const auto &values : remapped) {
2180 assert(values.size() == 1 &&
"1:N conversion not supported");
2181 results.push_back(values.front());
2191 "incorrect # of argument replacement values");
2193 "attempting to inline a block from a replaced/erased op");
2195 "attempting to inline a block into a replaced/erased op");
2196 auto opIgnored = [&](
Operation *op) {
return impl->isOpIgnored(op); };
2199 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
2200 "expected 'source' to have no predecessors");
2209 bool fastPath = !getConfig().listener;
2211 if (fastPath &&
impl->config.allowPatternRollback)
2212 impl->inlineBlockBefore(source, dest, before);
2215 for (
auto it : llvm::zip(source->
getArguments(), argValues))
2216 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2223 while (!source->
empty())
2224 moveOpBefore(&source->
front(), dest, before);
2229 if (getInsertionBlock() == source)
2230 setInsertionPoint(dest, getInsertionPoint());
2237 if (!
impl->config.allowPatternRollback) {
2242 assert(!
impl->wasOpReplaced(op) &&
2243 "attempting to modify a replaced/erased op");
2245 impl->pendingRootUpdates.insert(op);
2247 impl->appendRewrite<ModifyOperationRewrite>(op);
2251 impl->patternModifiedOps.insert(op);
2252 if (!
impl->config.allowPatternRollback) {
2254 if (getConfig().listener)
2255 getConfig().listener->notifyOperationModified(op);
2262 assert(!
impl->wasOpReplaced(op) &&
2263 "attempting to modify a replaced/erased op");
2264 assert(
impl->pendingRootUpdates.erase(op) &&
2265 "operation did not have a pending in-place update");
2270 if (!
impl->config.allowPatternRollback) {
2275 assert(
impl->pendingRootUpdates.erase(op) &&
2276 "operation did not have a pending in-place update");
2279 auto it = llvm::find_if(
2280 llvm::reverse(
impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
2281 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2282 return modifyRewrite && modifyRewrite->getOperation() == op;
2284 assert(it !=
impl->rewrites.rend() &&
"no root update started on op");
2286 int updateIdx = std::prev(
impl->rewrites.rend()) - it;
2287 impl->rewrites.erase(
impl->rewrites.begin() + updateIdx);
2301 oneToOneOperands.reserve(operands.size());
2303 if (operand.size() != 1)
2306 oneToOneOperands.push_back(operand.front());
2308 return std::move(oneToOneOperands);
2315 auto &rewriterImpl = dialectRewriter.getImpl();
2319 getTypeConverter());
2328 llvm::to_vector_of<ValueRange>(remapped);
2329 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2341 class OperationLegalizer {
2361 LogicalResult legalizeWithFold(
Operation *op);
2365 LogicalResult legalizeWithPattern(
Operation *op);
2373 const RewriterState &curState,
2380 legalizePatternBlockRewrites(
Operation *op,
2396 void buildLegalizationGraph(
2397 LegalizationPatterns &anyOpLegalizerPatterns,
2408 void computeLegalizationGraphBenefit(
2409 LegalizationPatterns &anyOpLegalizerPatterns,
2414 unsigned computeOpLegalizationDepth(
2421 unsigned applyCostModelToPatterns(
2443 : rewriter(rewriter), target(targetInfo), applicator(
patterns) {
2447 LegalizationPatterns anyOpLegalizerPatterns;
2449 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2450 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2453 bool OperationLegalizer::isIllegal(
Operation *op)
const {
2454 return target.isIllegal(op);
2457 LogicalResult OperationLegalizer::legalize(
Operation *op) {
2459 const char *logLineComment =
2460 "//===-------------------------------------------===//\n";
2462 auto &logger = rewriter.getImpl().logger;
2466 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2469 logger.getOStream() <<
"\n";
2470 logger.startLine() << logLineComment;
2471 logger.startLine() <<
"Legalizing operation : ";
2476 logger.getOStream() <<
"'" << op->
getName() <<
"' ";
2477 logger.getOStream() <<
"(" << op <<
") {\n";
2482 logger.startLine() << OpWithFlags(op,
2483 OpPrintingFlags().printGenericOpForm())
2490 logSuccess(logger,
"operation marked 'ignored' during conversion");
2491 logger.startLine() << logLineComment;
2497 if (
auto legalityInfo = target.isLegal(op)) {
2500 logger,
"operation marked legal by the target{0}",
2501 legalityInfo->isRecursivelyLegal
2502 ?
"; NOTE: operation is recursively legal; skipping internals"
2504 logger.startLine() << logLineComment;
2509 if (legalityInfo->isRecursivelyLegal) {
2512 rewriter.getImpl().ignoredOps.
insert(nested);
2523 if (succeeded(legalizeWithFold(op))) {
2526 logger.startLine() << logLineComment;
2533 if (succeeded(legalizeWithPattern(op))) {
2536 logger.startLine() << logLineComment;
2544 if (succeeded(legalizeWithFold(op))) {
2547 logger.startLine() << logLineComment;
2554 logFailure(logger,
"no matched legalization pattern");
2555 logger.startLine() << logLineComment;
2562 template <
typename T>
2564 T result = std::move(obj);
2569 LogicalResult OperationLegalizer::legalizeWithFold(
Operation *op) {
2570 auto &rewriterImpl = rewriter.getImpl();
2572 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2573 rewriterImpl.
logger.indent();
2578 auto cleanup = llvm::make_scope_exit([&]() {
2589 SmallVector<Value, 2> replacementValues;
2590 SmallVector<Operation *, 2> newOps;
2593 if (
failed(rewriter.
tryFold(op, replacementValues, &newOps))) {
2602 if (replacementValues.empty())
2603 return legalize(op);
2606 rewriter.
replaceOp(op, replacementValues);
2610 if (
failed(legalize(newOp))) {
2612 "failed to legalize generated constant '{0}'",
2614 if (!rewriter.getConfig().allowPatternRollback) {
2616 llvm::report_fatal_error(
2618 "' folder rollback of IR modifications requested");
2637 auto newOpNames = llvm::map_range(
2639 auto modifiedOpNames = llvm::map_range(
2641 StringRef detachedBlockStr =
"(detached block)";
2642 auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](
Block *block) {
2645 return detachedBlockStr;
2647 llvm::report_fatal_error(
2649 "' produced IR that could not be legalized. " +
"new ops: {" +
2650 llvm::join(newOpNames,
", ") +
"}, " +
"modified ops: {" +
2651 llvm::join(modifiedOpNames,
", ") +
"}, " +
"inserted block into ops: {" +
2652 llvm::join(insertedBlockNames,
", ") +
"}");
2655 LogicalResult OperationLegalizer::legalizeWithPattern(
Operation *op) {
2656 auto &rewriterImpl = rewriter.getImpl();
2659 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2661 std::optional<OperationFingerPrint> topLevelFingerPrint;
2675 rewriterImpl.
logger.startLine()
2676 <<
"WARNING: Multi-threadeding is enabled. Some dialect "
2677 "conversion expensive checks are skipped in multithreading "
2686 auto canApply = [&](
const Pattern &pattern) {
2687 bool canApply = canApplyPattern(op, pattern);
2688 if (canApply &&
config.listener)
2689 config.listener->notifyPatternBegin(pattern, op);
2695 auto onFailure = [&](
const Pattern &pattern) {
2704 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2708 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2709 llvm::report_fatal_error(
"pattern '" + pattern.getDebugName() +
2710 "' returned failure but IR did change");
2721 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2728 config.listener->notifyPatternEnd(pattern, failure());
2729 rewriterImpl.
resetState(curState, pattern.getDebugName());
2730 appliedPatterns.erase(&pattern);
2735 auto onSuccess = [&](
const Pattern &pattern) {
2752 auto result = legalizePatternResult(op, pattern, curState, newOps,
2753 modifiedOps, insertedBlocks);
2754 appliedPatterns.erase(&pattern);
2759 rewriterImpl.
resetState(curState, pattern.getDebugName());
2762 config.listener->notifyPatternEnd(pattern, result);
2767 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2771 bool OperationLegalizer::canApplyPattern(
Operation *op,
2774 auto &os = rewriter.getImpl().logger;
2775 os.getOStream() <<
"\n";
2776 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2778 os.getOStream() <<
")' {\n";
2785 !appliedPatterns.insert(&pattern).second) {
2787 logFailure(rewriter.getImpl().logger,
"pattern was already applied"));
2793 LogicalResult OperationLegalizer::legalizePatternResult(
2798 [[maybe_unused]]
auto &
impl = rewriter.getImpl();
2799 assert(
impl.pendingRootUpdates.empty() &&
"dangling root updates");
2801 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2803 auto newRewrites = llvm::drop_begin(
impl.rewrites, curState.numRewrites);
2804 auto replacedRoot = [&] {
2805 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2807 auto updatedRootInPlace = [&] {
2808 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2810 if (!replacedRoot() && !updatedRootInPlace())
2811 llvm::report_fatal_error(
2812 "expected pattern to replace the root operation or modify it in place");
2816 if (
failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
2817 failed(legalizePatternRootUpdates(modifiedOps)) ||
2818 failed(legalizePatternCreatedOperations(newOps))) {
2822 LLVM_DEBUG(
logSuccess(
impl.logger,
"pattern applied successfully"));
2826 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2834 for (
Block *block : insertedBlocks) {
2835 if (
impl.erasedBlocks.contains(block))
2845 if (
auto *converter =
impl.regionToConverter.lookup(block->
getParent())) {
2846 std::optional<TypeConverter::SignatureConversion> conversion =
2847 converter->convertBlockSignature(block);
2849 LLVM_DEBUG(
logFailure(
impl.logger,
"failed to convert types of moved "
2853 impl.applySignatureConversion(block, converter, *conversion);
2861 if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
2862 if (
failed(legalize(parentOp))) {
2864 impl.logger,
"operation '{0}'({1}) became illegal after rewrite",
2865 parentOp->
getName(), parentOp));
2873 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2876 if (
failed(legalize(op))) {
2877 LLVM_DEBUG(
logFailure(rewriter.getImpl().logger,
2878 "failed to legalize generated operation '{0}'({1})",
2886 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2889 if (
failed(legalize(op))) {
2892 "failed to legalize operation updated in-place '{0}'",
2904 void OperationLegalizer::buildLegalizationGraph(
2905 LegalizationPatterns &anyOpLegalizerPatterns,
2916 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2917 std::optional<OperationName> root = pattern.
getRootKind();
2923 anyOpLegalizerPatterns.push_back(&pattern);
2928 if (target.getOpAction(*root) == LegalizationAction::Legal)
2933 invalidPatterns[*root].insert(&pattern);
2935 parentOps[op].insert(*root);
2938 patternWorklist.insert(&pattern);
2946 if (!anyOpLegalizerPatterns.empty()) {
2947 for (
const Pattern *pattern : patternWorklist)
2948 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2952 while (!patternWorklist.empty()) {
2953 auto *pattern = patternWorklist.pop_back_val();
2957 std::optional<LegalizationAction> action = target.getOpAction(op);
2958 return !legalizerPatterns.count(op) &&
2959 (!action || action == LegalizationAction::Illegal);
2965 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2966 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2970 for (
auto op : parentOps[*pattern->
getRootKind()])
2971 patternWorklist.set_union(invalidPatterns[op]);
2975 void OperationLegalizer::computeLegalizationGraphBenefit(
2976 LegalizationPatterns &anyOpLegalizerPatterns,
2982 for (
auto &opIt : legalizerPatterns)
2983 if (!minOpPatternDepth.count(opIt.first))
2984 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2990 if (!anyOpLegalizerPatterns.empty())
2991 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2997 applicator.applyCostModel([&](
const Pattern &pattern) {
2999 if (std::optional<OperationName> rootName = pattern.
getRootKind())
3000 orderedPatternList = legalizerPatterns[*rootName];
3002 orderedPatternList = anyOpLegalizerPatterns;
3005 auto *it = llvm::find(orderedPatternList, &pattern);
3006 if (it == orderedPatternList.end())
3010 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3014 unsigned OperationLegalizer::computeOpLegalizationDepth(
3018 auto depthIt = minOpPatternDepth.find(op);
3019 if (depthIt != minOpPatternDepth.end())
3020 return depthIt->second;
3024 auto opPatternsIt = legalizerPatterns.find(op);
3025 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3034 unsigned minDepth = applyCostModelToPatterns(
3035 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3036 minOpPatternDepth[op] = minDepth;
3040 unsigned OperationLegalizer::applyCostModelToPatterns(
3047 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3048 patternsByDepth.reserve(
patterns.size());
3052 unsigned generatedOpDepth = computeOpLegalizationDepth(
3053 generatedOp, minOpPatternDepth, legalizerPatterns);
3054 depth =
std::max(depth, generatedOpDepth + 1);
3056 patternsByDepth.emplace_back(pattern, depth);
3059 minDepth =
std::min(minDepth, depth);
3064 if (patternsByDepth.size() == 1)
3068 llvm::stable_sort(patternsByDepth,
3069 [](
const std::pair<const Pattern *, unsigned> &lhs,
3070 const std::pair<const Pattern *, unsigned> &rhs) {
3073 if (lhs.second != rhs.second)
3074 return lhs.second < rhs.second;
3077 auto lhsBenefit = lhs.first->getBenefit();
3078 auto rhsBenefit = rhs.first->getBenefit();
3079 return lhsBenefit > rhsBenefit;
3084 for (
auto &patternIt : patternsByDepth)
3085 patterns.push_back(patternIt.first);
3099 template <
typename RangeT>
3102 function_ref<
bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3111 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3112 if (castOp.getInputs().empty())
3115 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3118 if (inputCastOp.getOutputs() != castOp.getInputs())
3124 while (!worklist.empty()) {
3125 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3129 UnrealizedConversionCastOp nextCast = castOp;
3131 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3132 if (llvm::any_of(nextCast.getInputs(), [&](
Value v) {
3133 return v.getDefiningOp() == castOp;
3141 castOp.replaceAllUsesWith(nextCast.getInputs());
3144 nextCast = getInputCast(nextCast);
3154 auto markOpLive = [&](
Operation *rootOp) {
3155 SmallVector<Operation *> worklist;
3156 worklist.push_back(rootOp);
3157 while (!worklist.empty()) {
3158 Operation *op = worklist.pop_back_val();
3159 if (liveOps.insert(op).second) {
3162 if (
auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3163 if (isCastOpOfInterestFn(castOp))
3164 worklist.push_back(castOp);
3170 for (UnrealizedConversionCastOp op : castOps) {
3173 if (liveOps.contains(op.getOperation()))
3177 if (llvm::any_of(op->getUsers(), [&](
Operation *user) {
3178 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3179 return !castOp || !isCastOpOfInterestFn(castOp);
3185 for (UnrealizedConversionCastOp op : castOps) {
3186 if (liveOps.contains(op)) {
3188 if (remainingCastOps)
3189 remainingCastOps->push_back(op);
3204 for (UnrealizedConversionCastOp op : castOps)
3205 castOpSet.insert(op);
3213 llvm::make_range(castOps.begin(), castOps.end()),
3214 [&](UnrealizedConversionCastOp castOp) {
3215 return castOps.contains(castOp);
3227 [&](UnrealizedConversionCastOp castOp) {
3228 return castOps.contains(castOp);
3239 enum OpConversionMode {
3262 OpConversionMode mode)
3263 : rewriter(ctx,
config), opLegalizer(rewriter, target,
patterns),
3277 OperationLegalizer opLegalizer;
3280 OpConversionMode mode;
3284 LogicalResult OperationConverter::convert(
Operation *op) {
3288 if (
failed(opLegalizer.legalize(op))) {
3291 if (mode == OpConversionMode::Full)
3293 <<
"failed to legalize operation '" << op->
getName() <<
"'";
3297 if (mode == OpConversionMode::Partial) {
3298 if (opLegalizer.isIllegal(op))
3300 <<
"failed to legalize operation '" << op->
getName()
3301 <<
"' that was explicitly marked illegal";
3302 if (
config.unlegalizedOps)
3303 config.unlegalizedOps->insert(op);
3305 }
else if (mode == OpConversionMode::Analysis) {
3309 if (
config.legalizableOps)
3310 config.legalizableOps->insert(op);
3315 static LogicalResult
3317 UnrealizedConversionCastOp op,
3318 const UnresolvedMaterializationInfo &info) {
3319 assert(!op.use_empty() &&
3320 "expected that dead materializations have already been DCE'd");
3326 SmallVector<Value> newMaterialization;
3327 switch (info.getMaterializationKind()) {
3328 case MaterializationKind::Target:
3329 newMaterialization = converter->materializeTargetConversion(
3330 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3331 info.getOriginalType());
3333 case MaterializationKind::Source:
3334 assert(op->getNumResults() == 1 &&
"expected single result");
3335 Value sourceMat = converter->materializeSourceConversion(
3336 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3338 newMaterialization.push_back(sourceMat);
3341 if (!newMaterialization.empty()) {
3343 ValueRange newMaterializationRange(newMaterialization);
3344 assert(
TypeRange(newMaterializationRange) == op.getResultTypes() &&
3345 "materialization callback produced value of incorrect type");
3347 rewriter.
replaceOp(op, newMaterialization);
3353 <<
"failed to legalize unresolved materialization "
3355 << inputOperands.
getTypes() <<
") to ("
3356 << op.getResultTypes()
3357 <<
") that remained live after conversion";
3358 diag.attachNote(op->getUsers().begin()->getLoc())
3359 <<
"see existing live user here: " << *op->getUsers().begin();
3368 for (
auto *op : ops) {
3371 toConvert.push_back(op);
3374 auto legalityInfo = target.
isLegal(op);
3375 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3384 for (
auto *op : toConvert) {
3385 if (
failed(convert(op))) {
3411 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3415 if (rewriter.getConfig().buildMaterializations) {
3420 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3421 auto it = materializations.find(castOp);
3422 assert(it != materializations.end() &&
"inconsistent state");
3438 assert(!types.empty() &&
"expected valid types");
3439 remapInput(origInputNo, argTypes.size(), types.size());
3444 assert(!types.empty() &&
3445 "1->0 type remappings don't need to be added explicitly");
3446 argTypes.append(types.begin(), types.end());
3450 unsigned newInputNo,
3451 unsigned newInputCount) {
3452 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3453 assert(newInputCount != 0 &&
"expected valid input count");
3454 remappedInputs[origInputNo] =
3455 InputMapping{newInputNo, newInputCount, {}};
3460 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3476 assert(typeOrValue &&
"expected non-null type");
3477 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3478 : cast<Type>(typeOrValue);
3480 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3483 cacheReadLock.lock();
3484 auto existingIt = cachedDirectConversions.find(t);
3485 if (existingIt != cachedDirectConversions.end()) {
3486 if (existingIt->second)
3487 results.push_back(existingIt->second);
3488 return success(existingIt->second !=
nullptr);
3490 auto multiIt = cachedMultiConversions.find(t);
3491 if (multiIt != cachedMultiConversions.end()) {
3492 results.append(multiIt->second.begin(), multiIt->second.end());
3498 size_t currentCount = results.size();
3502 auto isCacheable = [&](
int index) {
3503 int numberOfConversionsUntilContextAware =
3504 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3505 return index < numberOfConversionsUntilContextAware;
3508 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3511 for (
auto indexedConverter :
llvm::enumerate(llvm::reverse(conversions))) {
3512 const ConversionCallbackFn &converter = indexedConverter.value();
3513 std::optional<LogicalResult> result = converter(typeOrValue, results);
3515 assert(results.size() == currentCount &&
3516 "failed type conversion should not change results");
3519 if (!isCacheable(indexedConverter.index()))
3522 cacheWriteLock.lock();
3523 if (!succeeded(*result)) {
3524 assert(results.size() == currentCount &&
3525 "failed type conversion should not change results");
3526 cachedDirectConversions.try_emplace(t,
nullptr);
3529 auto newTypes =
ArrayRef<Type>(results).drop_front(currentCount);
3530 if (newTypes.size() == 1)
3531 cachedDirectConversions.try_emplace(t, newTypes.front());
3533 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3541 return convertTypeImpl(t, results);
3546 return convertTypeImpl(v, results);
3556 return results.size() == 1 ? results.front() :
nullptr;
3566 return results.size() == 1 ? results.front() :
nullptr;
3572 for (
Type type : types)
3581 for (
Value value : values)
3600 return llvm::all_of(
3607 if (!
isLegal(ty.getResults()))
3621 if (convertedTypes.empty())
3625 result.
addInputs(inputNo, convertedTypes);
3631 unsigned origInputOffset)
const {
3632 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3646 if (convertedTypes.empty())
3650 result.
addInputs(inputNo, convertedTypes);
3656 unsigned origInputOffset)
const {
3657 for (
unsigned i = 0, e = values.size(); i != e; ++i)
3666 for (
const SourceMaterializationCallbackFn &fn :
3667 llvm::reverse(sourceMaterializations))
3668 if (
Value result = fn(builder, resultType, inputs, loc))
3676 Type originalType)
const {
3678 builder, loc,
TypeRange(resultType), inputs, originalType);
3681 assert(result.size() == 1 &&
"expected single result");
3682 return result.front();
3687 Type originalType)
const {
3688 for (
const TargetMaterializationCallbackFn &fn :
3689 llvm::reverse(targetMaterializations)) {
3691 fn(builder, resultTypes, inputs, loc, originalType);
3695 "callback produced incorrect number of values or values with "
3702 std::optional<TypeConverter::SignatureConversion>
3706 return std::nullopt;
3729 return impl.getInt() == resultTag;
3733 return impl.getInt() == naTag;
3737 return impl.getInt() == abortTag;
3741 assert(hasResult() &&
"Cannot get result from N/A or abort");
3742 return impl.getPointer();
3745 std::optional<Attribute>
3747 for (
const TypeAttributeConversionCallbackFn &fn :
3748 llvm::reverse(typeAttributeConversions)) {
3753 return std::nullopt;
3755 return std::nullopt;
3765 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3771 SmallVector<Type, 1> newResults;
3775 typeConverter, &result)))
3792 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3801 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3806 struct AnyFunctionOpInterfaceSignatureConversion
3818 FailureOr<Operation *>
3822 assert(op &&
"Invalid op");
3836 return rewriter.
create(newOp);
3842 patterns.add<FunctionOpInterfaceSignatureConversion>(
3843 functionLikeOpName,
patterns.getContext(), converter, benefit);
3849 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3850 converter,
patterns.getContext(), benefit);
3859 legalOperations[op].action = action;
3864 for (StringRef dialect : dialectNames)
3865 legalDialects[dialect] = action;
3869 -> std::optional<LegalizationAction> {
3870 std::optional<LegalizationInfo> info = getOpInfo(op);
3871 return info ? info->action : std::optional<LegalizationAction>();
3875 -> std::optional<LegalOpDetails> {
3876 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3878 return std::nullopt;
3881 auto isOpLegal = [&] {
3883 if (info->action == LegalizationAction::Dynamic) {
3884 std::optional<bool> result = info->legalityFn(op);
3890 return info->action == LegalizationAction::Legal;
3893 return std::nullopt;
3897 if (info->isRecursivelyLegal) {
3898 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3899 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3901 legalityFnIt->second(op).value_or(
true);
3906 return legalityDetails;
3910 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3914 if (info->action == LegalizationAction::Dynamic) {
3915 std::optional<bool> result = info->legalityFn(op);
3922 return info->action == LegalizationAction::Illegal;
3931 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3933 if (std::optional<bool> result = newCl(op))
3941 void ConversionTarget::setLegalityCallback(
3942 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3943 assert(callback &&
"expected valid legality callback");
3944 auto *infoIt = legalOperations.find(name);
3945 assert(infoIt != legalOperations.end() &&
3946 infoIt->second.action == LegalizationAction::Dynamic &&
3947 "expected operation to already be marked as dynamically legal");
3948 infoIt->second.legalityFn =
3954 auto *infoIt = legalOperations.find(name);
3955 assert(infoIt != legalOperations.end() &&
3956 infoIt->second.action != LegalizationAction::Illegal &&
3957 "expected operation to already be marked as legal");
3958 infoIt->second.isRecursivelyLegal =
true;
3961 std::move(opRecursiveLegalityFns[name]), callback);
3963 opRecursiveLegalityFns.erase(name);
3966 void ConversionTarget::setLegalityCallback(
3968 assert(callback &&
"expected valid legality callback");
3969 for (StringRef dialect : dialects)
3971 std::move(dialectLegalityFns[dialect]), callback);
3974 void ConversionTarget::setLegalityCallback(
3975 const DynamicLegalityCallbackFn &callback) {
3976 assert(callback &&
"expected valid legality callback");
3981 -> std::optional<LegalizationInfo> {
3983 const auto *it = legalOperations.find(op);
3984 if (it != legalOperations.end())
3987 auto dialectIt = legalDialects.find(op.getDialectNamespace());
3988 if (dialectIt != legalDialects.end()) {
3989 DynamicLegalityCallbackFn callback;
3990 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3991 if (dialectFn != dialectLegalityFns.end())
3992 callback = dialectFn->second;
3993 return LegalizationInfo{dialectIt->second,
false,
3997 if (unknownLegalityFn)
3998 return LegalizationInfo{LegalizationAction::Dynamic,
3999 false, unknownLegalityFn};
4000 return std::nullopt;
4003 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4009 auto &rewriterImpl =
4015 auto &rewriterImpl =
4022 static FailureOr<SmallVector<Value>>
4024 SmallVector<Value> mappedValues;
4027 return std::move(mappedValues);
4031 patterns.getPDLPatterns().registerRewriteFunction(
4038 return results->front();
4040 patterns.getPDLPatterns().registerRewriteFunction(
4045 patterns.getPDLPatterns().registerRewriteFunction(
4048 auto &rewriterImpl =
4052 if (
Type newType = converter->convertType(type))
4058 patterns.getPDLPatterns().registerRewriteFunction(
4061 TypeRange types) -> FailureOr<SmallVector<Type>> {
4062 auto &rewriterImpl =
4071 return std::move(remappedTypes);
4086 static constexpr StringLiteral tag =
"apply-conversion";
4087 static constexpr StringLiteral desc =
4088 "Encapsulate the application of a dialect conversion";
4090 void print(raw_ostream &os)
const override { os << tag; }
4097 OpConversionMode mode) {
4101 LogicalResult status = success();
4102 SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
4121 OpConversionMode::Partial);
4161 "expected top-level op to be isolated from above");
4164 "expected ops to have a common ancestor");
4173 for (
Operation *op : ops.drop_front()) {
4177 assert(commonAncestor &&
4178 "expected to find a common isolated from above ancestor");
4182 return commonAncestor;
4189 if (
config.legalizableOps)
4190 assert(
config.legalizableOps->empty() &&
"expected empty set");
4200 inverseOperationMap[it.second] = it.first;
4206 OpConversionMode::Analysis);
4210 if (
config.legalizableOps) {
4213 originalLegalizableOps.insert(inverseOperationMap[op]);
4214 *
config.legalizableOps = std::move(originalLegalizableOps);
4218 clonedAncestor->
erase();
static void setInsertionPointAfter(OpBuilder &b, Value value)
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value >> &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps, const SetVector< Block * > &insertedBlocks)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This is the type of Action that is dispatched when a conversion is applied.
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
const ConversionConfig & getConfig() const
Return the configuration of the current dialect conversion.
void replaceAllUsesWith(Value from, ValueRange to)
Replace all the uses of from with to.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={}) override
PatternRewriter hook for inlining the ops of a block into another block.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
~ConversionPatternRewriter() override
Base class for the conversion patterns.
FailureOr< SmallVector< Value > > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Listener * listener
The optional listener for events of this builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
type_range getTypes() const
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
int getPropertiesStorageSize() const
Returns the properties storage size.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
The general result of a type attribute conversion callback, allowing for early termination.
Attribute getResult() const
static AttributeConversionResult abort()
static AttributeConversionResult na()
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
CRTP Implementation of an action.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
@ AfterPatterns
Only attempt to fold not legal operations after applying patterns.
@ BeforePatterns
Only attempt to fold not legal operations before applying patterns.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
void reconcileUnrealizedCasts(const DenseSet< UnrealizedConversionCastOp > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
bool allowPatternRollback
If set to "true", pattern rollback is allowed.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
bool attachDebugMaterializationKind
If set to "true", the materialization kind ("source" or "target") will be attached to "builtin....
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
bool wasErased(void *ptr) const
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config)
void replaceAllUsesWith(Value from, ValueRange to, const TypeConverter *converter)
Replace the uses of the given value with the given values.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Block * > patternInsertedBlocks
A set of blocks that were inserted (newly-created blocks or moved blocks) by the current pattern.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
void replaceOp(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the results of the given operation with the given values and erase the operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.