10 #include "mlir/Config/mlir-config.h"
20 #include "llvm/ADT/ScopeExit.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "llvm/Support/SaveAndRestore.h"
26 #include "llvm/Support/ScopedPrinter.h"
32 #define DEBUG_TYPE "dialect-conversion"
35 template <
typename... Args>
36 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
39 os.startLine() <<
"} -> SUCCESS";
41 os.getOStream() <<
" : "
42 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
43 os.getOStream() <<
"\n";
48 template <
typename... Args>
49 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
52 os.startLine() <<
"} -> FAILURE : "
53 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
63 if (
OpResult inputRes = dyn_cast<OpResult>(value))
64 insertPt = ++inputRes.getOwner()->getIterator();
71 assert(!vals.empty() &&
"expected at least one value");
74 for (
Value v : vals.drop_front()) {
88 assert(dom &&
"unable to find valid insertion point");
106 struct ValueVectorMapInfo {
109 static ::llvm::hash_code getHashValue(
const ValueVector &val) {
110 return ::llvm::hash_combine_range(val);
119 struct ConversionValueMapping {
122 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
127 template <
typename T>
128 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
131 template <
typename OldVal,
typename NewVal>
132 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
133 map(OldVal &&oldVal, NewVal &&newVal) {
137 assert(next != oldVal &&
"inserting cyclic mapping");
138 auto it = mapping.find(next);
139 if (it == mapping.end())
144 mappedTo.insert_range(newVal);
146 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
150 template <
typename OldVal,
typename NewVal>
151 std::enable_if_t<!IsValueVector<OldVal>::value ||
152 !IsValueVector<NewVal>::value>
153 map(OldVal &&oldVal, NewVal &&newVal) {
154 if constexpr (IsValueVector<OldVal>{}) {
155 map(std::forward<OldVal>(oldVal),
ValueVector{newVal});
156 }
else if constexpr (IsValueVector<NewVal>{}) {
157 map(
ValueVector{oldVal}, std::forward<NewVal>(newVal));
163 void map(
Value oldVal, SmallVector<Value> &&newVal) {
168 void erase(
const ValueVector &value) { mapping.erase(value); }
188 assert(!values.empty() &&
"expected non-empty value vector");
189 Operation *op = values.front().getDefiningOp();
190 for (
Value v : llvm::drop_begin(values)) {
191 if (v.getDefiningOp() != op)
201 assert(!values.empty() &&
"expected non-empty value vector");
207 auto it = mapping.find(from);
208 if (it == mapping.end()) {
221 struct RewriterState {
222 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
223 unsigned numReplacedOps)
224 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
225 numReplacedOps(numReplacedOps) {}
228 unsigned numRewrites;
231 unsigned numIgnoredOperations;
234 unsigned numReplacedOps;
246 notifyIRErased(listener, op);
255 notifyIRErased(listener, b);
286 UnresolvedMaterialization
289 virtual ~IRRewrite() =
default;
292 virtual void rollback() = 0;
311 Kind getKind()
const {
return kind; }
313 static bool classof(
const IRRewrite *
rewrite) {
return true; }
317 :
kind(
kind), rewriterImpl(rewriterImpl) {}
326 class BlockRewrite :
public IRRewrite {
329 Block *getBlock()
const {
return block; }
331 static bool classof(
const IRRewrite *
rewrite) {
332 return rewrite->getKind() >= Kind::CreateBlock &&
333 rewrite->getKind() <= Kind::ReplaceBlockArg;
339 : IRRewrite(
kind, rewriterImpl), block(block) {}
348 class CreateBlockRewrite :
public BlockRewrite {
351 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
353 static bool classof(
const IRRewrite *
rewrite) {
354 return rewrite->getKind() == Kind::CreateBlock;
363 void rollback()
override {
366 auto &blockOps = block->getOperations();
367 while (!blockOps.empty())
368 blockOps.remove(blockOps.begin());
369 block->dropAllUses();
370 if (block->getParent())
381 class EraseBlockRewrite :
public BlockRewrite {
384 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
385 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
387 static bool classof(
const IRRewrite *
rewrite) {
388 return rewrite->getKind() == Kind::EraseBlock;
391 ~EraseBlockRewrite()
override {
393 "rewrite was neither rolled back nor committed/cleaned up");
396 void rollback()
override {
399 assert(block &&
"expected block");
400 auto &blockList = region->getBlocks();
404 blockList.insert(before, block);
409 assert(block &&
"expected block");
413 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
414 notifyIRErased(listener, *block);
419 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
421 assert(block->empty() &&
"expected empty block");
424 block->dropAllDefinedValueUses();
435 Block *insertBeforeBlock;
441 class InlineBlockRewrite :
public BlockRewrite {
445 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
446 sourceBlock(sourceBlock),
447 firstInlinedInst(sourceBlock->empty() ? nullptr
448 : &sourceBlock->front()),
449 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
455 assert(!getConfig().listener &&
456 "InlineBlockRewrite not supported if listener is attached");
459 static bool classof(
const IRRewrite *
rewrite) {
460 return rewrite->getKind() == Kind::InlineBlock;
463 void rollback()
override {
466 if (firstInlinedInst) {
467 assert(lastInlinedInst &&
"expected operation");
487 class MoveBlockRewrite :
public BlockRewrite {
491 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block),
492 region(previousRegion),
493 insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
496 static bool classof(
const IRRewrite *
rewrite) {
497 return rewrite->getKind() == Kind::MoveBlock;
510 void rollback()
override {
523 Block *insertBeforeBlock;
527 class BlockTypeConversionRewrite :
public BlockRewrite {
531 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, origBlock),
532 newBlock(newBlock) {}
534 static bool classof(
const IRRewrite *
rewrite) {
535 return rewrite->getKind() == Kind::BlockTypeConversion;
538 Block *getOrigBlock()
const {
return block; }
540 Block *getNewBlock()
const {
return newBlock; }
544 void rollback()
override;
554 class ReplaceBlockArgRewrite :
public BlockRewrite {
559 : BlockRewrite(
Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
560 converter(converter) {}
562 static bool classof(
const IRRewrite *
rewrite) {
563 return rewrite->getKind() == Kind::ReplaceBlockArg;
568 void rollback()
override;
578 class OperationRewrite :
public IRRewrite {
581 Operation *getOperation()
const {
return op; }
583 static bool classof(
const IRRewrite *
rewrite) {
584 return rewrite->getKind() >= Kind::MoveOperation &&
585 rewrite->getKind() <= Kind::UnresolvedMaterialization;
591 : IRRewrite(
kind, rewriterImpl), op(op) {}
598 class MoveOperationRewrite :
public OperationRewrite {
602 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op),
603 block(previous.getBlock()),
604 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
606 : &*previous.getPoint()) {}
608 static bool classof(
const IRRewrite *
rewrite) {
609 return rewrite->getKind() == Kind::MoveOperation;
623 void rollback()
override {
641 class ModifyOperationRewrite :
public OperationRewrite {
645 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
646 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
647 operands(op->operand_begin(), op->operand_end()),
648 successors(op->successor_begin(), op->successor_end()) {
653 name.initOpProperties(propCopy, prop);
657 static bool classof(
const IRRewrite *
rewrite) {
658 return rewrite->getKind() == Kind::ModifyOperation;
661 ~ModifyOperationRewrite()
override {
662 assert(!propertiesStorage &&
663 "rewrite was neither committed nor rolled back");
669 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
672 if (propertiesStorage) {
676 name.destroyOpProperties(propCopy);
677 operator delete(propertiesStorage);
678 propertiesStorage =
nullptr;
682 void rollback()
override {
688 if (propertiesStorage) {
691 name.destroyOpProperties(propCopy);
692 operator delete(propertiesStorage);
693 propertiesStorage =
nullptr;
700 DictionaryAttr attrs;
701 SmallVector<Value, 8> operands;
702 SmallVector<Block *, 2> successors;
703 void *propertiesStorage =
nullptr;
710 class ReplaceOperationRewrite :
public OperationRewrite {
714 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
715 converter(converter) {}
717 static bool classof(
const IRRewrite *
rewrite) {
718 return rewrite->getKind() == Kind::ReplaceOperation;
723 void rollback()
override;
733 class CreateOperationRewrite :
public OperationRewrite {
737 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
739 static bool classof(
const IRRewrite *
rewrite) {
740 return rewrite->getKind() == Kind::CreateOperation;
749 void rollback()
override;
753 enum MaterializationKind {
764 class UnresolvedMaterializationInfo {
766 UnresolvedMaterializationInfo() =
default;
767 UnresolvedMaterializationInfo(
const TypeConverter *converter,
768 MaterializationKind
kind,
Type originalType)
769 : converterAndKind(converter,
kind), originalType(originalType) {}
773 return converterAndKind.getPointer();
777 MaterializationKind getMaterializationKind()
const {
778 return converterAndKind.getInt();
782 Type getOriginalType()
const {
return originalType; }
787 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
798 class UnresolvedMaterializationRewrite :
public OperationRewrite {
801 UnrealizedConversionCastOp op,
803 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
804 mappedValues(std::move(mappedValues)) {}
806 static bool classof(
const IRRewrite *
rewrite) {
807 return rewrite->getKind() == Kind::UnresolvedMaterialization;
810 void rollback()
override;
812 UnrealizedConversionCastOp getOperation()
const {
813 return cast<UnrealizedConversionCastOp>(op);
823 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
826 template <
typename RewriteTy,
typename R>
827 static bool hasRewrite(R &&rewrites,
Operation *op) {
828 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
829 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
830 return rewriteTy && rewriteTy->getOperation() == op;
836 template <
typename RewriteTy,
typename R>
837 static bool hasRewrite(R &&rewrites,
Block *block) {
838 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
839 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
840 return rewriteTy && rewriteTy->getBlock() == block;
861 RewriterState getCurrentState();
865 void applyRewrites();
870 void resetState(RewriterState state, StringRef patternName =
"");
874 template <
typename RewriteTy,
typename... Args>
876 assert(
config.allowPatternRollback &&
"appending rewrites is not allowed");
878 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
884 void undoRewrites(
unsigned numRewritesToKeep = 0, StringRef patternName =
"");
890 LogicalResult remapValues(StringRef valueDiagTag,
891 std::optional<Location> inputLoc,
ValueRange values,
908 bool skipPureTypeConversions =
false)
const;
930 Block *applySignatureConversion(
941 void replaceOp(
Operation *op, SmallVector<SmallVector<Value>> &&newValues);
949 void eraseBlock(
Block *block);
987 Value findOrBuildReplacementValue(
Value value,
995 void notifyOperationInserted(
Operation *op,
999 void notifyBlockInserted(
Block *block,
Region *previous,
1018 std::function<
void(
Operation *)> opErasedCallback =
nullptr)
1020 opErasedCallback(opErasedCallback) {}
1032 if (wasErased(block))
1034 assert(block->
empty() &&
"expected empty block");
1039 bool wasErased(
void *ptr)
const {
return erased.contains(ptr); }
1043 if (opErasedCallback)
1044 opErasedCallback(op);
1054 std::function<void(
Operation *)> opErasedCallback;
1143 llvm::impl::raw_ldbg_ostream os{(Twine(
"[") +
DEBUG_TYPE +
":1] ").str(),
1147 llvm::ScopedPrinter logger{os};
1155 return rewriterImpl.
config;
1158 void BlockTypeConversionRewrite::commit(
RewriterBase &rewriter) {
1162 if (
auto *listener =
1163 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1164 for (
Operation *op : getNewBlock()->getUsers())
1168 void BlockTypeConversionRewrite::rollback() {
1169 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1174 if (isa<BlockArgument>(repl)) {
1182 Operation *replOp = cast<OpResult>(repl).getOwner();
1190 void ReplaceBlockArgRewrite::commit(
RewriterBase &rewriter) {
1197 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.
mapping.erase({arg}); }
1199 void ReplaceOperationRewrite::commit(
RewriterBase &rewriter) {
1201 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1204 SmallVector<Value> replacements =
1206 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1214 for (
auto [result, newValue] :
1215 llvm::zip_equal(op->
getResults(), replacements))
1221 if (getConfig().unlegalizedOps)
1222 getConfig().unlegalizedOps->erase(op);
1226 notifyIRErased(listener, *op);
1233 void ReplaceOperationRewrite::rollback() {
1235 rewriterImpl.
mapping.erase({result});
1238 void ReplaceOperationRewrite::cleanup(
RewriterBase &rewriter) {
1242 void CreateOperationRewrite::rollback() {
1244 while (!region.getBlocks().empty())
1245 region.getBlocks().remove(region.getBlocks().begin());
1251 void UnresolvedMaterializationRewrite::rollback() {
1252 if (!mappedValues.empty())
1253 rewriterImpl.
mapping.erase(mappedValues);
1264 for (
size_t i = 0; i <
rewrites.size(); ++i)
1270 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1271 unresolvedMaterializations.erase(castOp);
1274 rewrite->cleanup(eraseRewriter);
1282 Value from,
TypeRange desiredTypes,
bool skipPureTypeConversions)
const {
1285 assert(!values.empty() &&
"expected non-empty value vector");
1290 return mapping.lookup(values);
1297 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1302 if (castOp.getOutputs() != values)
1304 return castOp.getInputs();
1313 for (
Value v : values) {
1316 llvm::append_range(next, r);
1321 if (next != values) {
1350 if (skipPureTypeConversions) {
1353 match &= !pureConversion;
1356 if (!pureConversion)
1357 lastNonMaterialization = current;
1360 desiredValue = current;
1366 current = std::move(next);
1371 if (!desiredTypes.empty())
1372 return desiredValue;
1373 if (skipPureTypeConversions)
1374 return lastNonMaterialization;
1393 StringRef patternName) {
1398 while (
ignoredOps.size() != state.numIgnoredOperations)
1401 while (
replacedOps.size() != state.numReplacedOps)
1406 StringRef patternName) {
1408 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1410 rewrites.resize(numRewritesToKeep);
1414 StringRef valueDiagTag, std::optional<Location> inputLoc,
ValueRange values,
1416 remapped.reserve(llvm::size(values));
1419 Value operand = it.value();
1438 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1439 << it.index() <<
", type was " << origType;
1444 if (legalTypes.empty()) {
1445 remapped.push_back({});
1454 remapped.push_back(std::move(repl));
1463 repl, repl, legalTypes,
1465 remapped.push_back(castValues);
1488 if (region->
empty())
1493 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1495 std::optional<TypeConverter::SignatureConversion> conversion =
1505 if (entryConversion)
1508 std::optional<TypeConverter::SignatureConversion> conversion =
1518 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1520 if (hasRewrite<BlockTypeConversionRewrite>(
rewrites, block))
1521 llvm::report_fatal_error(
"block was already converted");
1535 for (
unsigned i = 0; i < origArgCount; ++i) {
1537 if (!inputMap || inputMap->replacedWithValues())
1540 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1541 newLocs[inputMap->inputNo +
j] = origLoc;
1548 convertedTypes, newLocs);
1559 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->
end());
1562 while (!block->
empty())
1569 for (
unsigned i = 0; i != origArgCount; ++i) {
1573 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1580 MaterializationKind::Source,
1584 origArgType,
Type(), converter,
1591 if (inputMap->replacedWithValues()) {
1593 assert(inputMap->size == 0 &&
1594 "invalid to provide a replacement value when the argument isn't "
1603 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1608 appendRewrite<BlockTypeConversionRewrite>(block, newBlock);
1628 assert((!originalType ||
kind == MaterializationKind::Target) &&
1629 "original type is valid only for target materializations");
1630 assert(
TypeRange(inputs) != outputTypes &&
1631 "materialization is not necessary");
1635 OpBuilder builder(outputTypes.front().getContext());
1637 UnrealizedConversionCastOp convertOp =
1638 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1641 kind == MaterializationKind::Source ?
"source" :
"target";
1642 convertOp->setAttr(
"__kind__", builder.
getStringAttr(kindStr));
1649 UnresolvedMaterializationInfo(converter,
kind, originalType);
1651 if (!valuesToMap.empty())
1652 mapping.map(valuesToMap, convertOp.getResults());
1653 appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
1654 std::move(valuesToMap));
1658 return convertOp.getResults();
1664 "this code path is valid only in rollback mode");
1671 return repl.front();
1678 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1719 bool wasDetached = !previous.
isSet();
1721 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"' (" << op
1724 logger.getOStream() <<
" (was detached)";
1725 logger.getOStream() <<
"\n";
1731 "attempting to insert into a block within a replaced/erased op");
1748 appendRewrite<CreateOperationRewrite>(op);
1759 appendRewrite<MoveOperationRewrite>(op, previous);
1766 const SmallVector<SmallVector<Value>> &toRange,
1768 assert(!
impl.config.allowPatternRollback &&
1769 "this code path is valid only in 'no rollback' mode");
1770 SmallVector<Value> repls;
1771 for (
auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1774 repls.push_back(
Value());
1781 Value srcMat =
impl.buildUnresolvedMaterialization(
1786 repls.push_back(srcMat);
1792 repls.push_back(to[0]);
1801 Value srcMat =
impl.buildUnresolvedMaterialization(
1804 Type(), converter)[0];
1805 repls.push_back(srcMat);
1814 "incorrect number of replacement values");
1825 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1840 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1844 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1849 "attempting to replace/erase an unresolved materialization");
1853 for (
auto [repl, result] : llvm::zip_equal(newValues, op->
getResults())) {
1868 mapping.map(
static_cast<Value>(result), std::move(repl));
1883 Value repl = repls.front();
1900 "attempting to replace a block argument that was already replaced");
1904 appendRewrite<ReplaceBlockArgRewrite>(from.
getOwner(), from, converter);
1916 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1932 "attempting to erase a block within a replaced/erased op");
1933 appendRewrite<EraseBlockRewrite>(block);
1948 bool wasDetached = !previous;
1954 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
1955 <<
"' (" << parent <<
")";
1958 <<
"** Insert Block into detached Region (nullptr parent op)";
1961 logger.getOStream() <<
" (was detached)";
1962 logger.getOStream() <<
"\n";
1968 "attempting to insert into a region within a replaced/erased op");
1983 appendRewrite<CreateBlockRewrite>(block);
1994 appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
2000 appendRewrite<InlineBlockRewrite>(dest, source, before);
2007 reasonCallback(
diag);
2008 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
2018 ConversionPatternRewriter::ConversionPatternRewriter(
2022 setListener(
impl.get());
2028 return impl->config;
2032 assert(op && newOp &&
"expected non-null op");
2038 "incorrect # of replacement values");
2040 impl->logger.startLine()
2041 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
2046 if (getInsertionPoint() == op->getIterator())
2053 impl->replaceOp(op, std::move(newVals));
2059 "incorrect # of replacement values");
2061 impl->logger.startLine()
2062 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
2067 if (getInsertionPoint() == op->getIterator())
2070 impl->replaceOp(op, std::move(newValues));
2075 impl->logger.startLine()
2076 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
2081 if (getInsertionPoint() == op->getIterator())
2085 impl->replaceOp(op, std::move(nullRepls));
2089 impl->eraseBlock(block);
2096 "attempting to apply a signature conversion to a block within a "
2097 "replaced/erased op");
2098 return impl->applySignatureConversion(block, converter, conversion);
2105 "attempting to apply a signature conversion to a block within a "
2106 "replaced/erased op");
2107 return impl->convertRegionTypes(region, converter, entryConversion);
2113 impl->logger.startLine() <<
"** Replace Argument : '" << from <<
"'";
2115 impl->logger.getOStream() <<
" (in region of '" << parentOp->getName()
2116 <<
"' (" << parentOp <<
")\n";
2118 impl->logger.getOStream() <<
" (unlinked block)\n";
2121 impl->replaceUsesOfBlockArgument(from, to,
impl->currentTypeConverter);
2126 if (
failed(
impl->remapValues(
"value", std::nullopt, key,
2129 assert(remappedValues.front().size() == 1 &&
"1:N conversion not supported");
2130 return remappedValues.front().front();
2139 if (
failed(
impl->remapValues(
"value", std::nullopt, keys,
2142 for (
const auto &values : remapped) {
2143 assert(values.size() == 1 &&
"1:N conversion not supported");
2144 results.push_back(values.front());
2154 "incorrect # of argument replacement values");
2156 "attempting to inline a block from a replaced/erased op");
2158 "attempting to inline a block into a replaced/erased op");
2159 auto opIgnored = [&](
Operation *op) {
return impl->isOpIgnored(op); };
2162 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
2163 "expected 'source' to have no predecessors");
2172 bool fastPath = !getConfig().listener;
2174 if (fastPath &&
impl->config.allowPatternRollback)
2175 impl->inlineBlockBefore(source, dest, before);
2178 for (
auto it : llvm::zip(source->
getArguments(), argValues))
2179 replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
2186 while (!source->
empty())
2187 moveOpBefore(&source->
front(), dest, before);
2192 if (getInsertionBlock() == source)
2193 setInsertionPoint(dest, getInsertionPoint());
2200 if (!
impl->config.allowPatternRollback) {
2205 assert(!
impl->wasOpReplaced(op) &&
2206 "attempting to modify a replaced/erased op");
2208 impl->pendingRootUpdates.insert(op);
2210 impl->appendRewrite<ModifyOperationRewrite>(op);
2214 impl->patternModifiedOps.insert(op);
2215 if (!
impl->config.allowPatternRollback) {
2217 if (getConfig().listener)
2218 getConfig().listener->notifyOperationModified(op);
2225 assert(!
impl->wasOpReplaced(op) &&
2226 "attempting to modify a replaced/erased op");
2227 assert(
impl->pendingRootUpdates.erase(op) &&
2228 "operation did not have a pending in-place update");
2233 if (!
impl->config.allowPatternRollback) {
2238 assert(
impl->pendingRootUpdates.erase(op) &&
2239 "operation did not have a pending in-place update");
2242 auto it = llvm::find_if(
2243 llvm::reverse(
impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
2244 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2245 return modifyRewrite && modifyRewrite->getOperation() == op;
2247 assert(it !=
impl->rewrites.rend() &&
"no root update started on op");
2249 int updateIdx = std::prev(
impl->rewrites.rend()) - it;
2250 impl->rewrites.erase(
impl->rewrites.begin() + updateIdx);
2264 oneToOneOperands.reserve(operands.size());
2266 if (operand.size() != 1)
2269 oneToOneOperands.push_back(operand.front());
2271 return std::move(oneToOneOperands);
2278 auto &rewriterImpl = dialectRewriter.getImpl();
2282 getTypeConverter());
2291 llvm::to_vector_of<ValueRange>(remapped);
2292 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2304 class OperationLegalizer {
2324 LogicalResult legalizeWithFold(
Operation *op);
2328 LogicalResult legalizeWithPattern(
Operation *op);
2336 const RewriterState &curState,
2343 legalizePatternBlockRewrites(
Operation *op,
2359 void buildLegalizationGraph(
2360 LegalizationPatterns &anyOpLegalizerPatterns,
2371 void computeLegalizationGraphBenefit(
2372 LegalizationPatterns &anyOpLegalizerPatterns,
2377 unsigned computeOpLegalizationDepth(
2384 unsigned applyCostModelToPatterns(
2406 : rewriter(rewriter), target(targetInfo), applicator(
patterns) {
2410 LegalizationPatterns anyOpLegalizerPatterns;
2412 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2413 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2416 bool OperationLegalizer::isIllegal(
Operation *op)
const {
2417 return target.isIllegal(op);
2420 LogicalResult OperationLegalizer::legalize(
Operation *op) {
2422 const char *logLineComment =
2423 "//===-------------------------------------------===//\n";
2425 auto &logger = rewriter.getImpl().logger;
2429 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2432 logger.getOStream() <<
"\n";
2433 logger.startLine() << logLineComment;
2434 logger.startLine() <<
"Legalizing operation : ";
2439 logger.getOStream() <<
"'" << op->
getName() <<
"' ";
2440 logger.getOStream() <<
"(" << op <<
") {\n";
2445 logger.startLine() << OpWithFlags(op,
2446 OpPrintingFlags().printGenericOpForm())
2453 logSuccess(logger,
"operation marked 'ignored' during conversion");
2454 logger.startLine() << logLineComment;
2460 if (
auto legalityInfo = target.isLegal(op)) {
2463 logger,
"operation marked legal by the target{0}",
2464 legalityInfo->isRecursivelyLegal
2465 ?
"; NOTE: operation is recursively legal; skipping internals"
2467 logger.startLine() << logLineComment;
2472 if (legalityInfo->isRecursivelyLegal) {
2475 rewriter.getImpl().ignoredOps.
insert(nested);
2486 if (succeeded(legalizeWithFold(op))) {
2489 logger.startLine() << logLineComment;
2496 if (succeeded(legalizeWithPattern(op))) {
2499 logger.startLine() << logLineComment;
2507 if (succeeded(legalizeWithFold(op))) {
2510 logger.startLine() << logLineComment;
2517 logFailure(logger,
"no matched legalization pattern");
2518 logger.startLine() << logLineComment;
2525 template <
typename T>
2527 T result = std::move(obj);
2532 LogicalResult OperationLegalizer::legalizeWithFold(
Operation *op) {
2533 auto &rewriterImpl = rewriter.getImpl();
2535 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2536 rewriterImpl.
logger.indent();
2541 auto cleanup = llvm::make_scope_exit([&]() {
2552 SmallVector<Value, 2> replacementValues;
2553 SmallVector<Operation *, 2> newOps;
2556 if (
failed(rewriter.
tryFold(op, replacementValues, &newOps))) {
2565 if (replacementValues.empty())
2566 return legalize(op);
2569 rewriter.
replaceOp(op, replacementValues);
2573 if (
failed(legalize(newOp))) {
2575 "failed to legalize generated constant '{0}'",
2577 if (!rewriter.getConfig().allowPatternRollback) {
2579 llvm::report_fatal_error(
2581 "' folder rollback of IR modifications requested");
2600 auto newOpNames = llvm::map_range(
2602 auto modifiedOpNames = llvm::map_range(
2604 StringRef detachedBlockStr =
"(detached block)";
2605 auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](
Block *block) {
2608 return detachedBlockStr;
2610 llvm::report_fatal_error(
2612 "' produced IR that could not be legalized. " +
"new ops: {" +
2613 llvm::join(newOpNames,
", ") +
"}, " +
"modified ops: {" +
2614 llvm::join(modifiedOpNames,
", ") +
"}, " +
"inserted block into ops: {" +
2615 llvm::join(insertedBlockNames,
", ") +
"}");
2618 LogicalResult OperationLegalizer::legalizeWithPattern(
Operation *op) {
2619 auto &rewriterImpl = rewriter.getImpl();
2622 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2624 std::optional<OperationFingerPrint> topLevelFingerPrint;
2638 rewriterImpl.
logger.startLine()
2639 <<
"WARNING: Multi-threadeding is enabled. Some dialect "
2640 "conversion expensive checks are skipped in multithreading "
2649 auto canApply = [&](
const Pattern &pattern) {
2650 bool canApply = canApplyPattern(op, pattern);
2651 if (canApply &&
config.listener)
2652 config.listener->notifyPatternBegin(pattern, op);
2658 auto onFailure = [&](
const Pattern &pattern) {
2667 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2671 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2672 llvm::report_fatal_error(
"pattern '" + pattern.getDebugName() +
2673 "' returned failure but IR did change");
2684 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2691 config.listener->notifyPatternEnd(pattern, failure());
2692 rewriterImpl.
resetState(curState, pattern.getDebugName());
2693 appliedPatterns.erase(&pattern);
2698 auto onSuccess = [&](
const Pattern &pattern) {
2715 auto result = legalizePatternResult(op, pattern, curState, newOps,
2716 modifiedOps, insertedBlocks);
2717 appliedPatterns.erase(&pattern);
2722 rewriterImpl.
resetState(curState, pattern.getDebugName());
2725 config.listener->notifyPatternEnd(pattern, result);
2730 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2734 bool OperationLegalizer::canApplyPattern(
Operation *op,
2737 auto &os = rewriter.getImpl().logger;
2738 os.getOStream() <<
"\n";
2739 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2741 os.getOStream() <<
")' {\n";
2748 !appliedPatterns.insert(&pattern).second) {
2750 logFailure(rewriter.getImpl().logger,
"pattern was already applied"));
2756 LogicalResult OperationLegalizer::legalizePatternResult(
2761 [[maybe_unused]]
auto &
impl = rewriter.getImpl();
2762 assert(
impl.pendingRootUpdates.empty() &&
"dangling root updates");
2764 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2766 auto newRewrites = llvm::drop_begin(
impl.rewrites, curState.numRewrites);
2767 auto replacedRoot = [&] {
2768 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2770 auto updatedRootInPlace = [&] {
2771 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2773 if (!replacedRoot() && !updatedRootInPlace())
2774 llvm::report_fatal_error(
2775 "expected pattern to replace the root operation or modify it in place");
2779 if (
failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
2780 failed(legalizePatternRootUpdates(modifiedOps)) ||
2781 failed(legalizePatternCreatedOperations(newOps))) {
2785 LLVM_DEBUG(
logSuccess(
impl.logger,
"pattern applied successfully"));
2789 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2797 for (
Block *block : insertedBlocks) {
2798 if (
impl.erasedBlocks.contains(block))
2808 if (
auto *converter =
impl.regionToConverter.lookup(block->
getParent())) {
2809 std::optional<TypeConverter::SignatureConversion> conversion =
2810 converter->convertBlockSignature(block);
2812 LLVM_DEBUG(
logFailure(
impl.logger,
"failed to convert types of moved "
2816 impl.applySignatureConversion(block, converter, *conversion);
2824 if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
2825 if (
failed(legalize(parentOp))) {
2827 impl.logger,
"operation '{0}'({1}) became illegal after rewrite",
2828 parentOp->
getName(), parentOp));
2836 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2839 if (
failed(legalize(op))) {
2840 LLVM_DEBUG(
logFailure(rewriter.getImpl().logger,
2841 "failed to legalize generated operation '{0}'({1})",
2849 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2852 if (
failed(legalize(op))) {
2855 "failed to legalize operation updated in-place '{0}'",
2867 void OperationLegalizer::buildLegalizationGraph(
2868 LegalizationPatterns &anyOpLegalizerPatterns,
2879 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2880 std::optional<OperationName> root = pattern.
getRootKind();
2886 anyOpLegalizerPatterns.push_back(&pattern);
2891 if (target.getOpAction(*root) == LegalizationAction::Legal)
2896 invalidPatterns[*root].insert(&pattern);
2898 parentOps[op].insert(*root);
2901 patternWorklist.insert(&pattern);
2909 if (!anyOpLegalizerPatterns.empty()) {
2910 for (
const Pattern *pattern : patternWorklist)
2911 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2915 while (!patternWorklist.empty()) {
2916 auto *pattern = patternWorklist.pop_back_val();
2920 std::optional<LegalizationAction> action = target.getOpAction(op);
2921 return !legalizerPatterns.count(op) &&
2922 (!action || action == LegalizationAction::Illegal);
2928 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2929 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2933 for (
auto op : parentOps[*pattern->
getRootKind()])
2934 patternWorklist.set_union(invalidPatterns[op]);
2938 void OperationLegalizer::computeLegalizationGraphBenefit(
2939 LegalizationPatterns &anyOpLegalizerPatterns,
2945 for (
auto &opIt : legalizerPatterns)
2946 if (!minOpPatternDepth.count(opIt.first))
2947 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2953 if (!anyOpLegalizerPatterns.empty())
2954 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2960 applicator.applyCostModel([&](
const Pattern &pattern) {
2962 if (std::optional<OperationName> rootName = pattern.
getRootKind())
2963 orderedPatternList = legalizerPatterns[*rootName];
2965 orderedPatternList = anyOpLegalizerPatterns;
2968 auto *it = llvm::find(orderedPatternList, &pattern);
2969 if (it == orderedPatternList.end())
2973 return PatternBenefit(std::distance(it, orderedPatternList.end()));
2977 unsigned OperationLegalizer::computeOpLegalizationDepth(
2981 auto depthIt = minOpPatternDepth.find(op);
2982 if (depthIt != minOpPatternDepth.end())
2983 return depthIt->second;
2987 auto opPatternsIt = legalizerPatterns.find(op);
2988 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2997 unsigned minDepth = applyCostModelToPatterns(
2998 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2999 minOpPatternDepth[op] = minDepth;
3003 unsigned OperationLegalizer::applyCostModelToPatterns(
3010 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3011 patternsByDepth.reserve(
patterns.size());
3015 unsigned generatedOpDepth = computeOpLegalizationDepth(
3016 generatedOp, minOpPatternDepth, legalizerPatterns);
3017 depth =
std::max(depth, generatedOpDepth + 1);
3019 patternsByDepth.emplace_back(pattern, depth);
3022 minDepth =
std::min(minDepth, depth);
3027 if (patternsByDepth.size() == 1)
3031 llvm::stable_sort(patternsByDepth,
3032 [](
const std::pair<const Pattern *, unsigned> &lhs,
3033 const std::pair<const Pattern *, unsigned> &rhs) {
3036 if (lhs.second != rhs.second)
3037 return lhs.second < rhs.second;
3040 auto lhsBenefit = lhs.first->getBenefit();
3041 auto rhsBenefit = rhs.first->getBenefit();
3042 return lhsBenefit > rhsBenefit;
3047 for (
auto &patternIt : patternsByDepth)
3048 patterns.push_back(patternIt.first);
3056 enum OpConversionMode {
3079 OpConversionMode mode)
3080 : rewriter(ctx,
config), opLegalizer(rewriter, target,
patterns),
3094 OperationLegalizer opLegalizer;
3097 OpConversionMode mode;
3101 LogicalResult OperationConverter::convert(
Operation *op) {
3105 if (
failed(opLegalizer.legalize(op))) {
3108 if (mode == OpConversionMode::Full)
3110 <<
"failed to legalize operation '" << op->
getName() <<
"'";
3114 if (mode == OpConversionMode::Partial) {
3115 if (opLegalizer.isIllegal(op))
3117 <<
"failed to legalize operation '" << op->
getName()
3118 <<
"' that was explicitly marked illegal";
3119 if (
config.unlegalizedOps)
3120 config.unlegalizedOps->insert(op);
3122 }
else if (mode == OpConversionMode::Analysis) {
3126 if (
config.legalizableOps)
3127 config.legalizableOps->insert(op);
3132 static LogicalResult
3134 UnrealizedConversionCastOp op,
3135 const UnresolvedMaterializationInfo &info) {
3136 assert(!op.use_empty() &&
3137 "expected that dead materializations have already been DCE'd");
3143 SmallVector<Value> newMaterialization;
3144 switch (info.getMaterializationKind()) {
3145 case MaterializationKind::Target:
3146 newMaterialization = converter->materializeTargetConversion(
3147 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3148 info.getOriginalType());
3150 case MaterializationKind::Source:
3151 assert(op->getNumResults() == 1 &&
"expected single result");
3152 Value sourceMat = converter->materializeSourceConversion(
3153 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3155 newMaterialization.push_back(sourceMat);
3158 if (!newMaterialization.empty()) {
3160 ValueRange newMaterializationRange(newMaterialization);
3161 assert(
TypeRange(newMaterializationRange) == op.getResultTypes() &&
3162 "materialization callback produced value of incorrect type");
3164 rewriter.
replaceOp(op, newMaterialization);
3170 <<
"failed to legalize unresolved materialization "
3172 << inputOperands.
getTypes() <<
") to ("
3173 << op.getResultTypes()
3174 <<
") that remained live after conversion";
3175 diag.attachNote(op->getUsers().begin()->getLoc())
3176 <<
"see existing live user here: " << *op->getUsers().begin();
3185 for (
auto *op : ops) {
3188 toConvert.push_back(op);
3191 auto legalityInfo = target.
isLegal(op);
3192 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3201 for (
auto *op : toConvert) {
3202 if (
failed(convert(op))) {
3223 for (
auto it : materializations)
3224 allCastOps.push_back(it.first);
3233 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3237 if (rewriter.getConfig().buildMaterializations) {
3242 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3243 auto it = materializations.find(castOp);
3244 assert(it != materializations.end() &&
"inconsistent state");
3267 auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
3268 for (
Value v : castOp.getInputs())
3269 if (
auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3270 worklist.insert(inputCastOp);
3277 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3278 if (castOp.getInputs().empty())
3281 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3284 if (inputCastOp.getOutputs() != castOp.getInputs())
3290 while (!worklist.empty()) {
3291 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3292 if (castOp->use_empty()) {
3295 enqueueOperands(castOp);
3296 if (remainingCastOps)
3297 erasedOps.insert(castOp.getOperation());
3304 UnrealizedConversionCastOp nextCast = castOp;
3306 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3310 enqueueOperands(castOp);
3311 castOp.replaceAllUsesWith(nextCast.getInputs());
3312 if (remainingCastOps)
3313 erasedOps.insert(castOp.getOperation());
3317 nextCast = getInputCast(nextCast);
3321 if (remainingCastOps)
3322 for (UnrealizedConversionCastOp op : castOps)
3323 if (!erasedOps.contains(op.getOperation()))
3324 remainingCastOps->push_back(op);
3333 assert(!types.empty() &&
"expected valid types");
3334 remapInput(origInputNo, argTypes.size(), types.size());
3339 assert(!types.empty() &&
3340 "1->0 type remappings don't need to be added explicitly");
3341 argTypes.append(types.begin(), types.end());
3345 unsigned newInputNo,
3346 unsigned newInputCount) {
3347 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3348 assert(newInputCount != 0 &&
"expected valid input count");
3349 remappedInputs[origInputNo] =
3350 InputMapping{newInputNo, newInputCount, {}};
3355 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3363 assert(t &&
"expected non-null type");
3366 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3369 cacheReadLock.lock();
3370 auto existingIt = cachedDirectConversions.find(t);
3371 if (existingIt != cachedDirectConversions.end()) {
3372 if (existingIt->second)
3373 results.push_back(existingIt->second);
3374 return success(existingIt->second !=
nullptr);
3376 auto multiIt = cachedMultiConversions.find(t);
3377 if (multiIt != cachedMultiConversions.end()) {
3378 results.append(multiIt->second.begin(), multiIt->second.end());
3384 size_t currentCount = results.size();
3386 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3389 for (
const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
3390 if (std::optional<LogicalResult> result = converter(t, results)) {
3392 cacheWriteLock.lock();
3393 if (!succeeded(*result)) {
3394 assert(results.size() == currentCount &&
3395 "failed type conversion should not change results");
3396 cachedDirectConversions.try_emplace(t,
nullptr);
3399 auto newTypes =
ArrayRef<Type>(results).drop_front(currentCount);
3400 if (newTypes.size() == 1)
3401 cachedDirectConversions.try_emplace(t, newTypes.front());
3403 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3406 assert(results.size() == currentCount &&
3407 "failed type conversion should not change results");
3415 assert(v &&
"expected non-null value");
3419 if (!hasContextAwareTypeConversions)
3424 for (
const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
3425 if (std::optional<LogicalResult> result = converter(v, results)) {
3426 if (!succeeded(*result))
3441 return results.size() == 1 ? results.front() :
nullptr;
3451 return results.size() == 1 ? results.front() :
nullptr;
3457 for (
Type type : types)
3466 for (
Value value : values)
3485 return llvm::all_of(
3492 if (!
isLegal(ty.getResults()))
3506 if (convertedTypes.empty())
3510 result.
addInputs(inputNo, convertedTypes);
3516 unsigned origInputOffset)
const {
3517 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3531 if (convertedTypes.empty())
3535 result.
addInputs(inputNo, convertedTypes);
3541 unsigned origInputOffset)
const {
3542 for (
unsigned i = 0, e = values.size(); i != e; ++i)
3551 for (
const SourceMaterializationCallbackFn &fn :
3552 llvm::reverse(sourceMaterializations))
3553 if (
Value result = fn(builder, resultType, inputs, loc))
3561 Type originalType)
const {
3563 builder, loc,
TypeRange(resultType), inputs, originalType);
3566 assert(result.size() == 1 &&
"expected single result");
3567 return result.front();
3572 Type originalType)
const {
3573 for (
const TargetMaterializationCallbackFn &fn :
3574 llvm::reverse(targetMaterializations)) {
3576 fn(builder, resultTypes, inputs, loc, originalType);
3580 "callback produced incorrect number of values or values with "
3587 std::optional<TypeConverter::SignatureConversion>
3591 return std::nullopt;
3614 return impl.getInt() == resultTag;
3618 return impl.getInt() == naTag;
3622 return impl.getInt() == abortTag;
3626 assert(hasResult() &&
"Cannot get result from N/A or abort");
3627 return impl.getPointer();
3630 std::optional<Attribute>
3632 for (
const TypeAttributeConversionCallbackFn &fn :
3633 llvm::reverse(typeAttributeConversions)) {
3638 return std::nullopt;
3640 return std::nullopt;
3650 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3656 SmallVector<Type, 1> newResults;
3660 typeConverter, &result)))
3677 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3685 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3690 struct AnyFunctionOpInterfaceSignatureConversion
3702 FailureOr<Operation *>
3706 assert(op &&
"Invalid op");
3720 return rewriter.
create(newOp);
3726 patterns.add<FunctionOpInterfaceSignatureConversion>(
3727 functionLikeOpName,
patterns.getContext(), converter);
3732 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3742 legalOperations[op].action = action;
3747 for (StringRef dialect : dialectNames)
3748 legalDialects[dialect] = action;
3752 -> std::optional<LegalizationAction> {
3753 std::optional<LegalizationInfo> info = getOpInfo(op);
3754 return info ? info->action : std::optional<LegalizationAction>();
3758 -> std::optional<LegalOpDetails> {
3759 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3761 return std::nullopt;
3764 auto isOpLegal = [&] {
3766 if (info->action == LegalizationAction::Dynamic) {
3767 std::optional<bool> result = info->legalityFn(op);
3773 return info->action == LegalizationAction::Legal;
3776 return std::nullopt;
3780 if (info->isRecursivelyLegal) {
3781 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3782 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3784 legalityFnIt->second(op).value_or(
true);
3789 return legalityDetails;
3793 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3797 if (info->action == LegalizationAction::Dynamic) {
3798 std::optional<bool> result = info->legalityFn(op);
3805 return info->action == LegalizationAction::Illegal;
3814 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3816 if (std::optional<bool> result = newCl(op))
3824 void ConversionTarget::setLegalityCallback(
3825 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3826 assert(callback &&
"expected valid legality callback");
3827 auto *infoIt = legalOperations.find(name);
3828 assert(infoIt != legalOperations.end() &&
3829 infoIt->second.action == LegalizationAction::Dynamic &&
3830 "expected operation to already be marked as dynamically legal");
3831 infoIt->second.legalityFn =
3837 auto *infoIt = legalOperations.find(name);
3838 assert(infoIt != legalOperations.end() &&
3839 infoIt->second.action != LegalizationAction::Illegal &&
3840 "expected operation to already be marked as legal");
3841 infoIt->second.isRecursivelyLegal =
true;
3844 std::move(opRecursiveLegalityFns[name]), callback);
3846 opRecursiveLegalityFns.erase(name);
3849 void ConversionTarget::setLegalityCallback(
3851 assert(callback &&
"expected valid legality callback");
3852 for (StringRef dialect : dialects)
3854 std::move(dialectLegalityFns[dialect]), callback);
3857 void ConversionTarget::setLegalityCallback(
3858 const DynamicLegalityCallbackFn &callback) {
3859 assert(callback &&
"expected valid legality callback");
3864 -> std::optional<LegalizationInfo> {
3866 const auto *it = legalOperations.find(op);
3867 if (it != legalOperations.end())
3870 auto dialectIt = legalDialects.find(op.getDialectNamespace());
3871 if (dialectIt != legalDialects.end()) {
3872 DynamicLegalityCallbackFn callback;
3873 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3874 if (dialectFn != dialectLegalityFns.end())
3875 callback = dialectFn->second;
3876 return LegalizationInfo{dialectIt->second,
false,
3880 if (unknownLegalityFn)
3881 return LegalizationInfo{LegalizationAction::Dynamic,
3882 false, unknownLegalityFn};
3883 return std::nullopt;
3886 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3892 auto &rewriterImpl =
3898 auto &rewriterImpl =
3905 static FailureOr<SmallVector<Value>>
3907 SmallVector<Value> mappedValues;
3910 return std::move(mappedValues);
3914 patterns.getPDLPatterns().registerRewriteFunction(
3921 return results->front();
3923 patterns.getPDLPatterns().registerRewriteFunction(
3928 patterns.getPDLPatterns().registerRewriteFunction(
3931 auto &rewriterImpl =
3935 if (
Type newType = converter->convertType(type))
3941 patterns.getPDLPatterns().registerRewriteFunction(
3944 TypeRange types) -> FailureOr<SmallVector<Type>> {
3945 auto &rewriterImpl =
3954 return std::move(remappedTypes);
3969 static constexpr StringLiteral tag =
"apply-conversion";
3970 static constexpr StringLiteral desc =
3971 "Encapsulate the application of a dialect conversion";
3973 void print(raw_ostream &os)
const override { os << tag; }
3980 OpConversionMode mode) {
3984 LogicalResult status = success();
3985 SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
4004 OpConversionMode::Partial);
4044 "expected top-level op to be isolated from above");
4047 "expected ops to have a common ancestor");
4056 for (
Operation *op : ops.drop_front()) {
4060 assert(commonAncestor &&
4061 "expected to find a common isolated from above ancestor");
4065 return commonAncestor;
4072 if (
config.legalizableOps)
4073 assert(
config.legalizableOps->empty() &&
"expected empty set");
4083 inverseOperationMap[it.second] = it.first;
4089 OpConversionMode::Analysis);
4093 if (
config.legalizableOps) {
4096 originalLegalizableOps.insert(inverseOperationMap[op]);
4097 *
config.legalizableOps = std::move(originalLegalizableOps);
4101 clonedAncestor->
erase();
static void setInsertionPointAfter(OpBuilder &b, Value value)
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value >> &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg, Value repl)
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps, const SetVector< Block * > &insertedBlocks)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This is the type of Action that is dispatched when a conversion is applied.
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
const ConversionConfig & getConfig() const
Return the configuration of the current dialect conversion.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={}) override
PatternRewriter hook for inlining the ops of a block into another block.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to)
Replace all the uses of the block argument from with to.
~ConversionPatternRewriter() override
Base class for the conversion patterns.
FailureOr< SmallVector< Value > > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Listener * listener
The optional listener for events of this builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
type_range getTypes() const
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
int getPropertiesStorageSize() const
Returns the properties storage size.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
The general result of a type attribute conversion callback, allowing for early termination.
Attribute getResult() const
static AttributeConversionResult abort()
static AttributeConversionResult na()
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
CRTP Implementation of an action.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
@ AfterPatterns
Only attempt to fold not legal operations after applying patterns.
@ BeforePatterns
Only attempt to fold not legal operations before applying patterns.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
bool allowPatternRollback
If set to "true", pattern rollback is allowed.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
bool attachDebugMaterializationKind
If set to "true", the materialization kind ("source" or "target") will be attached to "builtin....
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
bool wasErased(void *ptr) const
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to, const TypeConverter *converter)
Replace the given block argument with the given values.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
DenseSet< BlockArgument > replacedArgs
A set of replaced block arguments.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config)
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Block * > patternInsertedBlocks
A set of blocks that were inserted (newly-created blocks or moved blocks) by the current pattern.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
void replaceOp(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the results of the given operation with the given values and erase the operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.