10#include "mlir/Config/mlir-config.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/SaveAndRestore.h"
27#include "llvm/Support/ScopedPrinter.h"
34#define DEBUG_TYPE "dialect-conversion"
37template <
typename... Args>
38static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
41 os.startLine() <<
"} -> SUCCESS";
43 os.getOStream() <<
" : "
44 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
45 os.getOStream() <<
"\n";
50template <
typename... Args>
51static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
54 os.startLine() <<
"} -> FAILURE : "
55 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
65 if (
OpResult inputRes = dyn_cast<OpResult>(value))
66 insertPt = ++inputRes.getOwner()->getIterator();
73 assert(!vals.empty() &&
"expected at least one value");
76 for (
Value v : vals.drop_front()) {
90 assert(dom &&
"unable to find valid insertion point");
98enum OpConversionMode {
125struct ValueVectorMapInfo {
128 static ::llvm::hash_code getHashValue(
const ValueVector &val) {
129 return ::llvm::hash_combine_range(val);
138struct ConversionValueMapping {
141 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
146 template <
typename T>
147 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
150 template <
typename OldVal,
typename NewVal>
151 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
152 map(OldVal &&oldVal, NewVal &&newVal) {
156 assert(next != oldVal &&
"inserting cyclic mapping");
157 auto it = mapping.find(next);
158 if (it == mapping.end())
163 mappedTo.insert_range(newVal);
165 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
169 template <
typename OldVal,
typename NewVal>
170 std::enable_if_t<!IsValueVector<OldVal>::value ||
171 !IsValueVector<NewVal>::value>
172 map(OldVal &&oldVal, NewVal &&newVal) {
173 if constexpr (IsValueVector<OldVal>{}) {
174 map(std::forward<OldVal>(oldVal),
ValueVector{newVal});
175 }
else if constexpr (IsValueVector<NewVal>{}) {
176 map(
ValueVector{oldVal}, std::forward<NewVal>(newVal));
187 void erase(
const ValueVector &value) { mapping.erase(value); }
207 assert(!values.empty() &&
"expected non-empty value vector");
208 Operation *op = values.front().getDefiningOp();
209 for (
Value v : llvm::drop_begin(values)) {
210 if (v.getDefiningOp() != op)
220 assert(!values.empty() &&
"expected non-empty value vector");
226 auto it = mapping.find(from);
227 if (it == mapping.end()) {
240struct RewriterState {
241 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
242 unsigned numReplacedOps)
243 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
244 numReplacedOps(numReplacedOps) {}
247 unsigned numRewrites;
250 unsigned numIgnoredOperations;
253 unsigned numReplacedOps;
260static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
263static void notifyIRErased(RewriterBase::Listener *listener,
Block &
b) {
264 for (Operation &op :
b)
265 notifyIRErased(listener, op);
271static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
274 notifyIRErased(listener,
b);
304 UnresolvedMaterialization,
309 virtual ~IRRewrite() =
default;
312 virtual void rollback() = 0;
326 virtual void commit(RewriterBase &rewriter) {}
329 virtual void cleanup(RewriterBase &rewriter) {}
331 Kind getKind()
const {
return kind; }
333 static bool classof(
const IRRewrite *
rewrite) {
return true; }
336 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
337 : kind(kind), rewriterImpl(rewriterImpl) {}
339 const ConversionConfig &getConfig()
const;
342 ConversionPatternRewriterImpl &rewriterImpl;
346class BlockRewrite :
public IRRewrite {
349 Block *getBlock()
const {
return block; }
351 static bool classof(
const IRRewrite *
rewrite) {
352 return rewrite->getKind() >= Kind::CreateBlock &&
353 rewrite->getKind() <= Kind::BlockTypeConversion;
357 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
359 : IRRewrite(kind, rewriterImpl), block(block) {}
366class ValueRewrite :
public IRRewrite {
369 Value getValue()
const {
return value; }
371 static bool classof(
const IRRewrite *
rewrite) {
372 return rewrite->getKind() == Kind::ReplaceValue;
376 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
378 : IRRewrite(kind, rewriterImpl), value(value) {}
387class CreateBlockRewrite :
public BlockRewrite {
389 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
390 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
392 static bool classof(
const IRRewrite *
rewrite) {
393 return rewrite->getKind() == Kind::CreateBlock;
396 void commit(RewriterBase &rewriter)
override {
402 void rollback()
override {
405 auto &blockOps = block->getOperations();
406 while (!blockOps.empty())
407 blockOps.remove(blockOps.begin());
408 block->dropAllUses();
409 if (block->getParent())
420class EraseBlockRewrite :
public BlockRewrite {
422 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
423 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
424 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
426 static bool classof(
const IRRewrite *
rewrite) {
427 return rewrite->getKind() == Kind::EraseBlock;
430 ~EraseBlockRewrite()
override {
432 "rewrite was neither rolled back nor committed/cleaned up");
435 void rollback()
override {
438 assert(block &&
"expected block");
443 blockList.insert(before, block);
447 void commit(RewriterBase &rewriter)
override {
448 assert(block &&
"expected block");
452 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
453 notifyIRErased(listener, *block);
456 void cleanup(RewriterBase &rewriter)
override {
458 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
460 assert(block->empty() &&
"expected empty block");
463 block->dropAllDefinedValueUses();
474 Block *insertBeforeBlock;
480class InlineBlockRewrite :
public BlockRewrite {
482 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
484 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
485 sourceBlock(sourceBlock),
486 firstInlinedInst(sourceBlock->empty() ?
nullptr
487 : &sourceBlock->front()),
488 lastInlinedInst(sourceBlock->empty() ?
nullptr : &sourceBlock->back()) {
494 assert(!getConfig().listener &&
495 "InlineBlockRewrite not supported if listener is attached");
498 static bool classof(
const IRRewrite *
rewrite) {
499 return rewrite->getKind() == Kind::InlineBlock;
502 void rollback()
override {
505 if (firstInlinedInst) {
506 assert(lastInlinedInst &&
"expected operation");
519 Operation *firstInlinedInst;
522 Operation *lastInlinedInst;
526class MoveBlockRewrite :
public BlockRewrite {
528 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
530 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block),
531 region(previousRegion),
532 insertBeforeBlock(previousIt == previousRegion->end() ?
nullptr
535 static bool classof(
const IRRewrite *
rewrite) {
536 return rewrite->getKind() == Kind::MoveBlock;
539 void commit(RewriterBase &rewriter)
override {
549 void rollback()
override {
553 if (Region *currentParent = block->
getParent()) {
555 region->getBlocks().splice(before, currentParent->getBlocks(), block);
559 region->
getBlocks().insert(before, block);
568 Block *insertBeforeBlock;
572class BlockTypeConversionRewrite :
public BlockRewrite {
574 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
576 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, origBlock),
577 newBlock(newBlock) {}
579 static bool classof(
const IRRewrite *
rewrite) {
580 return rewrite->getKind() == Kind::BlockTypeConversion;
583 Block *getOrigBlock()
const {
return block; }
585 Block *getNewBlock()
const {
return newBlock; }
587 void commit(RewriterBase &rewriter)
override;
589 void rollback()
override;
599class ReplaceValueRewrite :
public ValueRewrite {
601 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
602 const TypeConverter *converter)
603 : ValueRewrite(
Kind::ReplaceValue, rewriterImpl, value),
604 converter(converter) {}
606 static bool classof(
const IRRewrite *
rewrite) {
607 return rewrite->getKind() == Kind::ReplaceValue;
610 void commit(RewriterBase &rewriter)
override;
612 void rollback()
override;
616 const TypeConverter *converter;
620class OperationRewrite :
public IRRewrite {
623 Operation *getOperation()
const {
return op; }
625 static bool classof(
const IRRewrite *
rewrite) {
626 return rewrite->getKind() >= Kind::MoveOperation &&
627 rewrite->getKind() <= Kind::UnresolvedMaterialization;
631 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
633 : IRRewrite(kind, rewriterImpl), op(op) {}
640class MoveOperationRewrite :
public OperationRewrite {
642 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
643 Operation *op, OpBuilder::InsertPoint previous)
644 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op),
645 block(previous.getBlock()),
646 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
648 : &*previous.getPoint()) {}
650 static bool classof(
const IRRewrite *
rewrite) {
651 return rewrite->getKind() == Kind::MoveOperation;
654 void commit(RewriterBase &rewriter)
override {
660 op, OpBuilder::InsertPoint(block,
665 void rollback()
override {
678 Operation *insertBeforeOp;
683class ModifyOperationRewrite :
public OperationRewrite {
685 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
687 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
688 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
689 operands(op->operand_begin(), op->operand_end()),
690 successors(op->successor_begin(), op->successor_end()) {
693 propertiesStorage = operator new(op->getPropertiesStorageSize());
694 PropertyRef propCopy(name.getOpPropertiesTypeID(), propertiesStorage);
695 name.initOpProperties(propCopy, prop);
699 static bool classof(
const IRRewrite *
rewrite) {
700 return rewrite->getKind() == Kind::ModifyOperation;
703 ~ModifyOperationRewrite()
override {
704 assert(!propertiesStorage &&
705 "rewrite was neither committed nor rolled back");
708 void commit(RewriterBase &rewriter)
override {
711 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
714 if (propertiesStorage) {
719 operator delete(propertiesStorage);
720 propertiesStorage =
nullptr;
724 void rollback()
override {
728 for (
const auto &it : llvm::enumerate(successors))
730 if (propertiesStorage) {
734 operator delete(propertiesStorage);
735 propertiesStorage =
nullptr;
742 DictionaryAttr attrs;
743 SmallVector<Value, 8> operands;
744 SmallVector<Block *, 2> successors;
745 void *propertiesStorage =
nullptr;
752class ReplaceOperationRewrite :
public OperationRewrite {
754 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
755 Operation *op,
const TypeConverter *converter)
756 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
757 converter(converter) {}
759 static bool classof(
const IRRewrite *
rewrite) {
760 return rewrite->getKind() == Kind::ReplaceOperation;
763 void commit(RewriterBase &rewriter)
override;
765 void rollback()
override;
767 void cleanup(RewriterBase &rewriter)
override;
772 const TypeConverter *converter;
775class CreateOperationRewrite :
public OperationRewrite {
777 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
779 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
781 static bool classof(
const IRRewrite *
rewrite) {
782 return rewrite->getKind() == Kind::CreateOperation;
785 void commit(RewriterBase &rewriter)
override {
791 void rollback()
override;
795enum MaterializationKind {
806class UnresolvedMaterializationInfo {
808 UnresolvedMaterializationInfo() =
default;
809 UnresolvedMaterializationInfo(
const TypeConverter *converter,
810 MaterializationKind kind, Type originalType)
811 : converterAndKind(converter, kind), originalType(originalType) {}
814 const TypeConverter *getConverter()
const {
815 return converterAndKind.getPointer();
819 MaterializationKind getMaterializationKind()
const {
820 return converterAndKind.getInt();
824 Type getOriginalType()
const {
return originalType; }
829 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
840class UnresolvedMaterializationRewrite :
public OperationRewrite {
842 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
843 UnrealizedConversionCastOp op,
845 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
846 mappedValues(std::move(mappedValues)) {}
848 static bool classof(
const IRRewrite *
rewrite) {
849 return rewrite->getKind() == Kind::UnresolvedMaterialization;
852 void rollback()
override;
854 UnrealizedConversionCastOp getOperation()
const {
855 return cast<UnrealizedConversionCastOp>(op);
865#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
868template <
typename RewriteTy,
typename R>
869static bool hasRewrite(R &&rewrites, Operation *op) {
870 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
871 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
872 return rewriteTy && rewriteTy->getOperation() == op;
878template <
typename RewriteTy,
typename R>
879static bool hasRewrite(R &&rewrites,
Block *block) {
880 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
881 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
882 return rewriteTy && rewriteTy->getBlock() == block;
894 const ConversionConfig &
config,
904 RewriterState getCurrentState();
908 void applyRewrites();
913 void resetState(RewriterState state, StringRef patternName =
"");
917 template <
typename RewriteTy,
typename... Args>
919 assert(
config.allowPatternRollback &&
"appending rewrites is not allowed");
921 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
927 void undoRewrites(
unsigned numRewritesToKeep = 0, StringRef patternName =
"");
933 LogicalResult remapValues(StringRef valueDiagTag,
934 std::optional<Location> inputLoc,
ValueRange values,
951 bool skipPureTypeConversions =
false)
const;
965 TypeConverter::SignatureConversion *entryConversion);
973 Block *applySignatureConversion(
975 TypeConverter::SignatureConversion &signatureConversion);
995 void eraseBlock(
Block *block);
1033 Value findOrBuildReplacementValue(
Value value,
1041 void notifyOperationInserted(
Operation *op,
1045 void notifyBlockInserted(
Block *block,
Region *previous,
1064 std::function<
void(
Operation *)> opErasedCallback =
nullptr)
1066 opErasedCallback(std::move(opErasedCallback)) {}
1080 assert(block->empty() &&
"expected empty block");
1081 block->dropAllDefinedValueUses();
1089 if (opErasedCallback)
1090 opErasedCallback(op);
1198const ConversionConfig &IRRewrite::getConfig()
const {
1199 return rewriterImpl.
config;
1202void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1206 if (
auto *listener =
1207 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1208 for (Operation *op : getNewBlock()->getUsers())
1212void BlockTypeConversionRewrite::rollback() {
1213 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1220 if (isa<BlockArgument>(repl)) {
1260 result &= functor(operand);
1265void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1272void ReplaceValueRewrite::rollback() {
1273 rewriterImpl.
mapping.erase({value});
1279void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1281 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1284 SmallVector<Value> replacements =
1286 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1294 for (
auto [
result, newValue] :
1295 llvm::zip_equal(op->
getResults(), replacements))
1301 if (getConfig().unlegalizedOps)
1302 getConfig().unlegalizedOps->erase(op);
1306 notifyIRErased(listener, *op);
1311 llvm::reportFatalInternalError(
1312 "dialect conversion attempted to replace a root operation that has no "
1313 "parent block; the pass must ensure its target op is nested in a "
1318void ReplaceOperationRewrite::rollback() {
1323void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1327void CreateOperationRewrite::rollback() {
1329 while (!region.getBlocks().empty())
1330 region.getBlocks().remove(region.getBlocks().begin());
1336void UnresolvedMaterializationRewrite::rollback() {
1337 if (!mappedValues.empty())
1338 rewriterImpl.
mapping.erase(mappedValues);
1349 for (
size_t i = 0; i <
rewrites.size(); ++i)
1355 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1356 unresolvedMaterializations.erase(castOp);
1359 rewrite->cleanup(eraseRewriter);
1367 Value from,
TypeRange desiredTypes,
bool skipPureTypeConversions)
const {
1370 assert(!values.empty() &&
"expected non-empty value vector");
1374 if (
config.allowPatternRollback)
1375 return mapping.lookup(values);
1382 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1387 if (castOp.getOutputs() != values)
1389 return castOp.getInputs();
1398 for (
Value v : values) {
1401 llvm::append_range(next, r);
1406 if (next != values) {
1435 if (skipPureTypeConversions) {
1438 match &= !pureConversion;
1441 if (!pureConversion)
1442 lastNonMaterialization = current;
1445 desiredValue = current;
1451 current = std::move(next);
1456 if (!desiredTypes.empty())
1457 return desiredValue;
1458 if (skipPureTypeConversions)
1459 return lastNonMaterialization;
1478 StringRef patternName) {
1483 while (
ignoredOps.size() != state.numIgnoredOperations)
1486 while (
replacedOps.size() != state.numReplacedOps)
1491 StringRef patternName) {
1493 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1495 rewrites.resize(numRewritesToKeep);
1499 StringRef valueDiagTag, std::optional<Location> inputLoc,
ValueRange values,
1501 remapped.reserve(llvm::size(values));
1503 for (
const auto &it : llvm::enumerate(values)) {
1504 Value operand = it.value();
1523 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1524 << it.index() <<
", type was " << origType;
1529 if (legalTypes.empty()) {
1530 remapped.push_back({});
1539 remapped.push_back(std::move(repl));
1548 repl, repl, legalTypes,
1550 remapped.push_back(castValues);
1571 TypeConverter::SignatureConversion *entryConversion) {
1573 if (region->
empty())
1578 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1580 std::optional<TypeConverter::SignatureConversion> conversion =
1581 converter.convertBlockSignature(&block);
1590 if (entryConversion)
1593 std::optional<TypeConverter::SignatureConversion> conversion =
1594 converter.convertBlockSignature(®ion->
front());
1602 TypeConverter::SignatureConversion &signatureConversion) {
1603#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1605 if (hasRewrite<BlockTypeConversionRewrite>(
rewrites, block))
1606 llvm::reportFatalInternalError(
"block was already converted");
1613 auto convertedTypes = signatureConversion.getConvertedTypes();
1620 for (
unsigned i = 0; i < origArgCount; ++i) {
1621 auto inputMap = signatureConversion.getInputMapping(i);
1622 if (!inputMap || inputMap->replacedWithValues())
1625 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1626 newLocs[inputMap->inputNo +
j] = origLoc;
1633 convertedTypes, newLocs);
1641 bool fastPath = !
config.listener;
1643 if (
config.allowPatternRollback)
1647 while (!block->
empty())
1654 for (
unsigned i = 0; i != origArgCount; ++i) {
1658 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1659 signatureConversion.getInputMapping(i);
1667 MaterializationKind::Source,
1671 origArgType,
Type(), converter,
1678 if (inputMap->replacedWithValues()) {
1680 assert(inputMap->size == 0 &&
1681 "invalid to provide a replacement value when the argument isn't "
1689 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1693 if (
config.allowPatternRollback)
1714 assert((!originalType || kind == MaterializationKind::Target) &&
1715 "original type is valid only for target materializations");
1716 assert(
TypeRange(inputs) != outputTypes &&
1717 "materialization is not necessary");
1721 OpBuilder builder(outputTypes.front().getContext());
1723 UnrealizedConversionCastOp convertOp =
1724 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1725 if (
config.attachDebugMaterializationKind) {
1727 kind == MaterializationKind::Source ?
"source" :
"target";
1728 convertOp->setAttr(
"__kind__", builder.
getStringAttr(kindStr));
1735 UnresolvedMaterializationInfo(converter, kind, originalType);
1736 if (
config.allowPatternRollback) {
1737 if (!valuesToMap.empty())
1738 mapping.map(valuesToMap, convertOp.getResults());
1740 std::move(valuesToMap));
1744 return convertOp.getResults();
1749 assert(
config.allowPatternRollback &&
1750 "this code path is valid only in rollback mode");
1757 return repl.front();
1764 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1789 MaterializationKind::Source, ip, value.
getLoc(),
1805 bool wasDetached = !previous.
isSet();
1807 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"' (" << op
1810 logger.getOStream() <<
" (was detached)";
1811 logger.getOStream() <<
"\n";
1817 "attempting to insert into a block within a replaced/erased op");
1821 config.listener->notifyOperationInserted(op, previous);
1830 if (
config.allowPatternRollback) {
1844 if (
config.allowPatternRollback)
1854 assert(!
impl.config.allowPatternRollback &&
1855 "this code path is valid only in 'no rollback' mode");
1857 for (
auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1860 repls.push_back(
Value());
1867 Value srcMat =
impl.buildUnresolvedMaterialization(
1872 repls.push_back(srcMat);
1878 repls.push_back(to[0]);
1887 Value srcMat =
impl.buildUnresolvedMaterialization(
1890 Type(), converter)[0];
1891 repls.push_back(srcMat);
1900 "incorrect number of replacement values");
1902 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
1910 for (
auto [
result, repls] :
1911 llvm::zip_equal(op->
getResults(), newValues)) {
1913 auto logProlog = [&, repls = repls]() {
1914 logger.startLine() <<
" Note: Replacing op result of type "
1915 << resultType <<
" with value(s) of type (";
1916 llvm::interleaveComma(repls,
logger.getOStream(), [&](
Value v) {
1917 logger.getOStream() << v.getType();
1919 logger.getOStream() <<
")";
1925 logger.getOStream() <<
", but the type converter failed to legalize "
1926 "the original type.\n";
1931 logger.getOStream() <<
", but the legalized type(s) is/are (";
1932 llvm::interleaveComma(convertedTypes,
logger.getOStream(),
1933 [&](
Type t) { logger.getOStream() << t; });
1934 logger.getOStream() <<
")\n";
1940 if (!
config.allowPatternRollback) {
1949 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1955 if (
config.unlegalizedOps)
1956 config.unlegalizedOps->erase(op);
1964 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1968 "attempting to replace a value that was already replaced");
1973 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1978 "attempting to replace/erase an unresolved materialization");
1994 logger.startLine() <<
"** Replace Value : '" << from <<
"'";
1995 if (
auto blockArg = dyn_cast<BlockArgument>(from)) {
1997 logger.getOStream() <<
" (in region of '" << parentOp->getName()
1998 <<
"' (" << parentOp <<
")";
2000 logger.getOStream() <<
" (unlinked block)";
2004 logger.getOStream() <<
", conditional replacement";
2008 if (!
config.allowPatternRollback) {
2013 Value repl = repls.front();
2030 "attempting to replace a value that was already replaced");
2032 "attempting to replace a op result that was already replaced");
2037 llvm::reportFatalInternalError(
2038 "conditional value replacement is not supported in rollback mode");
2044 if (!
config.allowPatternRollback) {
2051 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2057 if (
config.unlegalizedOps)
2058 config.unlegalizedOps->erase(op);
2067 "attempting to erase a block within a replaced/erased op");
2083 bool wasDetached = !previous;
2089 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
2090 <<
"' (" << parent <<
")";
2093 <<
"** Insert Block into detached Region (nullptr parent op)";
2096 logger.getOStream() <<
" (was detached)";
2097 logger.getOStream() <<
"\n";
2103 "attempting to insert into a region within a replaced/erased op");
2108 config.listener->notifyBlockInserted(block, previous, previousIt);
2112 if (
config.allowPatternRollback) {
2126 if (
config.allowPatternRollback)
2140 reasonCallback(
diag);
2141 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
2142 if (
config.notifyCallback)
2151ConversionPatternRewriter::ConversionPatternRewriter(
2155 *this, config, opConverter)) {
2156 setListener(
impl.get());
2159ConversionPatternRewriter::~ConversionPatternRewriter() =
default;
2161const ConversionConfig &ConversionPatternRewriter::getConfig()
const {
2162 return impl->config;
2166 assert(op && newOp &&
"expected non-null op");
2170void ConversionPatternRewriter::replaceOp(Operation *op,
ValueRange newValues) {
2172 "incorrect # of replacement values");
2176 if (getInsertionPoint() == op->getIterator())
2179 SmallVector<SmallVector<Value>> newVals =
2180 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2181 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2183 impl->replaceOp(op, std::move(newVals));
2186void ConversionPatternRewriter::replaceOpWithMultiple(
2187 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2189 "incorrect # of replacement values");
2193 if (getInsertionPoint() == op->getIterator())
2196 impl->replaceOp(op, std::move(newValues));
2199void ConversionPatternRewriter::eraseOp(Operation *op) {
2201 impl->logger.startLine()
2202 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
2207 if (getInsertionPoint() == op->getIterator())
2210 SmallVector<SmallVector<Value>> nullRepls(op->
getNumResults(), {});
2211 impl->replaceOp(op, std::move(nullRepls));
2214void ConversionPatternRewriter::eraseBlock(
Block *block) {
2215 impl->eraseBlock(block);
2218Block *ConversionPatternRewriter::applySignatureConversion(
2219 Block *block, TypeConverter::SignatureConversion &conversion,
2220 const TypeConverter *converter) {
2221 assert(!impl->wasOpReplaced(block->
getParentOp()) &&
2222 "attempting to apply a signature conversion to a block within a "
2223 "replaced/erased op");
2224 return impl->applySignatureConversion(block, converter, conversion);
2227FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2228 Region *region,
const TypeConverter &converter,
2229 TypeConverter::SignatureConversion *entryConversion) {
2230 assert(!impl->wasOpReplaced(region->
getParentOp()) &&
2231 "attempting to apply a signature conversion to a block within a "
2232 "replaced/erased op");
2233 return impl->convertRegionTypes(region, converter, entryConversion);
2236void ConversionPatternRewriter::replaceAllUsesWith(Value from,
ValueRange to) {
2237 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2240void ConversionPatternRewriter::replaceUsesWithIf(
2242 bool *allUsesReplaced) {
2243 assert(!allUsesReplaced &&
2244 "allUsesReplaced is not supported in a dialect conversion");
2245 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2248Value ConversionPatternRewriter::getRemappedValue(Value key) {
2249 SmallVector<ValueVector> remappedValues;
2250 if (
failed(impl->remapValues(
"value", std::nullopt, key,
2253 assert(remappedValues.front().size() == 1 &&
"1:N conversion not supported");
2254 return remappedValues.front().front();
2258ConversionPatternRewriter::getRemappedValues(
ValueRange keys,
2259 SmallVectorImpl<Value> &results) {
2262 SmallVector<ValueVector> remapped;
2263 if (
failed(impl->remapValues(
"value", std::nullopt, keys,
2266 for (
const auto &values : remapped) {
2267 assert(values.size() == 1 &&
"1:N conversion not supported");
2268 results.push_back(values.front());
2273void ConversionPatternRewriter::inlineBlockBefore(
Block *source,
Block *dest,
2278 "incorrect # of argument replacement values");
2279 assert(!impl->wasOpReplaced(source->
getParentOp()) &&
2280 "attempting to inline a block from a replaced/erased op");
2281 assert(!impl->wasOpReplaced(dest->
getParentOp()) &&
2282 "attempting to inline a block into a replaced/erased op");
2283 auto opIgnored = [&](Operation *op) {
return impl->isOpIgnored(op); };
2286 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
2287 "expected 'source' to have no predecessors");
2296 bool fastPath = !getConfig().listener;
2298 if (fastPath && impl->config.allowPatternRollback)
2299 impl->inlineBlockBefore(source, dest, before);
2302 for (
auto it : llvm::zip(source->
getArguments(), argValues))
2303 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2310 while (!source->
empty())
2311 moveOpBefore(&source->
front(), dest, before);
2316 if (getInsertionBlock() == source)
2317 setInsertionPoint(dest, getInsertionPoint());
2323void ConversionPatternRewriter::startOpModification(Operation *op) {
2324 if (!impl->config.allowPatternRollback) {
2329 assert(!impl->wasOpReplaced(op) &&
2330 "attempting to modify a replaced/erased op");
2332 impl->pendingRootUpdates.insert(op);
2334 impl->appendRewrite<ModifyOperationRewrite>(op);
2337void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2338 impl->patternModifiedOps.insert(op);
2339 if (!impl->config.allowPatternRollback) {
2341 if (getConfig().listener)
2342 getConfig().listener->notifyOperationModified(op);
2349 assert(!impl->wasOpReplaced(op) &&
2350 "attempting to modify a replaced/erased op");
2351 assert(impl->pendingRootUpdates.erase(op) &&
2352 "operation did not have a pending in-place update");
2356void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2357 if (!impl->config.allowPatternRollback) {
2362 assert(impl->pendingRootUpdates.erase(op) &&
2363 "operation did not have a pending in-place update");
2366 auto it = llvm::find_if(
2367 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
2368 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2369 return modifyRewrite && modifyRewrite->getOperation() == op;
2371 assert(it != impl->rewrites.rend() &&
"no root update started on op");
2373 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2374 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2377detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2385FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2386 ArrayRef<ValueRange> operands)
const {
2387 SmallVector<Value> oneToOneOperands;
2388 oneToOneOperands.reserve(operands.size());
2390 if (operand.size() != 1)
2393 oneToOneOperands.push_back(operand.front());
2395 return std::move(oneToOneOperands);
2399ConversionPattern::matchAndRewrite(Operation *op,
2400 PatternRewriter &rewriter)
const {
2401 auto &dialectRewriter =
static_cast<ConversionPatternRewriter &
>(rewriter);
2402 auto &rewriterImpl = dialectRewriter.getImpl();
2406 getTypeConverter());
2409 SmallVector<ValueVector> remapped;
2414 SmallVector<ValueRange> remappedAsRange =
2415 llvm::to_vector_of<ValueRange>(remapped);
2416 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2425using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2428class OperationLegalizer {
2430 using LegalizationAction = ConversionTarget::LegalizationAction;
2432 OperationLegalizer(ConversionPatternRewriter &rewriter,
2433 const ConversionTarget &targetInfo,
2434 const FrozenRewritePatternSet &patterns);
2437 bool isIllegal(Operation *op)
const;
2441 LogicalResult legalize(Operation *op);
2444 const ConversionTarget &getTarget() {
return target; }
2448 LogicalResult legalizeWithFold(Operation *op);
2452 LogicalResult legalizeWithPattern(Operation *op);
2456 bool canApplyPattern(Operation *op,
const Pattern &pattern);
2460 legalizePatternResult(Operation *op,
const Pattern &pattern,
2461 const RewriterState &curState,
2478 void buildLegalizationGraph(
2479 LegalizationPatterns &anyOpLegalizerPatterns,
2490 void computeLegalizationGraphBenefit(
2491 LegalizationPatterns &anyOpLegalizerPatterns,
2496 unsigned computeOpLegalizationDepth(
2503 unsigned applyCostModelToPatterns(
2504 LegalizationPatterns &patterns,
2509 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2512 ConversionPatternRewriter &rewriter;
2515 const ConversionTarget &
target;
2518 PatternApplicator applicator;
2522OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2523 const ConversionTarget &targetInfo,
2524 const FrozenRewritePatternSet &patterns)
2525 : rewriter(rewriter),
target(targetInfo), applicator(patterns) {
2529 LegalizationPatterns anyOpLegalizerPatterns;
2531 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2532 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2535bool OperationLegalizer::isIllegal(Operation *op)
const {
2536 return target.isIllegal(op);
2539LogicalResult OperationLegalizer::legalize(Operation *op) {
2541 const char *logLineComment =
2542 "//===-------------------------------------------===//\n";
2544 auto &logger = rewriter.getImpl().logger;
2548 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2551 logger.getOStream() <<
"\n";
2552 logger.startLine() << logLineComment;
2553 logger.startLine() <<
"Legalizing operation : ";
2558 logger.getOStream() <<
"'" << op->
getName() <<
"' ";
2559 logger.getOStream() <<
"(" << op <<
") {\n";
2564 logger.startLine() << OpWithFlags(op,
2565 OpPrintingFlags().printGenericOpForm())
2572 logSuccess(logger,
"operation marked 'ignored' during conversion");
2573 logger.startLine() << logLineComment;
2579 if (
auto legalityInfo =
target.isLegal(op)) {
2582 logger,
"operation marked legal by the target{0}",
2583 legalityInfo->isRecursivelyLegal
2584 ?
"; NOTE: operation is recursively legal; skipping internals"
2586 logger.startLine() << logLineComment;
2591 if (legalityInfo->isRecursivelyLegal) {
2592 op->
walk([&](Operation *nested) {
2594 rewriter.getImpl().ignoredOps.
insert(nested);
2603 const ConversionConfig &config = rewriter.getConfig();
2604 if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2605 if (succeeded(legalizeWithFold(op))) {
2608 logger.startLine() << logLineComment;
2615 if (succeeded(legalizeWithPattern(op))) {
2618 logger.startLine() << logLineComment;
2625 if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2626 if (succeeded(legalizeWithFold(op))) {
2629 logger.startLine() << logLineComment;
2636 logFailure(logger,
"no matched legalization pattern");
2637 logger.startLine() << logLineComment;
2644template <
typename T>
2646 T
result = std::move(obj);
2651LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2652 auto &rewriterImpl = rewriter.getImpl();
2654 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2655 rewriterImpl.
logger.indent();
2660 llvm::scope_exit cleanup([&]() {
2670 SmallVector<Value, 2> replacementValues;
2671 SmallVector<Operation *, 2> newOps;
2674 if (
failed(rewriter.
tryFold(op, replacementValues, &newOps))) {
2683 if (replacementValues.empty())
2684 return legalize(op);
2687 rewriter.
replaceOp(op, replacementValues);
2690 for (Operation *newOp : newOps) {
2691 if (
failed(legalize(newOp))) {
2693 "failed to legalize generated constant '{0}'",
2695 if (!rewriter.getConfig().allowPatternRollback) {
2697 llvm::reportFatalInternalError(
2699 "' folder rollback of IR modifications requested");
2717 auto newOpNames = llvm::map_range(
2719 auto modifiedOpNames = llvm::map_range(
2721 llvm::reportFatalInternalError(
"pattern '" + pattern.
getDebugName() +
2722 "' produced IR that could not be legalized. " +
2723 "new ops: {" + llvm::join(newOpNames,
", ") +
2724 "}, " +
"modified ops: {" +
2725 llvm::join(modifiedOpNames,
", ") +
"}");
2728LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2729 auto &rewriterImpl = rewriter.getImpl();
2730 const ConversionConfig &config = rewriter.getConfig();
2732#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2734 std::optional<OperationFingerPrint> topLevelFingerPrint;
2735 if (!rewriterImpl.
config.allowPatternRollback) {
2742 topLevelFingerPrint = OperationFingerPrint(checkOp);
2748 rewriterImpl.
logger.startLine()
2749 <<
"WARNING: Multi-threadeding is enabled. Some dialect "
2750 "conversion expensive checks are skipped in multithreading "
2759 auto canApply = [&](
const Pattern &pattern) {
2760 bool canApply = canApplyPattern(op, pattern);
2761 if (canApply && config.listener)
2762 config.listener->notifyPatternBegin(pattern, op);
2768 auto onFailure = [&](
const Pattern &pattern) {
2770 if (!rewriterImpl.
config.allowPatternRollback) {
2777#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2779 if (checkOp && topLevelFingerPrint) {
2780 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2781 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2782 llvm::reportFatalInternalError(
2783 "pattern '" + pattern.getDebugName() +
2784 "' returned failure but IR did change");
2792 if (rewriterImpl.
config.notifyCallback) {
2794 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2800 if (config.listener)
2801 config.listener->notifyPatternEnd(pattern, failure());
2802 rewriterImpl.
resetState(curState, pattern.getDebugName());
2803 appliedPatterns.erase(&pattern);
2808 auto onSuccess = [&](
const Pattern &pattern) {
2810 if (!rewriterImpl.
config.allowPatternRollback) {
2824 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2825 appliedPatterns.erase(&pattern);
2827 if (!rewriterImpl.
config.allowPatternRollback)
2829 rewriterImpl.
resetState(curState, pattern.getDebugName());
2831 if (config.listener)
2832 config.listener->notifyPatternEnd(pattern,
result);
2837 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2841bool OperationLegalizer::canApplyPattern(Operation *op,
2842 const Pattern &pattern) {
2844 auto &os = rewriter.getImpl().logger;
2845 os.getOStream() <<
"\n";
2846 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2848 os.getOStream() <<
")' {\n";
2855 !appliedPatterns.insert(&pattern).second) {
2857 logFailure(rewriter.getImpl().logger,
"pattern was already applied"));
2863LogicalResult OperationLegalizer::legalizePatternResult(
2864 Operation *op,
const Pattern &pattern,
const RewriterState &curState,
2867 [[maybe_unused]]
auto &impl = rewriter.getImpl();
2868 assert(impl.pendingRootUpdates.empty() &&
"dangling root updates");
2870#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2871 if (impl.config.allowPatternRollback) {
2873 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2874 auto replacedRoot = [&] {
2875 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2877 auto updatedRootInPlace = [&] {
2878 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2880 if (!replacedRoot() && !updatedRootInPlace())
2881 llvm::reportFatalInternalError(
2882 "expected pattern to replace the root operation "
2883 "or modify it in place");
2888 if (
failed(legalizePatternRootUpdates(modifiedOps)) ||
2889 failed(legalizePatternCreatedOperations(newOps))) {
2893 LLVM_DEBUG(
logSuccess(impl.logger,
"pattern applied successfully"));
2897LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2899 for (Operation *op : newOps) {
2900 if (
failed(legalize(op))) {
2901 LLVM_DEBUG(
logFailure(rewriter.getImpl().logger,
2902 "failed to legalize generated operation '{0}'({1})",
2910LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2912 for (Operation *op : modifiedOps) {
2913 if (
failed(legalize(op))) {
2916 "failed to legalize operation updated in-place '{0}'",
2928void OperationLegalizer::buildLegalizationGraph(
2929 LegalizationPatterns &anyOpLegalizerPatterns,
2940 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2941 std::optional<OperationName> root = pattern.
getRootKind();
2947 anyOpLegalizerPatterns.push_back(&pattern);
2952 if (
target.getOpAction(*root) == LegalizationAction::Legal)
2957 invalidPatterns[*root].insert(&pattern);
2959 parentOps[op].insert(*root);
2962 patternWorklist.insert(&pattern);
2970 if (!anyOpLegalizerPatterns.empty()) {
2971 for (
const Pattern *pattern : patternWorklist)
2972 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2976 while (!patternWorklist.empty()) {
2977 auto *pattern = patternWorklist.pop_back_val();
2981 std::optional<LegalizationAction> action = target.getOpAction(op);
2982 return !legalizerPatterns.count(op) &&
2983 (!action || action == LegalizationAction::Illegal);
2989 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2990 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2994 for (
auto op : parentOps[*pattern->
getRootKind()])
2995 patternWorklist.set_union(invalidPatterns[op]);
2999void OperationLegalizer::computeLegalizationGraphBenefit(
3000 LegalizationPatterns &anyOpLegalizerPatterns,
3006 for (
auto &opIt : legalizerPatterns)
3007 if (!minOpPatternDepth.count(opIt.first))
3008 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3014 if (!anyOpLegalizerPatterns.empty())
3015 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3021 applicator.applyCostModel([&](
const Pattern &pattern) {
3022 ArrayRef<const Pattern *> orderedPatternList;
3023 if (std::optional<OperationName> rootName = pattern.
getRootKind())
3024 orderedPatternList = legalizerPatterns[*rootName];
3026 orderedPatternList = anyOpLegalizerPatterns;
3029 auto *it = llvm::find(orderedPatternList, &pattern);
3030 if (it == orderedPatternList.end())
3034 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3038unsigned OperationLegalizer::computeOpLegalizationDepth(
3042 auto depthIt = minOpPatternDepth.find(op);
3043 if (depthIt != minOpPatternDepth.end())
3044 return depthIt->second;
3048 auto opPatternsIt = legalizerPatterns.find(op);
3049 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3054 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3058 unsigned minDepth = applyCostModelToPatterns(
3059 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3060 minOpPatternDepth[op] = minDepth;
3064unsigned OperationLegalizer::applyCostModelToPatterns(
3065 LegalizationPatterns &patterns,
3068 unsigned minDepth = std::numeric_limits<unsigned>::max();
3071 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3072 patternsByDepth.reserve(patterns.size());
3073 for (
const Pattern *pattern : patterns) {
3076 unsigned generatedOpDepth = computeOpLegalizationDepth(
3077 generatedOp, minOpPatternDepth, legalizerPatterns);
3078 depth = std::max(depth, generatedOpDepth + 1);
3080 patternsByDepth.emplace_back(pattern, depth);
3083 minDepth = std::min(minDepth, depth);
3088 if (patternsByDepth.size() == 1)
3092 llvm::stable_sort(patternsByDepth,
3093 [](
const std::pair<const Pattern *, unsigned> &
lhs,
3094 const std::pair<const Pattern *, unsigned> &
rhs) {
3097 if (
lhs.second !=
rhs.second)
3098 return lhs.second <
rhs.second;
3101 auto lhsBenefit =
lhs.first->getBenefit();
3102 auto rhsBenefit =
rhs.first->getBenefit();
3103 return lhsBenefit > rhsBenefit;
3108 for (
auto &patternIt : patternsByDepth)
3109 patterns.push_back(patternIt.first);
3123template <
typename RangeT>
3126 function_ref<
bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3135 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3136 if (castOp.getInputs().empty())
3139 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3142 if (inputCastOp.getOutputs() != castOp.getInputs())
3148 while (!worklist.empty()) {
3149 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3153 UnrealizedConversionCastOp nextCast = castOp;
3155 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3156 if (llvm::any_of(nextCast.getInputs(), [&](
Value v) {
3157 return v.getDefiningOp() == castOp;
3165 castOp.replaceAllUsesWith(nextCast.getInputs());
3168 nextCast = getInputCast(nextCast);
3178 auto markOpLive = [&](
Operation *rootOp) {
3180 worklist.push_back(rootOp);
3181 while (!worklist.empty()) {
3182 Operation *op = worklist.pop_back_val();
3183 if (liveOps.insert(op).second) {
3186 if (
auto castOp = v.
getDefiningOp<UnrealizedConversionCastOp>())
3187 if (isCastOpOfInterestFn(castOp))
3188 worklist.push_back(castOp);
3194 for (UnrealizedConversionCastOp op : castOps) {
3197 if (liveOps.contains(op.getOperation()))
3201 if (llvm::any_of(op->getUsers(), [&](
Operation *user) {
3202 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3203 return !castOp || !isCastOpOfInterestFn(castOp);
3209 for (UnrealizedConversionCastOp op : castOps) {
3210 if (liveOps.contains(op)) {
3212 if (remainingCastOps)
3213 remainingCastOps->push_back(op);
3224 ArrayRef<UnrealizedConversionCastOp> castOps,
3225 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3227 DenseSet<UnrealizedConversionCastOp> castOpSet;
3228 for (UnrealizedConversionCastOp op : castOps)
3229 castOpSet.insert(op);
3234 const DenseSet<UnrealizedConversionCastOp> &castOps,
3235 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3237 llvm::make_range(castOps.begin(), castOps.end()),
3238 [&](UnrealizedConversionCastOp castOp) {
3239 return castOps.contains(castOp);
3251 [&](UnrealizedConversionCastOp castOp) {
3252 return castOps.contains(castOp);
3269 const ConversionConfig &config,
3270 OpConversionMode mode)
3271 : rewriter(ctx, config, *this), opLegalizer(rewriter,
target, patterns),
3280 template <
typename Fn>
3282 bool isRecursiveLegalization =
false);
3284 bool isRecursiveLegalization =
false) {
3286 ops, [&]() {}, isRecursiveLegalization);
3294 LogicalResult convert(
Operation *op,
bool isRecursiveLegalization =
false);
3300 ConversionPatternRewriter rewriter;
3303 OperationLegalizer opLegalizer;
3306 OpConversionMode mode;
3311 bool isRecursiveLegalization) {
3312 const ConversionConfig &config = rewriter.getConfig();
3313 auto emitFailedToLegalizeDiag = [&](
bool wasExplicitlyIllegal) {
3315 <<
"failed to legalize operation '"
3317 if (wasExplicitlyIllegal)
3318 diag <<
" that was explicitly marked illegal";
3323 if (failed(opLegalizer.legalize(op))) {
3326 if (mode == OpConversionMode::Full) {
3327 if (!isRecursiveLegalization)
3328 emitFailedToLegalizeDiag(
false);
3334 if (mode == OpConversionMode::Partial) {
3335 if (opLegalizer.isIllegal(op)) {
3336 if (!isRecursiveLegalization)
3337 emitFailedToLegalizeDiag(
true);
3340 if (config.unlegalizedOps && !isRecursiveLegalization)
3341 config.unlegalizedOps->insert(op);
3343 }
else if (mode == OpConversionMode::Analysis) {
3347 if (config.legalizableOps && !isRecursiveLegalization)
3348 config.legalizableOps->insert(op);
3355 UnrealizedConversionCastOp op,
3356 const UnresolvedMaterializationInfo &info) {
3357 assert(!op.use_empty() &&
3358 "expected that dead materializations have already been DCE'd");
3365 switch (info.getMaterializationKind()) {
3366 case MaterializationKind::Target:
3367 newMaterialization = converter->materializeTargetConversion(
3368 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3369 info.getOriginalType());
3371 case MaterializationKind::Source:
3372 assert(op->getNumResults() == 1 &&
"expected single result");
3373 Value sourceMat = converter->materializeSourceConversion(
3374 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3376 newMaterialization.push_back(sourceMat);
3379 if (!newMaterialization.empty()) {
3381 ValueRange newMaterializationRange(newMaterialization);
3382 assert(
TypeRange(newMaterializationRange) == op.getResultTypes() &&
3383 "materialization callback produced value of incorrect type");
3385 rewriter.
replaceOp(op, newMaterialization);
3391 <<
"failed to legalize unresolved materialization "
3393 << inputOperands.
getTypes() <<
") to ("
3394 << op.getResultTypes()
3395 <<
") that remained live after conversion";
3396 diag.attachNote(op->getUsers().begin()->getLoc())
3397 <<
"see existing live user here: " << *op->getUsers().begin();
3401template <
typename Fn>
3404 bool isRecursiveLegalization) {
3412 toConvert.push_back(op);
3415 auto legalityInfo =
target.isLegal(op);
3416 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3422 if (failed(
convert(op, isRecursiveLegalization))) {
3431LogicalResult ConversionPatternRewriter::legalize(
Operation *op) {
3432 return impl->opConverter.legalizeOperations(op,
3436LogicalResult ConversionPatternRewriter::legalize(
Region *r) {
3452 std::optional<TypeConverter::SignatureConversion> conversion =
3453 converter->convertBlockSignature(&r->front());
3456 applySignatureConversion(&r->front(), *conversion, converter);
3461 return impl->opConverter.legalizeOperations(ops,
3471 if (rewriterImpl.
config.allowPatternRollback) {
3495 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3499 if (rewriter.getConfig().buildMaterializations) {
3503 rewriter.getConfig().listener);
3504 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3505 auto it = materializations.find(castOp);
3506 assert(it != materializations.end() &&
"inconsistent state");
3520void TypeConverter::SignatureConversion::addInputs(
unsigned origInputNo,
3522 assert(!types.empty() &&
"expected valid types");
3523 remapInput(origInputNo, argTypes.size(), types.size());
3527void TypeConverter::SignatureConversion::addInputs(
ArrayRef<Type> types) {
3528 assert(!types.empty() &&
3529 "1->0 type remappings don't need to be added explicitly");
3530 argTypes.append(types.begin(), types.end());
3533void TypeConverter::SignatureConversion::remapInput(
unsigned origInputNo,
3534 unsigned newInputNo,
3535 unsigned newInputCount) {
3536 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3537 assert(newInputCount != 0 &&
"expected valid input count");
3538 remappedInputs[origInputNo] =
3539 InputMapping{newInputNo, newInputCount, {}};
3542void TypeConverter::SignatureConversion::remapInput(
3543 unsigned origInputNo, ArrayRef<Value> replacements) {
3544 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3545 remappedInputs[origInputNo] = InputMapping{
3547 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3558TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3559 SmallVectorImpl<Type> &results)
const {
3560 assert(typeOrValue &&
"expected non-null type");
3561 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3562 : cast<Type>(typeOrValue);
3564 std::shared_lock<
decltype(cacheMutex)> cacheReadLock(cacheMutex,
3567 cacheReadLock.lock();
3568 auto existingIt = cachedDirectConversions.find(t);
3569 if (existingIt != cachedDirectConversions.end()) {
3570 if (existingIt->second)
3571 results.push_back(existingIt->second);
3572 return success(existingIt->second !=
nullptr);
3574 auto multiIt = cachedMultiConversions.find(t);
3575 if (multiIt != cachedMultiConversions.end()) {
3576 results.append(multiIt->second.begin(), multiIt->second.end());
3582 size_t currentCount = results.size();
3586 auto isCacheable = [&](
int index) {
3587 int numberOfConversionsUntilContextAware =
3588 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3589 return index < numberOfConversionsUntilContextAware;
3592 std::unique_lock<
decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3595 for (
auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3596 const ConversionCallbackFn &converter = indexedConverter.value();
3597 std::optional<LogicalResult>
result = converter(typeOrValue, results);
3599 assert(results.size() == currentCount &&
3600 "failed type conversion should not change results");
3603 if (!isCacheable(indexedConverter.index()))
3606 cacheWriteLock.lock();
3607 if (!succeeded(*
result)) {
3608 assert(results.size() == currentCount &&
3609 "failed type conversion should not change results");
3610 cachedDirectConversions.try_emplace(t,
nullptr);
3613 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3614 if (newTypes.size() == 1)
3615 cachedDirectConversions.try_emplace(t, newTypes.front());
3617 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3623LogicalResult TypeConverter::convertType(Type t,
3624 SmallVectorImpl<Type> &results)
const {
3625 return convertTypeImpl(t, results);
3628LogicalResult TypeConverter::convertType(Value v,
3629 SmallVectorImpl<Type> &results)
const {
3630 return convertTypeImpl(v, results);
3633Type TypeConverter::convertType(Type t)
const {
3635 SmallVector<Type, 1> results;
3636 if (
failed(convertType(t, results)))
3640 return results.size() == 1 ? results.front() :
nullptr;
3643Type TypeConverter::convertType(Value v)
const {
3645 SmallVector<Type, 1> results;
3646 if (
failed(convertType(v, results)))
3650 return results.size() == 1 ? results.front() :
nullptr;
3654TypeConverter::convertTypes(
TypeRange types,
3655 SmallVectorImpl<Type> &results)
const {
3656 for (Type type : types)
3657 if (
failed(convertType(type, results)))
3663TypeConverter::convertTypes(
ValueRange values,
3664 SmallVectorImpl<Type> &results)
const {
3665 for (Value value : values)
3666 if (
failed(convertType(value, results)))
3671bool TypeConverter::isLegal(Type type)
const {
3672 return convertType(type) == type;
3675bool TypeConverter::isLegal(Value value)
const {
3676 return convertType(value) == value.
getType();
3679bool TypeConverter::isLegal(Operation *op)
const {
3683bool TypeConverter::isLegal(Region *region)
const {
3684 return llvm::all_of(
3688bool TypeConverter::isSignatureLegal(FunctionType ty)
const {
3689 if (!isLegal(ty.getInputs()))
3691 if (!isLegal(ty.getResults()))
3697TypeConverter::convertSignatureArg(
unsigned inputNo, Type type,
3698 SignatureConversion &
result)
const {
3700 SmallVector<Type, 1> convertedTypes;
3701 if (
failed(convertType(type, convertedTypes)))
3705 if (convertedTypes.empty())
3709 result.addInputs(inputNo, convertedTypes);
3713TypeConverter::convertSignatureArgs(
TypeRange types,
3714 SignatureConversion &
result,
3715 unsigned origInputOffset)
const {
3716 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3717 if (
failed(convertSignatureArg(origInputOffset + i, types[i],
result)))
3722TypeConverter::convertSignatureArg(
unsigned inputNo, Value value,
3723 SignatureConversion &
result)
const {
3725 SmallVector<Type, 1> convertedTypes;
3726 if (
failed(convertType(value, convertedTypes)))
3730 if (convertedTypes.empty())
3734 result.addInputs(inputNo, convertedTypes);
3738TypeConverter::convertSignatureArgs(
ValueRange values,
3739 SignatureConversion &
result,
3740 unsigned origInputOffset)
const {
3741 for (
unsigned i = 0, e = values.size(); i != e; ++i)
3742 if (
failed(convertSignatureArg(origInputOffset + i, values[i],
result)))
3747Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3748 Location loc, Type resultType,
3750 for (
const SourceMaterializationCallbackFn &fn :
3751 llvm::reverse(sourceMaterializations))
3752 if (Value
result = fn(builder, resultType, inputs, loc))
3757Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3758 Location loc, Type resultType,
3760 Type originalType)
const {
3761 SmallVector<Value>
result = materializeTargetConversion(
3762 builder, loc,
TypeRange(resultType), inputs, originalType);
3765 assert(
result.size() == 1 &&
"expected single result");
3769SmallVector<Value> TypeConverter::materializeTargetConversion(
3771 Type originalType)
const {
3772 for (
const TargetMaterializationCallbackFn &fn :
3773 llvm::reverse(targetMaterializations)) {
3774 SmallVector<Value>
result =
3775 fn(builder, resultTypes, inputs, loc, originalType);
3779 "callback produced incorrect number of values or values with "
3786std::optional<TypeConverter::SignatureConversion>
3787TypeConverter::convertBlockSignature(
Block *block)
const {
3790 return std::nullopt;
3797TypeConverter::AttributeConversionResult
3798TypeConverter::AttributeConversionResult::result(Attribute attr) {
3799 return AttributeConversionResult(attr, resultTag);
3802TypeConverter::AttributeConversionResult
3803TypeConverter::AttributeConversionResult::na() {
3804 return AttributeConversionResult(
nullptr, naTag);
3807TypeConverter::AttributeConversionResult
3808TypeConverter::AttributeConversionResult::abort() {
3809 return AttributeConversionResult(
nullptr, abortTag);
3812bool TypeConverter::AttributeConversionResult::hasResult()
const {
3813 return impl.getInt() == resultTag;
3816bool TypeConverter::AttributeConversionResult::isNa()
const {
3817 return impl.getInt() == naTag;
3820bool TypeConverter::AttributeConversionResult::isAbort()
const {
3821 return impl.getInt() == abortTag;
3824Attribute TypeConverter::AttributeConversionResult::getResult()
const {
3825 assert(hasResult() &&
"Cannot get result from N/A or abort");
3826 return impl.getPointer();
3829std::optional<Attribute>
3830TypeConverter::convertTypeAttribute(Type type, Attribute attr)
const {
3831 for (
const TypeAttributeConversionCallbackFn &fn :
3832 llvm::reverse(typeAttributeConversions)) {
3833 AttributeConversionResult res = fn(type, attr);
3834 if (res.hasResult())
3835 return res.getResult();
3837 return std::nullopt;
3839 return std::nullopt;
3848 ConversionPatternRewriter &rewriter) {
3849 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3854 TypeConverter::SignatureConversion funcConversion(type.getNumInputs());
3856 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
3858 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3867 if (!funcOp.getFunctionBody().empty()) {
3868 Block *entryBlock = &funcOp.getFunctionBody().
front();
3870 unsigned numFuncTypeInputs = type.getNumInputs();
3871 TypeConverter::SignatureConversion blockConversion(numEntryBlockArgs);
3873 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
3878 for (
unsigned i = numFuncTypeInputs; i < numEntryBlockArgs; ++i)
3880 rewriter.applySignatureConversion(entryBlock, blockConversion,
3884 auto newType = FunctionType::get(
3885 rewriter.getContext(), funcConversion.getConvertedTypes(), newResults);
3887 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3896struct FunctionOpInterfaceSignatureConversion :
public ConversionPattern {
3897 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3899 const TypeConverter &converter,
3900 PatternBenefit benefit)
3901 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3904 matchAndRewrite(Operation *op, ArrayRef<Value> ,
3905 ConversionPatternRewriter &rewriter)
const override {
3906 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3911struct AnyFunctionOpInterfaceSignatureConversion
3912 :
public OpInterfaceConversionPattern<FunctionOpInterface> {
3913 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3916 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> ,
3917 ConversionPatternRewriter &rewriter)
const override {
3923FailureOr<Operation *>
3924mlir::convertOpResultTypes(Operation *op,
ValueRange operands,
3925 const TypeConverter &converter,
3926 ConversionPatternRewriter &rewriter) {
3927 assert(op &&
"Invalid op");
3928 Location loc = op->
getLoc();
3929 if (converter.isLegal(op))
3930 return rewriter.notifyMatchFailure(loc,
"op already legal");
3932 OperationState newOp(loc, op->
getName());
3933 newOp.addOperands(operands);
3935 SmallVector<Type> newResultTypes;
3937 return rewriter.notifyMatchFailure(loc,
"couldn't convert return types");
3939 newOp.addTypes(newResultTypes);
3940 newOp.addAttributes(op->
getAttrs());
3941 return rewriter.create(newOp);
3944void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3945 StringRef functionLikeOpName, RewritePatternSet &patterns,
3946 const TypeConverter &converter, PatternBenefit benefit) {
3947 patterns.
add<FunctionOpInterfaceSignatureConversion>(
3948 functionLikeOpName, patterns.
getContext(), converter, benefit);
3951void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3952 RewritePatternSet &patterns,
const TypeConverter &converter,
3953 PatternBenefit benefit) {
3954 patterns.
add<AnyFunctionOpInterfaceSignatureConversion>(
3962void ConversionTarget::setOpAction(OperationName op,
3963 LegalizationAction action) {
3964 legalOperations[op].action = action;
3967void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3968 LegalizationAction action) {
3969 for (StringRef dialect : dialectNames)
3970 legalDialects[dialect] = action;
3973auto ConversionTarget::getOpAction(OperationName op)
const
3974 -> std::optional<LegalizationAction> {
3975 std::optional<LegalizationInfo> info = getOpInfo(op);
3976 return info ? info->action : std::optional<LegalizationAction>();
3979auto ConversionTarget::isLegal(Operation *op)
const
3980 -> std::optional<LegalOpDetails> {
3981 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3983 return std::nullopt;
3986 auto isOpLegal = [&] {
3988 if (info->action == LegalizationAction::Dynamic) {
3989 std::optional<bool>
result = info->legalityFn(op);
3995 return info->action == LegalizationAction::Legal;
3998 return std::nullopt;
4001 LegalOpDetails legalityDetails;
4002 if (info->isRecursivelyLegal) {
4003 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
4004 if (legalityFnIt != opRecursiveLegalityFns.end()) {
4005 legalityDetails.isRecursivelyLegal =
4006 legalityFnIt->second(op).value_or(
true);
4008 legalityDetails.isRecursivelyLegal =
true;
4011 return legalityDetails;
4014bool ConversionTarget::isIllegal(Operation *op)
const {
4015 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
4019 if (info->action == LegalizationAction::Dynamic) {
4020 std::optional<bool>
result = info->legalityFn(op);
4027 return info->action == LegalizationAction::Illegal;
4031 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
4032 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
4036 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
4038 if (std::optional<bool>
result = newCl(op))
4046void ConversionTarget::setLegalityCallback(
4047 OperationName name,
const DynamicLegalityCallbackFn &callback) {
4048 assert(callback &&
"expected valid legality callback");
4049 auto *infoIt = legalOperations.find(name);
4050 assert(infoIt != legalOperations.end() &&
4051 infoIt->second.action == LegalizationAction::Dynamic &&
4052 "expected operation to already be marked as dynamically legal");
4053 infoIt->second.legalityFn =
4057void ConversionTarget::markOpRecursivelyLegal(
4058 OperationName name,
const DynamicLegalityCallbackFn &callback) {
4059 auto *infoIt = legalOperations.find(name);
4060 assert(infoIt != legalOperations.end() &&
4061 infoIt->second.action != LegalizationAction::Illegal &&
4062 "expected operation to already be marked as legal");
4063 infoIt->second.isRecursivelyLegal =
true;
4066 std::move(opRecursiveLegalityFns[name]), callback);
4068 opRecursiveLegalityFns.erase(name);
4071void ConversionTarget::setLegalityCallback(
4072 ArrayRef<StringRef> dialects,
const DynamicLegalityCallbackFn &callback) {
4073 assert(callback &&
"expected valid legality callback");
4074 for (StringRef dialect : dialects)
4076 std::move(dialectLegalityFns[dialect]), callback);
4079void ConversionTarget::setLegalityCallback(
4080 const DynamicLegalityCallbackFn &callback) {
4081 assert(callback &&
"expected valid legality callback");
4085auto ConversionTarget::getOpInfo(OperationName op)
const
4086 -> std::optional<LegalizationInfo> {
4088 const auto *it = legalOperations.find(op);
4089 if (it != legalOperations.end())
4092 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4093 if (dialectIt != legalDialects.end()) {
4094 DynamicLegalityCallbackFn callback;
4095 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4096 if (dialectFn != dialectLegalityFns.end())
4097 callback = dialectFn->second;
4098 return LegalizationInfo{dialectIt->second,
false,
4102 if (unknownLegalityFn)
4103 return LegalizationInfo{LegalizationAction::Dynamic,
4104 false, unknownLegalityFn};
4105 return std::nullopt;
4108#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4113void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4114 auto &rewriterImpl =
4115 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4119void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4120 auto &rewriterImpl =
4121 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4127static FailureOr<SmallVector<Value>>
4128pdllConvertValues(ConversionPatternRewriter &rewriter,
ValueRange values) {
4129 SmallVector<Value> mappedValues;
4130 if (
failed(rewriter.getRemappedValues(values, mappedValues)))
4132 return std::move(mappedValues);
4135void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
4138 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4139 auto results = pdllConvertValues(
4140 static_cast<ConversionPatternRewriter &
>(rewriter), value);
4143 return results->front();
4146 "convertValues", [](PatternRewriter &rewriter,
ValueRange values) {
4147 return pdllConvertValues(
4148 static_cast<ConversionPatternRewriter &
>(rewriter), values);
4152 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4153 auto &rewriterImpl =
4154 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4155 if (
const TypeConverter *converter =
4157 if (Type newType = converter->convertType(type))
4165 [](PatternRewriter &rewriter,
4166 TypeRange types) -> FailureOr<SmallVector<Type>> {
4167 auto &rewriterImpl =
4168 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4171 return SmallVector<Type>(types);
4173 SmallVector<Type> remappedTypes;
4174 if (
failed(converter->convertTypes(types, remappedTypes)))
4176 return std::move(remappedTypes);
4191 static constexpr StringLiteral
tag =
"apply-conversion";
4192 static constexpr StringLiteral
desc =
4193 "Encapsulate the application of a dialect conversion";
4201 ConversionConfig config,
4202 OpConversionMode mode) {
4206 LogicalResult status =
success();
4211 patterns, config, mode);
4222LogicalResult mlir::applyPartialConversion(
4223 ArrayRef<Operation *> ops,
const ConversionTarget &
target,
4224 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4226 OpConversionMode::Partial);
4229mlir::applyPartialConversion(Operation *op,
const ConversionTarget &
target,
4230 const FrozenRewritePatternSet &patterns,
4231 ConversionConfig config) {
4232 return applyPartialConversion(llvm::ArrayRef(op),
target, patterns, config);
4239LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4240 const ConversionTarget &
target,
4241 const FrozenRewritePatternSet &patterns,
4242 ConversionConfig config) {
4245LogicalResult mlir::applyFullConversion(Operation *op,
4246 const ConversionTarget &
target,
4247 const FrozenRewritePatternSet &patterns,
4248 ConversionConfig config) {
4249 return applyFullConversion(llvm::ArrayRef(op),
target, patterns, config);
4266 "expected top-level op to be isolated from above");
4269 "expected ops to have a common ancestor");
4278 for (
Operation *op : ops.drop_front()) {
4282 assert(commonAncestor &&
4283 "expected to find a common isolated from above ancestor");
4287 return commonAncestor;
4290LogicalResult mlir::applyAnalysisConversion(
4291 ArrayRef<Operation *> ops, ConversionTarget &
target,
4292 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4294 if (config.legalizableOps)
4295 assert(config.legalizableOps->empty() &&
"expected empty set");
4301 Operation *clonedAncestor = commonAncestor->
clone(mapping);
4305 inverseOperationMap[it.second] = it.first;
4308 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4309 ops, [&](Operation *op) {
return mapping.
lookup(op); });
4311 OpConversionMode::Analysis);
4315 if (config.legalizableOps) {
4317 for (Operation *op : *config.legalizableOps)
4318 originalLegalizableOps.insert(inverseOperationMap[op]);
4319 *config.legalizableOps = std::move(originalLegalizableOps);
4323 clonedAncestor->
erase();
4328mlir::applyAnalysisConversion(Operation *op, ConversionTarget &
target,
4329 const FrozenRewritePatternSet &patterns,
4330 ConversionConfig config) {
4331 return applyAnalysisConversion(llvm::ArrayRef(op),
target, patterns, config);
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
SmallVector< Value, 2 > ValueVector
A vector of SSA values, optimized for the most common case of one or two values.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
type_range getTypes() const
void destroyOpProperties(PropertyRef properties) const
This hooks destroy the op properties.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
TypeID getOpPropertiesTypeID() const
Return the TypeID of the op properties.
Operation is the basic unit of execution within MLIR.
PropertyRef getPropertiesStorage()
Return a generic (but typed) reference to the property type storage.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
void copyProperties(PropertyRef rhs)
Copy properties from an existing other properties object.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
CRTP Implementation of an action.
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This iterator enumerates elements according to their dominance relationship.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
const ConversionTarget & getTarget()
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, bool isRecursiveLegalization=false)
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, Fn onFailure, bool isRecursiveLegalization=false)
Legalizes the given operations (and their nested operations) to the conversion target.
LogicalResult applyConversion(ArrayRef< Operation * > ops)
Applies the conversion to the given operations (and their nested operations).
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
bool wasErased(void *ptr) const
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.