10 #include "mlir/Config/mlir-config.h"
20 #include "llvm/ADT/ScopeExit.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "llvm/Support/SaveAndRestore.h"
26 #include "llvm/Support/ScopedPrinter.h"
32 #define DEBUG_TYPE "dialect-conversion"
35 template <
typename... Args>
36 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
39 os.startLine() <<
"} -> SUCCESS";
41 os.getOStream() <<
" : "
42 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
43 os.getOStream() <<
"\n";
48 template <
typename... Args>
49 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
52 os.startLine() <<
"} -> FAILURE : "
53 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
63 if (
OpResult inputRes = dyn_cast<OpResult>(value))
64 insertPt = ++inputRes.getOwner()->getIterator();
71 assert(!vals.empty() &&
"expected at least one value");
74 for (
Value v : vals.drop_front()) {
88 assert(dom &&
"unable to find valid insertion point");
106 struct ValueVectorMapInfo {
109 static ::llvm::hash_code getHashValue(
const ValueVector &val) {
110 return ::llvm::hash_combine_range(val);
119 struct ConversionValueMapping {
122 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
127 template <
typename T>
128 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
131 template <
typename OldVal,
typename NewVal>
132 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
133 map(OldVal &&oldVal, NewVal &&newVal) {
137 assert(next != oldVal &&
"inserting cyclic mapping");
138 auto it = mapping.find(next);
139 if (it == mapping.end())
144 mappedTo.insert_range(newVal);
146 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
150 template <
typename OldVal,
typename NewVal>
151 std::enable_if_t<!IsValueVector<OldVal>::value ||
152 !IsValueVector<NewVal>::value>
153 map(OldVal &&oldVal, NewVal &&newVal) {
154 if constexpr (IsValueVector<OldVal>{}) {
155 map(std::forward<OldVal>(oldVal),
ValueVector{newVal});
156 }
else if constexpr (IsValueVector<NewVal>{}) {
157 map(
ValueVector{oldVal}, std::forward<NewVal>(newVal));
163 void map(
Value oldVal, SmallVector<Value> &&newVal) {
168 void erase(
const ValueVector &value) { mapping.erase(value); }
188 assert(!values.empty() &&
"expected non-empty value vector");
189 Operation *op = values.front().getDefiningOp();
190 for (
Value v : llvm::drop_begin(values)) {
191 if (v.getDefiningOp() != op)
201 assert(!values.empty() &&
"expected non-empty value vector");
207 auto it = mapping.find(from);
208 if (it == mapping.end()) {
221 struct RewriterState {
222 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
223 unsigned numReplacedOps)
224 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
225 numReplacedOps(numReplacedOps) {}
228 unsigned numRewrites;
231 unsigned numIgnoredOperations;
234 unsigned numReplacedOps;
246 notifyIRErased(listener, op);
255 notifyIRErased(listener, b);
285 UnresolvedMaterialization,
290 virtual ~IRRewrite() =
default;
293 virtual void rollback() = 0;
312 Kind getKind()
const {
return kind; }
314 static bool classof(
const IRRewrite *
rewrite) {
return true; }
318 :
kind(
kind), rewriterImpl(rewriterImpl) {}
327 class BlockRewrite :
public IRRewrite {
330 Block *getBlock()
const {
return block; }
332 static bool classof(
const IRRewrite *
rewrite) {
333 return rewrite->getKind() >= Kind::CreateBlock &&
334 rewrite->getKind() <= Kind::BlockTypeConversion;
340 : IRRewrite(
kind, rewriterImpl), block(block) {}
347 class ValueRewrite :
public IRRewrite {
350 Value getValue()
const {
return value; }
352 static bool classof(
const IRRewrite *
rewrite) {
353 return rewrite->getKind() == Kind::ReplaceValue;
359 : IRRewrite(
kind, rewriterImpl), value(value) {}
368 class CreateBlockRewrite :
public BlockRewrite {
371 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
373 static bool classof(
const IRRewrite *
rewrite) {
374 return rewrite->getKind() == Kind::CreateBlock;
383 void rollback()
override {
386 auto &blockOps = block->getOperations();
387 while (!blockOps.empty())
388 blockOps.remove(blockOps.begin());
389 block->dropAllUses();
390 if (block->getParent())
401 class EraseBlockRewrite :
public BlockRewrite {
404 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
405 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
407 static bool classof(
const IRRewrite *
rewrite) {
408 return rewrite->getKind() == Kind::EraseBlock;
411 ~EraseBlockRewrite()
override {
413 "rewrite was neither rolled back nor committed/cleaned up");
416 void rollback()
override {
419 assert(block &&
"expected block");
420 auto &blockList = region->getBlocks();
424 blockList.insert(before, block);
429 assert(block &&
"expected block");
433 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
434 notifyIRErased(listener, *block);
439 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
441 assert(block->empty() &&
"expected empty block");
444 block->dropAllDefinedValueUses();
455 Block *insertBeforeBlock;
461 class InlineBlockRewrite :
public BlockRewrite {
465 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
466 sourceBlock(sourceBlock),
467 firstInlinedInst(sourceBlock->empty() ? nullptr
468 : &sourceBlock->front()),
469 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
475 assert(!getConfig().listener &&
476 "InlineBlockRewrite not supported if listener is attached");
479 static bool classof(
const IRRewrite *
rewrite) {
480 return rewrite->getKind() == Kind::InlineBlock;
483 void rollback()
override {
486 if (firstInlinedInst) {
487 assert(lastInlinedInst &&
"expected operation");
507 class MoveBlockRewrite :
public BlockRewrite {
511 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block),
512 region(previousRegion),
513 insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
516 static bool classof(
const IRRewrite *
rewrite) {
517 return rewrite->getKind() == Kind::MoveBlock;
530 void rollback()
override {
543 Block *insertBeforeBlock;
547 class BlockTypeConversionRewrite :
public BlockRewrite {
551 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, origBlock),
552 newBlock(newBlock) {}
554 static bool classof(
const IRRewrite *
rewrite) {
555 return rewrite->getKind() == Kind::BlockTypeConversion;
558 Block *getOrigBlock()
const {
return block; }
560 Block *getNewBlock()
const {
return newBlock; }
564 void rollback()
override;
574 class ReplaceValueRewrite :
public ValueRewrite {
578 : ValueRewrite(
Kind::ReplaceValue, rewriterImpl, value),
579 converter(converter) {}
581 static bool classof(
const IRRewrite *
rewrite) {
582 return rewrite->getKind() == Kind::ReplaceValue;
587 void rollback()
override;
595 class OperationRewrite :
public IRRewrite {
598 Operation *getOperation()
const {
return op; }
600 static bool classof(
const IRRewrite *
rewrite) {
601 return rewrite->getKind() >= Kind::MoveOperation &&
602 rewrite->getKind() <= Kind::UnresolvedMaterialization;
608 : IRRewrite(
kind, rewriterImpl), op(op) {}
615 class MoveOperationRewrite :
public OperationRewrite {
619 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op),
620 block(previous.getBlock()),
621 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
623 : &*previous.getPoint()) {}
625 static bool classof(
const IRRewrite *
rewrite) {
626 return rewrite->getKind() == Kind::MoveOperation;
640 void rollback()
override {
658 class ModifyOperationRewrite :
public OperationRewrite {
662 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
663 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
664 operands(op->operand_begin(), op->operand_end()),
665 successors(op->successor_begin(), op->successor_end()) {
670 name.initOpProperties(propCopy, prop);
674 static bool classof(
const IRRewrite *
rewrite) {
675 return rewrite->getKind() == Kind::ModifyOperation;
678 ~ModifyOperationRewrite()
override {
679 assert(!propertiesStorage &&
680 "rewrite was neither committed nor rolled back");
686 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
689 if (propertiesStorage) {
693 name.destroyOpProperties(propCopy);
694 operator delete(propertiesStorage);
695 propertiesStorage =
nullptr;
699 void rollback()
override {
705 if (propertiesStorage) {
708 name.destroyOpProperties(propCopy);
709 operator delete(propertiesStorage);
710 propertiesStorage =
nullptr;
717 DictionaryAttr attrs;
718 SmallVector<Value, 8> operands;
719 SmallVector<Block *, 2> successors;
720 void *propertiesStorage =
nullptr;
727 class ReplaceOperationRewrite :
public OperationRewrite {
731 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
732 converter(converter) {}
734 static bool classof(
const IRRewrite *
rewrite) {
735 return rewrite->getKind() == Kind::ReplaceOperation;
740 void rollback()
override;
750 class CreateOperationRewrite :
public OperationRewrite {
754 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
756 static bool classof(
const IRRewrite *
rewrite) {
757 return rewrite->getKind() == Kind::CreateOperation;
766 void rollback()
override;
770 enum MaterializationKind {
781 class UnresolvedMaterializationInfo {
783 UnresolvedMaterializationInfo() =
default;
784 UnresolvedMaterializationInfo(
const TypeConverter *converter,
785 MaterializationKind
kind,
Type originalType)
786 : converterAndKind(converter,
kind), originalType(originalType) {}
790 return converterAndKind.getPointer();
794 MaterializationKind getMaterializationKind()
const {
795 return converterAndKind.getInt();
799 Type getOriginalType()
const {
return originalType; }
804 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
815 class UnresolvedMaterializationRewrite :
public OperationRewrite {
818 UnrealizedConversionCastOp op,
820 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
821 mappedValues(std::move(mappedValues)) {}
823 static bool classof(
const IRRewrite *
rewrite) {
824 return rewrite->getKind() == Kind::UnresolvedMaterialization;
827 void rollback()
override;
829 UnrealizedConversionCastOp getOperation()
const {
830 return cast<UnrealizedConversionCastOp>(op);
840 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
843 template <
typename RewriteTy,
typename R>
844 static bool hasRewrite(R &&rewrites,
Operation *op) {
845 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
846 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
847 return rewriteTy && rewriteTy->getOperation() == op;
853 template <
typename RewriteTy,
typename R>
854 static bool hasRewrite(R &&rewrites,
Block *block) {
855 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
856 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
857 return rewriteTy && rewriteTy->getBlock() == block;
878 RewriterState getCurrentState();
882 void applyRewrites();
887 void resetState(RewriterState state, StringRef patternName =
"");
891 template <
typename RewriteTy,
typename... Args>
893 assert(
config.allowPatternRollback &&
"appending rewrites is not allowed");
895 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
901 void undoRewrites(
unsigned numRewritesToKeep = 0, StringRef patternName =
"");
907 LogicalResult remapValues(StringRef valueDiagTag,
908 std::optional<Location> inputLoc,
ValueRange values,
925 bool skipPureTypeConversions =
false)
const;
947 Block *applySignatureConversion(
958 void replaceOp(
Operation *op, SmallVector<SmallVector<Value>> &&newValues);
966 void eraseBlock(
Block *block);
1004 Value findOrBuildReplacementValue(
Value value,
1012 void notifyOperationInserted(
Operation *op,
1016 void notifyBlockInserted(
Block *block,
Region *previous,
1035 std::function<
void(
Operation *)> opErasedCallback =
nullptr)
1037 opErasedCallback(opErasedCallback) {}
1049 if (wasErased(block))
1051 assert(block->
empty() &&
"expected empty block");
1056 bool wasErased(
void *ptr)
const {
return erased.contains(ptr); }
1060 if (opErasedCallback)
1061 opErasedCallback(op);
1071 std::function<void(
Operation *)> opErasedCallback;
1159 llvm::impl::raw_ldbg_ostream os{(Twine(
"[") +
DEBUG_TYPE +
":1] ").str(),
1163 llvm::ScopedPrinter logger{os};
1171 return rewriterImpl.
config;
1174 void BlockTypeConversionRewrite::commit(
RewriterBase &rewriter) {
1178 if (
auto *listener =
1179 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1180 for (
Operation *op : getNewBlock()->getUsers())
1184 void BlockTypeConversionRewrite::rollback() {
1185 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1191 if (isa<BlockArgument>(repl)) {
1228 void ReplaceValueRewrite::commit(
RewriterBase &rewriter) {
1235 void ReplaceValueRewrite::rollback() {
1236 rewriterImpl.
mapping.erase({value});
1242 void ReplaceOperationRewrite::commit(
RewriterBase &rewriter) {
1244 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1247 SmallVector<Value> replacements =
1249 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1257 for (
auto [result, newValue] :
1258 llvm::zip_equal(op->
getResults(), replacements))
1264 if (getConfig().unlegalizedOps)
1265 getConfig().unlegalizedOps->erase(op);
1269 notifyIRErased(listener, *op);
1276 void ReplaceOperationRewrite::rollback() {
1278 rewriterImpl.
mapping.erase({result});
1281 void ReplaceOperationRewrite::cleanup(
RewriterBase &rewriter) {
1285 void CreateOperationRewrite::rollback() {
1287 while (!region.getBlocks().empty())
1288 region.getBlocks().remove(region.getBlocks().begin());
1294 void UnresolvedMaterializationRewrite::rollback() {
1295 if (!mappedValues.empty())
1296 rewriterImpl.
mapping.erase(mappedValues);
1307 for (
size_t i = 0; i <
rewrites.size(); ++i)
1313 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1314 unresolvedMaterializations.erase(castOp);
1317 rewrite->cleanup(eraseRewriter);
1325 Value from,
TypeRange desiredTypes,
bool skipPureTypeConversions)
const {
1328 assert(!values.empty() &&
"expected non-empty value vector");
1333 return mapping.lookup(values);
1340 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1345 if (castOp.getOutputs() != values)
1347 return castOp.getInputs();
1356 for (
Value v : values) {
1359 llvm::append_range(next, r);
1364 if (next != values) {
1393 if (skipPureTypeConversions) {
1396 match &= !pureConversion;
1399 if (!pureConversion)
1400 lastNonMaterialization = current;
1403 desiredValue = current;
1409 current = std::move(next);
1414 if (!desiredTypes.empty())
1415 return desiredValue;
1416 if (skipPureTypeConversions)
1417 return lastNonMaterialization;
1436 StringRef patternName) {
1441 while (
ignoredOps.size() != state.numIgnoredOperations)
1444 while (
replacedOps.size() != state.numReplacedOps)
1449 StringRef patternName) {
1451 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1453 rewrites.resize(numRewritesToKeep);
1457 StringRef valueDiagTag, std::optional<Location> inputLoc,
ValueRange values,
1459 remapped.reserve(llvm::size(values));
1462 Value operand = it.value();
1481 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1482 << it.index() <<
", type was " << origType;
1487 if (legalTypes.empty()) {
1488 remapped.push_back({});
1497 remapped.push_back(std::move(repl));
1506 repl, repl, legalTypes,
1508 remapped.push_back(castValues);
1531 if (region->
empty())
1536 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1538 std::optional<TypeConverter::SignatureConversion> conversion =
1548 if (entryConversion)
1551 std::optional<TypeConverter::SignatureConversion> conversion =
1561 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1563 if (hasRewrite<BlockTypeConversionRewrite>(
rewrites, block))
1564 llvm::report_fatal_error(
"block was already converted");
1578 for (
unsigned i = 0; i < origArgCount; ++i) {
1580 if (!inputMap || inputMap->replacedWithValues())
1583 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1584 newLocs[inputMap->inputNo +
j] = origLoc;
1591 convertedTypes, newLocs);
1602 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->
end());
1605 while (!block->
empty())
1612 for (
unsigned i = 0; i != origArgCount; ++i) {
1616 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1625 MaterializationKind::Source,
1629 origArgType,
Type(), converter,
1636 if (inputMap->replacedWithValues()) {
1638 assert(inputMap->size == 0 &&
1639 "invalid to provide a replacement value when the argument isn't "
1647 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1652 appendRewrite<BlockTypeConversionRewrite>(block, newBlock);
1672 assert((!originalType ||
kind == MaterializationKind::Target) &&
1673 "original type is valid only for target materializations");
1674 assert(
TypeRange(inputs) != outputTypes &&
1675 "materialization is not necessary");
1679 OpBuilder builder(outputTypes.front().getContext());
1681 UnrealizedConversionCastOp convertOp =
1682 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1685 kind == MaterializationKind::Source ?
"source" :
"target";
1686 convertOp->setAttr(
"__kind__", builder.
getStringAttr(kindStr));
1693 UnresolvedMaterializationInfo(converter,
kind, originalType);
1695 if (!valuesToMap.empty())
1696 mapping.map(valuesToMap, convertOp.getResults());
1697 appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
1698 std::move(valuesToMap));
1702 return convertOp.getResults();
1708 "this code path is valid only in rollback mode");
1715 return repl.front();
1722 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1747 MaterializationKind::Source, ip, value.
getLoc(),
1763 bool wasDetached = !previous.
isSet();
1765 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"' (" << op
1768 logger.getOStream() <<
" (was detached)";
1769 logger.getOStream() <<
"\n";
1775 "attempting to insert into a block within a replaced/erased op");
1792 appendRewrite<CreateOperationRewrite>(op);
1803 appendRewrite<MoveOperationRewrite>(op, previous);
1810 const SmallVector<SmallVector<Value>> &toRange,
1812 assert(!
impl.config.allowPatternRollback &&
1813 "this code path is valid only in 'no rollback' mode");
1814 SmallVector<Value> repls;
1815 for (
auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1818 repls.push_back(
Value());
1825 Value srcMat =
impl.buildUnresolvedMaterialization(
1830 repls.push_back(srcMat);
1836 repls.push_back(to[0]);
1845 Value srcMat =
impl.buildUnresolvedMaterialization(
1848 Type(), converter)[0];
1849 repls.push_back(srcMat);
1858 "incorrect number of replacement values");
1860 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
1868 for (
auto [result, repls] :
1869 llvm::zip_equal(op->
getResults(), newValues)) {
1870 Type resultType = result.getType();
1871 auto logProlog = [&, repls = repls]() {
1872 logger.startLine() <<
" Note: Replacing op result of type "
1873 << resultType <<
" with value(s) of type (";
1874 llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) {
1875 logger.getOStream() << v.getType();
1877 logger.getOStream() <<
")";
1883 logger.getOStream() <<
", but the type converter failed to legalize "
1884 "the original type.\n";
1889 logger.getOStream() <<
", but the legalized type(s) is/are (";
1890 llvm::interleaveComma(convertedTypes,
logger.getOStream(),
1891 [&](
Type t) { logger.getOStream() << t; });
1892 logger.getOStream() <<
")\n";
1898 if (!
config.allowPatternRollback) {
1901 *
this, op->
getResults(), newValues, currentTypeConverter);
1905 erasedOps.insert(op);
1906 ignoredOps.remove(op);
1907 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1908 unresolvedMaterializations.erase(castOp);
1909 patternMaterializations.erase(castOp);
1913 if (
config.unlegalizedOps)
1914 config.unlegalizedOps->erase(op);
1916 op->
walk([&](
Block *block) { erasedBlocks.insert(block); });
1918 notifyingRewriter.replaceOp(op, repls);
1922 assert(!ignoredOps.contains(op) &&
"operation was already replaced");
1925 assert(!replacedValues.contains(v) &&
1926 "attempting to replace a value that was already replaced");
1931 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1935 assert(!unresolvedMaterializations.contains(castOp) &&
1936 "attempting to replace/erase an unresolved materialization");
1940 for (
auto [repl, result] : llvm::zip_equal(newValues, op->
getResults()))
1941 mapping.map(
static_cast<Value>(result), std::move(repl));
1943 appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
1948 void ConversionPatternRewriterImpl::replaceAllUsesWith(
1950 if (!
config.allowPatternRollback) {
1955 Value repl = repls.front();
1971 assert(!replacedValues.contains(from) &&
1972 "attempting to replace a value that was already replaced");
1974 "attempting to replace a op result that was already replaced");
1975 replacedValues.insert(from);
1978 mapping.map(from, to);
1979 appendRewrite<ReplaceValueRewrite>(from, converter);
1982 void ConversionPatternRewriterImpl::eraseBlock(
Block *block) {
1983 if (!
config.allowPatternRollback) {
1988 erasedOps.insert(op);
1989 ignoredOps.remove(op);
1990 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1991 unresolvedMaterializations.erase(castOp);
1992 patternMaterializations.erase(castOp);
1996 if (
config.unlegalizedOps)
1997 config.unlegalizedOps->erase(op);
1999 block->
walk([&](
Block *block) { erasedBlocks.insert(block); });
2001 notifyingRewriter.eraseBlock(block);
2006 "attempting to erase a block within a replaced/erased op");
2007 appendRewrite<EraseBlockRewrite>(block);
2016 block->
walk([&](
Operation *op) { replacedOps.insert(op); });
2019 void ConversionPatternRewriterImpl::notifyBlockInserted(
2022 bool wasDetached = !previous;
2028 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
2029 <<
"' (" << parent <<
")";
2032 <<
"** Insert Block into detached Region (nullptr parent op)";
2035 logger.getOStream() <<
" (was detached)";
2036 logger.getOStream() <<
"\n";
2041 assert(!(
config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
2042 "attempting to insert into a region within a replaced/erased op");
2047 config.listener->notifyBlockInserted(block, previous, previousIt);
2049 patternInsertedBlocks.insert(block);
2053 if (
config.allowPatternRollback) {
2057 appendRewrite<CreateBlockRewrite>(block);
2061 erasedBlocks.erase(block);
2067 if (
config.allowPatternRollback)
2068 appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
2071 void ConversionPatternRewriterImpl::inlineBlockBefore(
Block *source,
2074 appendRewrite<InlineBlockRewrite>(dest, source, before);
2077 void ConversionPatternRewriterImpl::notifyMatchFailure(
2081 reasonCallback(
diag);
2082 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
2083 if (
config.notifyCallback)
2092 ConversionPatternRewriter::ConversionPatternRewriter(
2096 setListener(
impl.get());
2102 return impl->config;
2106 assert(op && newOp &&
"expected non-null op");
2112 "incorrect # of replacement values");
2116 if (getInsertionPoint() == op->getIterator())
2123 impl->replaceOp(op, std::move(newVals));
2129 "incorrect # of replacement values");
2133 if (getInsertionPoint() == op->getIterator())
2136 impl->replaceOp(op, std::move(newValues));
2141 impl->logger.startLine()
2142 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
2147 if (getInsertionPoint() == op->getIterator())
2151 impl->replaceOp(op, std::move(nullRepls));
2155 impl->eraseBlock(block);
2162 "attempting to apply a signature conversion to a block within a "
2163 "replaced/erased op");
2164 return impl->applySignatureConversion(block, converter, conversion);
2171 "attempting to apply a signature conversion to a block within a "
2172 "replaced/erased op");
2173 return impl->convertRegionTypes(region, converter, entryConversion);
2178 impl->logger.startLine() <<
"** Replace Value : '" << from <<
"'";
2179 if (
auto blockArg = dyn_cast<BlockArgument>(from)) {
2181 impl->logger.getOStream() <<
" (in region of '" << parentOp->getName()
2182 <<
"' (" << parentOp <<
")\n";
2184 impl->logger.getOStream() <<
" (unlinked block)\n";
2188 impl->replaceAllUsesWith(from, to,
impl->currentTypeConverter);
2193 if (
failed(
impl->remapValues(
"value", std::nullopt, key,
2196 assert(remappedValues.front().size() == 1 &&
"1:N conversion not supported");
2197 return remappedValues.front().front();
2206 if (
failed(
impl->remapValues(
"value", std::nullopt, keys,
2209 for (
const auto &values : remapped) {
2210 assert(values.size() == 1 &&
"1:N conversion not supported");
2211 results.push_back(values.front());
2221 "incorrect # of argument replacement values");
2223 "attempting to inline a block from a replaced/erased op");
2225 "attempting to inline a block into a replaced/erased op");
2226 auto opIgnored = [&](
Operation *op) {
return impl->isOpIgnored(op); };
2229 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
2230 "expected 'source' to have no predecessors");
2239 bool fastPath = !getConfig().listener;
2241 if (fastPath &&
impl->config.allowPatternRollback)
2242 impl->inlineBlockBefore(source, dest, before);
2245 for (
auto it : llvm::zip(source->
getArguments(), argValues))
2246 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2253 while (!source->
empty())
2254 moveOpBefore(&source->
front(), dest, before);
2259 if (getInsertionBlock() == source)
2260 setInsertionPoint(dest, getInsertionPoint());
2267 if (!
impl->config.allowPatternRollback) {
2272 assert(!
impl->wasOpReplaced(op) &&
2273 "attempting to modify a replaced/erased op");
2275 impl->pendingRootUpdates.insert(op);
2277 impl->appendRewrite<ModifyOperationRewrite>(op);
2281 impl->patternModifiedOps.insert(op);
2282 if (!
impl->config.allowPatternRollback) {
2284 if (getConfig().listener)
2285 getConfig().listener->notifyOperationModified(op);
2292 assert(!
impl->wasOpReplaced(op) &&
2293 "attempting to modify a replaced/erased op");
2294 assert(
impl->pendingRootUpdates.erase(op) &&
2295 "operation did not have a pending in-place update");
2300 if (!
impl->config.allowPatternRollback) {
2305 assert(
impl->pendingRootUpdates.erase(op) &&
2306 "operation did not have a pending in-place update");
2309 auto it = llvm::find_if(
2310 llvm::reverse(
impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
2311 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2312 return modifyRewrite && modifyRewrite->getOperation() == op;
2314 assert(it !=
impl->rewrites.rend() &&
"no root update started on op");
2316 int updateIdx = std::prev(
impl->rewrites.rend()) - it;
2317 impl->rewrites.erase(
impl->rewrites.begin() + updateIdx);
2331 oneToOneOperands.reserve(operands.size());
2333 if (operand.size() != 1)
2336 oneToOneOperands.push_back(operand.front());
2338 return std::move(oneToOneOperands);
2345 auto &rewriterImpl = dialectRewriter.getImpl();
2349 getTypeConverter());
2358 llvm::to_vector_of<ValueRange>(remapped);
2359 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2371 class OperationLegalizer {
2391 LogicalResult legalizeWithFold(
Operation *op);
2395 LogicalResult legalizeWithPattern(
Operation *op);
2403 const RewriterState &curState,
2410 legalizePatternBlockRewrites(
Operation *op,
2426 void buildLegalizationGraph(
2427 LegalizationPatterns &anyOpLegalizerPatterns,
2438 void computeLegalizationGraphBenefit(
2439 LegalizationPatterns &anyOpLegalizerPatterns,
2444 unsigned computeOpLegalizationDepth(
2451 unsigned applyCostModelToPatterns(
2473 : rewriter(rewriter), target(targetInfo), applicator(
patterns) {
2477 LegalizationPatterns anyOpLegalizerPatterns;
2479 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2480 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2483 bool OperationLegalizer::isIllegal(
Operation *op)
const {
2484 return target.isIllegal(op);
2487 LogicalResult OperationLegalizer::legalize(
Operation *op) {
2489 const char *logLineComment =
2490 "//===-------------------------------------------===//\n";
2492 auto &logger = rewriter.getImpl().logger;
2496 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2499 logger.getOStream() <<
"\n";
2500 logger.startLine() << logLineComment;
2501 logger.startLine() <<
"Legalizing operation : ";
2506 logger.getOStream() <<
"'" << op->
getName() <<
"' ";
2507 logger.getOStream() <<
"(" << op <<
") {\n";
2512 logger.startLine() << OpWithFlags(op,
2513 OpPrintingFlags().printGenericOpForm())
2520 logSuccess(logger,
"operation marked 'ignored' during conversion");
2521 logger.startLine() << logLineComment;
2527 if (
auto legalityInfo = target.isLegal(op)) {
2530 logger,
"operation marked legal by the target{0}",
2531 legalityInfo->isRecursivelyLegal
2532 ?
"; NOTE: operation is recursively legal; skipping internals"
2534 logger.startLine() << logLineComment;
2539 if (legalityInfo->isRecursivelyLegal) {
2542 rewriter.getImpl().ignoredOps.
insert(nested);
2553 if (succeeded(legalizeWithFold(op))) {
2556 logger.startLine() << logLineComment;
2563 if (succeeded(legalizeWithPattern(op))) {
2566 logger.startLine() << logLineComment;
2574 if (succeeded(legalizeWithFold(op))) {
2577 logger.startLine() << logLineComment;
2584 logFailure(logger,
"no matched legalization pattern");
2585 logger.startLine() << logLineComment;
2592 template <
typename T>
2594 T result = std::move(obj);
2599 LogicalResult OperationLegalizer::legalizeWithFold(
Operation *op) {
2600 auto &rewriterImpl = rewriter.getImpl();
2602 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2603 rewriterImpl.
logger.indent();
2608 auto cleanup = llvm::make_scope_exit([&]() {
2619 SmallVector<Value, 2> replacementValues;
2620 SmallVector<Operation *, 2> newOps;
2623 if (
failed(rewriter.
tryFold(op, replacementValues, &newOps))) {
2632 if (replacementValues.empty())
2633 return legalize(op);
2636 rewriter.
replaceOp(op, replacementValues);
2640 if (
failed(legalize(newOp))) {
2642 "failed to legalize generated constant '{0}'",
2644 if (!rewriter.getConfig().allowPatternRollback) {
2646 llvm::report_fatal_error(
2648 "' folder rollback of IR modifications requested");
2667 auto newOpNames = llvm::map_range(
2669 auto modifiedOpNames = llvm::map_range(
2671 StringRef detachedBlockStr =
"(detached block)";
2672 auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](
Block *block) {
2675 return detachedBlockStr;
2677 llvm::report_fatal_error(
2679 "' produced IR that could not be legalized. " +
"new ops: {" +
2680 llvm::join(newOpNames,
", ") +
"}, " +
"modified ops: {" +
2681 llvm::join(modifiedOpNames,
", ") +
"}, " +
"inserted block into ops: {" +
2682 llvm::join(insertedBlockNames,
", ") +
"}");
2685 LogicalResult OperationLegalizer::legalizeWithPattern(
Operation *op) {
2686 auto &rewriterImpl = rewriter.getImpl();
2689 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2691 std::optional<OperationFingerPrint> topLevelFingerPrint;
2705 rewriterImpl.
logger.startLine()
2706 <<
"WARNING: Multi-threadeding is enabled. Some dialect "
2707 "conversion expensive checks are skipped in multithreading "
2716 auto canApply = [&](
const Pattern &pattern) {
2717 bool canApply = canApplyPattern(op, pattern);
2718 if (canApply &&
config.listener)
2719 config.listener->notifyPatternBegin(pattern, op);
2725 auto onFailure = [&](
const Pattern &pattern) {
2734 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2738 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2739 llvm::report_fatal_error(
"pattern '" + pattern.getDebugName() +
2740 "' returned failure but IR did change");
2751 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2758 config.listener->notifyPatternEnd(pattern, failure());
2759 rewriterImpl.
resetState(curState, pattern.getDebugName());
2760 appliedPatterns.erase(&pattern);
2765 auto onSuccess = [&](
const Pattern &pattern) {
2782 auto result = legalizePatternResult(op, pattern, curState, newOps,
2783 modifiedOps, insertedBlocks);
2784 appliedPatterns.erase(&pattern);
2789 rewriterImpl.
resetState(curState, pattern.getDebugName());
2792 config.listener->notifyPatternEnd(pattern, result);
2797 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2801 bool OperationLegalizer::canApplyPattern(
Operation *op,
2804 auto &os = rewriter.getImpl().logger;
2805 os.getOStream() <<
"\n";
2806 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2808 os.getOStream() <<
")' {\n";
2815 !appliedPatterns.insert(&pattern).second) {
2817 logFailure(rewriter.getImpl().logger,
"pattern was already applied"));
2823 LogicalResult OperationLegalizer::legalizePatternResult(
2828 [[maybe_unused]]
auto &
impl = rewriter.getImpl();
2829 assert(
impl.pendingRootUpdates.empty() &&
"dangling root updates");
2831 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2833 auto newRewrites = llvm::drop_begin(
impl.rewrites, curState.numRewrites);
2834 auto replacedRoot = [&] {
2835 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2837 auto updatedRootInPlace = [&] {
2838 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2840 if (!replacedRoot() && !updatedRootInPlace())
2841 llvm::report_fatal_error(
2842 "expected pattern to replace the root operation or modify it in place");
2846 if (
failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
2847 failed(legalizePatternRootUpdates(modifiedOps)) ||
2848 failed(legalizePatternCreatedOperations(newOps))) {
2852 LLVM_DEBUG(
logSuccess(
impl.logger,
"pattern applied successfully"));
2856 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2864 for (
Block *block : insertedBlocks) {
2865 if (
impl.erasedBlocks.contains(block))
2875 if (
auto *converter =
impl.regionToConverter.lookup(block->
getParent())) {
2876 std::optional<TypeConverter::SignatureConversion> conversion =
2877 converter->convertBlockSignature(block);
2879 LLVM_DEBUG(
logFailure(
impl.logger,
"failed to convert types of moved "
2883 impl.applySignatureConversion(block, converter, *conversion);
2891 if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
2892 if (
failed(legalize(parentOp))) {
2894 impl.logger,
"operation '{0}'({1}) became illegal after rewrite",
2895 parentOp->
getName(), parentOp));
2903 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2906 if (
failed(legalize(op))) {
2907 LLVM_DEBUG(
logFailure(rewriter.getImpl().logger,
2908 "failed to legalize generated operation '{0}'({1})",
2916 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2919 if (
failed(legalize(op))) {
2922 "failed to legalize operation updated in-place '{0}'",
2934 void OperationLegalizer::buildLegalizationGraph(
2935 LegalizationPatterns &anyOpLegalizerPatterns,
2946 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2947 std::optional<OperationName> root = pattern.
getRootKind();
2953 anyOpLegalizerPatterns.push_back(&pattern);
2958 if (target.getOpAction(*root) == LegalizationAction::Legal)
2963 invalidPatterns[*root].insert(&pattern);
2965 parentOps[op].insert(*root);
2968 patternWorklist.insert(&pattern);
2976 if (!anyOpLegalizerPatterns.empty()) {
2977 for (
const Pattern *pattern : patternWorklist)
2978 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2982 while (!patternWorklist.empty()) {
2983 auto *pattern = patternWorklist.pop_back_val();
2987 std::optional<LegalizationAction> action = target.getOpAction(op);
2988 return !legalizerPatterns.count(op) &&
2989 (!action || action == LegalizationAction::Illegal);
2995 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2996 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
3000 for (
auto op : parentOps[*pattern->
getRootKind()])
3001 patternWorklist.set_union(invalidPatterns[op]);
3005 void OperationLegalizer::computeLegalizationGraphBenefit(
3006 LegalizationPatterns &anyOpLegalizerPatterns,
3012 for (
auto &opIt : legalizerPatterns)
3013 if (!minOpPatternDepth.count(opIt.first))
3014 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3020 if (!anyOpLegalizerPatterns.empty())
3021 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3027 applicator.applyCostModel([&](
const Pattern &pattern) {
3029 if (std::optional<OperationName> rootName = pattern.
getRootKind())
3030 orderedPatternList = legalizerPatterns[*rootName];
3032 orderedPatternList = anyOpLegalizerPatterns;
3035 auto *it = llvm::find(orderedPatternList, &pattern);
3036 if (it == orderedPatternList.end())
3040 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3044 unsigned OperationLegalizer::computeOpLegalizationDepth(
3048 auto depthIt = minOpPatternDepth.find(op);
3049 if (depthIt != minOpPatternDepth.end())
3050 return depthIt->second;
3054 auto opPatternsIt = legalizerPatterns.find(op);
3055 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3064 unsigned minDepth = applyCostModelToPatterns(
3065 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3066 minOpPatternDepth[op] = minDepth;
3070 unsigned OperationLegalizer::applyCostModelToPatterns(
3077 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3078 patternsByDepth.reserve(
patterns.size());
3082 unsigned generatedOpDepth = computeOpLegalizationDepth(
3083 generatedOp, minOpPatternDepth, legalizerPatterns);
3084 depth =
std::max(depth, generatedOpDepth + 1);
3086 patternsByDepth.emplace_back(pattern, depth);
3089 minDepth =
std::min(minDepth, depth);
3094 if (patternsByDepth.size() == 1)
3098 llvm::stable_sort(patternsByDepth,
3099 [](
const std::pair<const Pattern *, unsigned> &lhs,
3100 const std::pair<const Pattern *, unsigned> &rhs) {
3103 if (lhs.second != rhs.second)
3104 return lhs.second < rhs.second;
3107 auto lhsBenefit = lhs.first->getBenefit();
3108 auto rhsBenefit = rhs.first->getBenefit();
3109 return lhsBenefit > rhsBenefit;
3114 for (
auto &patternIt : patternsByDepth)
3115 patterns.push_back(patternIt.first);
3129 template <
typename RangeT>
3132 function_ref<
bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3141 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3142 if (castOp.getInputs().empty())
3145 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3148 if (inputCastOp.getOutputs() != castOp.getInputs())
3154 while (!worklist.empty()) {
3155 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3159 UnrealizedConversionCastOp nextCast = castOp;
3161 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3162 if (llvm::any_of(nextCast.getInputs(), [&](
Value v) {
3163 return v.getDefiningOp() == castOp;
3171 castOp.replaceAllUsesWith(nextCast.getInputs());
3174 nextCast = getInputCast(nextCast);
3184 auto markOpLive = [&](
Operation *rootOp) {
3185 SmallVector<Operation *> worklist;
3186 worklist.push_back(rootOp);
3187 while (!worklist.empty()) {
3188 Operation *op = worklist.pop_back_val();
3189 if (liveOps.insert(op).second) {
3192 if (
auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3193 if (isCastOpOfInterestFn(castOp))
3194 worklist.push_back(castOp);
3200 for (UnrealizedConversionCastOp op : castOps) {
3203 if (liveOps.contains(op.getOperation()))
3207 if (llvm::any_of(op->getUsers(), [&](
Operation *user) {
3208 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3209 return !castOp || !isCastOpOfInterestFn(castOp);
3215 for (UnrealizedConversionCastOp op : castOps) {
3216 if (liveOps.contains(op)) {
3218 if (remainingCastOps)
3219 remainingCastOps->push_back(op);
3234 for (UnrealizedConversionCastOp op : castOps)
3235 castOpSet.insert(op);
3243 llvm::make_range(castOps.begin(), castOps.end()),
3244 [&](UnrealizedConversionCastOp castOp) {
3245 return castOps.contains(castOp);
3257 [&](UnrealizedConversionCastOp castOp) {
3258 return castOps.contains(castOp);
3269 enum OpConversionMode {
3292 OpConversionMode mode)
3293 : rewriter(ctx,
config), opLegalizer(rewriter, target,
patterns),
3307 OperationLegalizer opLegalizer;
3310 OpConversionMode mode;
3314 LogicalResult OperationConverter::convert(
Operation *op) {
3318 if (
failed(opLegalizer.legalize(op))) {
3321 if (mode == OpConversionMode::Full)
3323 <<
"failed to legalize operation '" << op->
getName() <<
"'";
3327 if (mode == OpConversionMode::Partial) {
3328 if (opLegalizer.isIllegal(op))
3330 <<
"failed to legalize operation '" << op->
getName()
3331 <<
"' that was explicitly marked illegal";
3332 if (
config.unlegalizedOps)
3333 config.unlegalizedOps->insert(op);
3335 }
else if (mode == OpConversionMode::Analysis) {
3339 if (
config.legalizableOps)
3340 config.legalizableOps->insert(op);
3345 static LogicalResult
3347 UnrealizedConversionCastOp op,
3348 const UnresolvedMaterializationInfo &info) {
3349 assert(!op.use_empty() &&
3350 "expected that dead materializations have already been DCE'd");
3356 SmallVector<Value> newMaterialization;
3357 switch (info.getMaterializationKind()) {
3358 case MaterializationKind::Target:
3359 newMaterialization = converter->materializeTargetConversion(
3360 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3361 info.getOriginalType());
3363 case MaterializationKind::Source:
3364 assert(op->getNumResults() == 1 &&
"expected single result");
3365 Value sourceMat = converter->materializeSourceConversion(
3366 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3368 newMaterialization.push_back(sourceMat);
3371 if (!newMaterialization.empty()) {
3373 ValueRange newMaterializationRange(newMaterialization);
3374 assert(
TypeRange(newMaterializationRange) == op.getResultTypes() &&
3375 "materialization callback produced value of incorrect type");
3377 rewriter.
replaceOp(op, newMaterialization);
3383 <<
"failed to legalize unresolved materialization "
3385 << inputOperands.
getTypes() <<
") to ("
3386 << op.getResultTypes()
3387 <<
") that remained live after conversion";
3388 diag.attachNote(op->getUsers().begin()->getLoc())
3389 <<
"see existing live user here: " << *op->getUsers().begin();
3398 for (
auto *op : ops) {
3401 toConvert.push_back(op);
3404 auto legalityInfo = target.
isLegal(op);
3405 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3414 for (
auto *op : toConvert) {
3415 if (
failed(convert(op))) {
3441 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3445 if (rewriter.getConfig().buildMaterializations) {
3450 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3451 auto it = materializations.find(castOp);
3452 assert(it != materializations.end() &&
"inconsistent state");
3468 assert(!types.empty() &&
"expected valid types");
3469 remapInput(origInputNo, argTypes.size(), types.size());
3474 assert(!types.empty() &&
3475 "1->0 type remappings don't need to be added explicitly");
3476 argTypes.append(types.begin(), types.end());
3480 unsigned newInputNo,
3481 unsigned newInputCount) {
3482 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3483 assert(newInputCount != 0 &&
"expected valid input count");
3484 remappedInputs[origInputNo] =
3485 InputMapping{newInputNo, newInputCount, {}};
3490 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3506 assert(typeOrValue &&
"expected non-null type");
3507 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3508 : cast<Type>(typeOrValue);
3510 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3513 cacheReadLock.lock();
3514 auto existingIt = cachedDirectConversions.find(t);
3515 if (existingIt != cachedDirectConversions.end()) {
3516 if (existingIt->second)
3517 results.push_back(existingIt->second);
3518 return success(existingIt->second !=
nullptr);
3520 auto multiIt = cachedMultiConversions.find(t);
3521 if (multiIt != cachedMultiConversions.end()) {
3522 results.append(multiIt->second.begin(), multiIt->second.end());
3528 size_t currentCount = results.size();
3532 auto isCacheable = [&](
int index) {
3533 int numberOfConversionsUntilContextAware =
3534 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3535 return index < numberOfConversionsUntilContextAware;
3538 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3541 for (
auto indexedConverter :
llvm::enumerate(llvm::reverse(conversions))) {
3542 const ConversionCallbackFn &converter = indexedConverter.value();
3543 std::optional<LogicalResult> result = converter(typeOrValue, results);
3545 assert(results.size() == currentCount &&
3546 "failed type conversion should not change results");
3549 if (!isCacheable(indexedConverter.index()))
3552 cacheWriteLock.lock();
3553 if (!succeeded(*result)) {
3554 assert(results.size() == currentCount &&
3555 "failed type conversion should not change results");
3556 cachedDirectConversions.try_emplace(t,
nullptr);
3559 auto newTypes =
ArrayRef<Type>(results).drop_front(currentCount);
3560 if (newTypes.size() == 1)
3561 cachedDirectConversions.try_emplace(t, newTypes.front());
3563 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3571 return convertTypeImpl(t, results);
3576 return convertTypeImpl(v, results);
3586 return results.size() == 1 ? results.front() :
nullptr;
3596 return results.size() == 1 ? results.front() :
nullptr;
3602 for (
Type type : types)
3611 for (
Value value : values)
3630 return llvm::all_of(
3637 if (!
isLegal(ty.getResults()))
3651 if (convertedTypes.empty())
3655 result.
addInputs(inputNo, convertedTypes);
3661 unsigned origInputOffset)
const {
3662 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3676 if (convertedTypes.empty())
3680 result.
addInputs(inputNo, convertedTypes);
3686 unsigned origInputOffset)
const {
3687 for (
unsigned i = 0, e = values.size(); i != e; ++i)
3696 for (
const SourceMaterializationCallbackFn &fn :
3697 llvm::reverse(sourceMaterializations))
3698 if (
Value result = fn(builder, resultType, inputs, loc))
3706 Type originalType)
const {
3708 builder, loc,
TypeRange(resultType), inputs, originalType);
3711 assert(result.size() == 1 &&
"expected single result");
3712 return result.front();
3717 Type originalType)
const {
3718 for (
const TargetMaterializationCallbackFn &fn :
3719 llvm::reverse(targetMaterializations)) {
3721 fn(builder, resultTypes, inputs, loc, originalType);
3725 "callback produced incorrect number of values or values with "
3732 std::optional<TypeConverter::SignatureConversion>
3736 return std::nullopt;
3759 return impl.getInt() == resultTag;
3763 return impl.getInt() == naTag;
3767 return impl.getInt() == abortTag;
3771 assert(hasResult() &&
"Cannot get result from N/A or abort");
3772 return impl.getPointer();
3775 std::optional<Attribute>
3777 for (
const TypeAttributeConversionCallbackFn &fn :
3778 llvm::reverse(typeAttributeConversions)) {
3783 return std::nullopt;
3785 return std::nullopt;
3795 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3801 SmallVector<Type, 1> newResults;
3805 typeConverter, &result)))
3822 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3831 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3836 struct AnyFunctionOpInterfaceSignatureConversion
3848 FailureOr<Operation *>
3852 assert(op &&
"Invalid op");
3866 return rewriter.
create(newOp);
3872 patterns.add<FunctionOpInterfaceSignatureConversion>(
3873 functionLikeOpName,
patterns.getContext(), converter, benefit);
3879 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3880 converter,
patterns.getContext(), benefit);
3889 legalOperations[op].action = action;
3894 for (StringRef dialect : dialectNames)
3895 legalDialects[dialect] = action;
3899 -> std::optional<LegalizationAction> {
3900 std::optional<LegalizationInfo> info = getOpInfo(op);
3901 return info ? info->action : std::optional<LegalizationAction>();
3905 -> std::optional<LegalOpDetails> {
3906 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3908 return std::nullopt;
3911 auto isOpLegal = [&] {
3913 if (info->action == LegalizationAction::Dynamic) {
3914 std::optional<bool> result = info->legalityFn(op);
3920 return info->action == LegalizationAction::Legal;
3923 return std::nullopt;
3927 if (info->isRecursivelyLegal) {
3928 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3929 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3931 legalityFnIt->second(op).value_or(
true);
3936 return legalityDetails;
3940 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3944 if (info->action == LegalizationAction::Dynamic) {
3945 std::optional<bool> result = info->legalityFn(op);
3952 return info->action == LegalizationAction::Illegal;
3961 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3963 if (std::optional<bool> result = newCl(op))
3971 void ConversionTarget::setLegalityCallback(
3972 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3973 assert(callback &&
"expected valid legality callback");
3974 auto *infoIt = legalOperations.find(name);
3975 assert(infoIt != legalOperations.end() &&
3976 infoIt->second.action == LegalizationAction::Dynamic &&
3977 "expected operation to already be marked as dynamically legal");
3978 infoIt->second.legalityFn =
3984 auto *infoIt = legalOperations.find(name);
3985 assert(infoIt != legalOperations.end() &&
3986 infoIt->second.action != LegalizationAction::Illegal &&
3987 "expected operation to already be marked as legal");
3988 infoIt->second.isRecursivelyLegal =
true;
3991 std::move(opRecursiveLegalityFns[name]), callback);
3993 opRecursiveLegalityFns.erase(name);
3996 void ConversionTarget::setLegalityCallback(
3998 assert(callback &&
"expected valid legality callback");
3999 for (StringRef dialect : dialects)
4001 std::move(dialectLegalityFns[dialect]), callback);
4004 void ConversionTarget::setLegalityCallback(
4005 const DynamicLegalityCallbackFn &callback) {
4006 assert(callback &&
"expected valid legality callback");
4011 -> std::optional<LegalizationInfo> {
4013 const auto *it = legalOperations.find(op);
4014 if (it != legalOperations.end())
4017 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4018 if (dialectIt != legalDialects.end()) {
4019 DynamicLegalityCallbackFn callback;
4020 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4021 if (dialectFn != dialectLegalityFns.end())
4022 callback = dialectFn->second;
4023 return LegalizationInfo{dialectIt->second,
false,
4027 if (unknownLegalityFn)
4028 return LegalizationInfo{LegalizationAction::Dynamic,
4029 false, unknownLegalityFn};
4030 return std::nullopt;
4033 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4039 auto &rewriterImpl =
4045 auto &rewriterImpl =
4052 static FailureOr<SmallVector<Value>>
4054 SmallVector<Value> mappedValues;
4057 return std::move(mappedValues);
4061 patterns.getPDLPatterns().registerRewriteFunction(
4068 return results->front();
4070 patterns.getPDLPatterns().registerRewriteFunction(
4075 patterns.getPDLPatterns().registerRewriteFunction(
4078 auto &rewriterImpl =
4082 if (
Type newType = converter->convertType(type))
4088 patterns.getPDLPatterns().registerRewriteFunction(
4091 TypeRange types) -> FailureOr<SmallVector<Type>> {
4092 auto &rewriterImpl =
4101 return std::move(remappedTypes);
4116 static constexpr StringLiteral tag =
"apply-conversion";
4117 static constexpr StringLiteral desc =
4118 "Encapsulate the application of a dialect conversion";
4120 void print(raw_ostream &os)
const override { os << tag; }
4127 OpConversionMode mode) {
4131 LogicalResult status = success();
4132 SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
4151 OpConversionMode::Partial);
4191 "expected top-level op to be isolated from above");
4194 "expected ops to have a common ancestor");
4203 for (
Operation *op : ops.drop_front()) {
4207 assert(commonAncestor &&
4208 "expected to find a common isolated from above ancestor");
4212 return commonAncestor;
4219 if (
config.legalizableOps)
4220 assert(
config.legalizableOps->empty() &&
"expected empty set");
4230 inverseOperationMap[it.second] = it.first;
4236 OpConversionMode::Analysis);
4240 if (
config.legalizableOps) {
4243 originalLegalizableOps.insert(inverseOperationMap[op]);
4244 *
config.legalizableOps = std::move(originalLegalizableOps);
4248 clonedAncestor->
erase();
static void setInsertionPointAfter(OpBuilder &b, Value value)
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value >> &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps, const SetVector< Block * > &insertedBlocks)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1247::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This is the type of Action that is dispatched when a conversion is applied.
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
const ConversionConfig & getConfig() const
Return the configuration of the current dialect conversion.
void replaceAllUsesWith(Value from, ValueRange to)
Replace all the uses of from with to.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={}) override
PatternRewriter hook for inlining the ops of a block into another block.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
~ConversionPatternRewriter() override
Base class for the conversion patterns.
FailureOr< SmallVector< Value > > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Listener * listener
The optional listener for events of this builder.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
type_range getTypes() const
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
int getPropertiesStorageSize() const
Returns the properties storage size.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
The general result of a type attribute conversion callback, allowing for early termination.
Attribute getResult() const
static AttributeConversionResult abort()
static AttributeConversionResult na()
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
CRTP Implementation of an action.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
@ AfterPatterns
Only attempt to fold not legal operations after applying patterns.
@ BeforePatterns
Only attempt to fold not legal operations before applying patterns.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
void reconcileUnrealizedCasts(const DenseSet< UnrealizedConversionCastOp > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
bool allowPatternRollback
If set to "true", pattern rollback is allowed.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
bool attachDebugMaterializationKind
If set to "true", the materialization kind ("source" or "target") will be attached to "builtin....
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
bool wasErased(void *ptr) const
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config)
void replaceAllUsesWith(Value from, ValueRange to, const TypeConverter *converter)
Replace the uses of the given value with the given values.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
SetVector< Block * > patternInsertedBlocks
A set of blocks that were inserted (newly-created blocks or moved blocks) by the current pattern.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
void replaceOp(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the results of the given operation with the given values and erase the operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.