10#include "mlir/Config/mlir-config.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/SaveAndRestore.h"
27#include "llvm/Support/ScopedPrinter.h"
34#define DEBUG_TYPE "dialect-conversion"
37template <
typename... Args>
38static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
41 os.startLine() <<
"} -> SUCCESS";
43 os.getOStream() <<
" : "
44 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
45 os.getOStream() <<
"\n";
50template <
typename... Args>
51static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
54 os.startLine() <<
"} -> FAILURE : "
55 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
65 if (
OpResult inputRes = dyn_cast<OpResult>(value))
66 insertPt = ++inputRes.getOwner()->getIterator();
73 assert(!vals.empty() &&
"expected at least one value");
76 for (
Value v : vals.drop_front()) {
90 assert(dom &&
"unable to find valid insertion point");
98enum OpConversionMode {
125struct ValueVectorMapInfo {
126 static ::llvm::hash_code getHashValue(
const ValueVector &val) {
127 return ::llvm::hash_combine_range(val);
136struct ConversionValueMapping {
139 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
144 template <
typename T>
145 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
148 template <
typename OldVal,
typename NewVal>
149 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
150 map(OldVal &&oldVal, NewVal &&newVal) {
154 assert(next != oldVal &&
"inserting cyclic mapping");
155 auto it = mapping.find(next);
156 if (it == mapping.end())
161 mappedTo.insert_range(newVal);
163 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
167 template <
typename OldVal,
typename NewVal>
168 std::enable_if_t<!IsValueVector<OldVal>::value ||
169 !IsValueVector<NewVal>::value>
170 map(OldVal &&oldVal, NewVal &&newVal) {
171 if constexpr (IsValueVector<OldVal>{}) {
172 map(std::forward<OldVal>(oldVal),
ValueVector{newVal});
173 }
else if constexpr (IsValueVector<NewVal>{}) {
174 map(
ValueVector{oldVal}, std::forward<NewVal>(newVal));
185 void erase(
const ValueVector &value) { mapping.erase(value); }
205 assert(!values.empty() &&
"expected non-empty value vector");
206 Operation *op = values.front().getDefiningOp();
207 for (
Value v : llvm::drop_begin(values)) {
208 if (v.getDefiningOp() != op)
218 assert(!values.empty() &&
"expected non-empty value vector");
224 auto it = mapping.find(from);
225 if (it == mapping.end()) {
238struct RewriterState {
239 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
240 unsigned numReplacedOps)
241 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
242 numReplacedOps(numReplacedOps) {}
245 unsigned numRewrites;
248 unsigned numIgnoredOperations;
251 unsigned numReplacedOps;
258static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
261static void notifyIRErased(RewriterBase::Listener *listener,
Block &
b) {
262 for (Operation &op :
b)
263 notifyIRErased(listener, op);
269static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
272 notifyIRErased(listener,
b);
302 UnresolvedMaterialization,
307 virtual ~IRRewrite() =
default;
310 virtual void rollback() = 0;
324 virtual void commit(RewriterBase &rewriter) {}
327 virtual void cleanup(RewriterBase &rewriter) {}
329 Kind getKind()
const {
return kind; }
331 static bool classof(
const IRRewrite *
rewrite) {
return true; }
334 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
335 : kind(kind), rewriterImpl(rewriterImpl) {}
337 const ConversionConfig &getConfig()
const;
340 ConversionPatternRewriterImpl &rewriterImpl;
344class BlockRewrite :
public IRRewrite {
347 Block *getBlock()
const {
return block; }
349 static bool classof(
const IRRewrite *
rewrite) {
350 return rewrite->getKind() >= Kind::CreateBlock &&
351 rewrite->getKind() <= Kind::BlockTypeConversion;
355 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
357 : IRRewrite(kind, rewriterImpl), block(block) {}
364class ValueRewrite :
public IRRewrite {
367 Value getValue()
const {
return value; }
369 static bool classof(
const IRRewrite *
rewrite) {
370 return rewrite->getKind() == Kind::ReplaceValue;
374 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
376 : IRRewrite(kind, rewriterImpl), value(value) {}
385class CreateBlockRewrite :
public BlockRewrite {
387 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
388 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
390 static bool classof(
const IRRewrite *
rewrite) {
391 return rewrite->getKind() == Kind::CreateBlock;
394 void commit(RewriterBase &rewriter)
override {
400 void rollback()
override {
403 auto &blockOps = block->getOperations();
404 while (!blockOps.empty())
405 blockOps.remove(blockOps.begin());
406 block->dropAllUses();
407 if (block->getParent())
418class EraseBlockRewrite :
public BlockRewrite {
420 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
421 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
422 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
424 static bool classof(
const IRRewrite *
rewrite) {
425 return rewrite->getKind() == Kind::EraseBlock;
428 ~EraseBlockRewrite()
override {
430 "rewrite was neither rolled back nor committed/cleaned up");
433 void rollback()
override {
436 assert(block &&
"expected block");
441 blockList.insert(before, block);
445 void commit(RewriterBase &rewriter)
override {
446 assert(block &&
"expected block");
450 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
451 notifyIRErased(listener, *block);
454 void cleanup(RewriterBase &rewriter)
override {
456 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
458 assert(block->empty() &&
"expected empty block");
461 block->dropAllDefinedValueUses();
472 Block *insertBeforeBlock;
478class InlineBlockRewrite :
public BlockRewrite {
480 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
482 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
483 sourceBlock(sourceBlock),
484 firstInlinedInst(sourceBlock->empty() ?
nullptr
485 : &sourceBlock->front()),
486 lastInlinedInst(sourceBlock->empty() ?
nullptr : &sourceBlock->back()) {
492 assert(!getConfig().listener &&
493 "InlineBlockRewrite not supported if listener is attached");
496 static bool classof(
const IRRewrite *
rewrite) {
497 return rewrite->getKind() == Kind::InlineBlock;
500 void rollback()
override {
503 if (firstInlinedInst) {
504 assert(lastInlinedInst &&
"expected operation");
517 Operation *firstInlinedInst;
520 Operation *lastInlinedInst;
524class MoveBlockRewrite :
public BlockRewrite {
526 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
528 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block),
529 region(previousRegion),
530 insertBeforeBlock(previousIt == previousRegion->end() ?
nullptr
533 static bool classof(
const IRRewrite *
rewrite) {
534 return rewrite->getKind() == Kind::MoveBlock;
537 void commit(RewriterBase &rewriter)
override {
547 void rollback()
override {
551 if (Region *currentParent = block->
getParent()) {
553 region->getBlocks().splice(before, currentParent->getBlocks(), block);
557 region->
getBlocks().insert(before, block);
566 Block *insertBeforeBlock;
570class BlockTypeConversionRewrite :
public BlockRewrite {
572 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
574 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, origBlock),
575 newBlock(newBlock) {}
577 static bool classof(
const IRRewrite *
rewrite) {
578 return rewrite->getKind() == Kind::BlockTypeConversion;
581 Block *getOrigBlock()
const {
return block; }
583 Block *getNewBlock()
const {
return newBlock; }
585 void commit(RewriterBase &rewriter)
override;
587 void rollback()
override;
597class ReplaceValueRewrite :
public ValueRewrite {
599 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
600 const TypeConverter *converter)
601 : ValueRewrite(
Kind::ReplaceValue, rewriterImpl, value),
602 converter(converter) {}
604 static bool classof(
const IRRewrite *
rewrite) {
605 return rewrite->getKind() == Kind::ReplaceValue;
608 void commit(RewriterBase &rewriter)
override;
610 void rollback()
override;
614 const TypeConverter *converter;
618class OperationRewrite :
public IRRewrite {
621 Operation *getOperation()
const {
return op; }
623 static bool classof(
const IRRewrite *
rewrite) {
624 return rewrite->getKind() >= Kind::MoveOperation &&
625 rewrite->getKind() <= Kind::UnresolvedMaterialization;
629 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
631 : IRRewrite(kind, rewriterImpl), op(op) {}
638class MoveOperationRewrite :
public OperationRewrite {
640 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
641 Operation *op, OpBuilder::InsertPoint previous)
642 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op),
643 block(previous.getBlock()),
644 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
646 : &*previous.getPoint()) {}
648 static bool classof(
const IRRewrite *
rewrite) {
649 return rewrite->getKind() == Kind::MoveOperation;
652 void commit(RewriterBase &rewriter)
override {
658 op, OpBuilder::InsertPoint(block,
663 void rollback()
override {
676 Operation *insertBeforeOp;
681class ModifyOperationRewrite :
public OperationRewrite {
683 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
685 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
686 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
687 operands(op->operand_begin(), op->operand_end()),
688 successors(op->successor_begin(), op->successor_end()) {
691 propertiesStorage = operator new(op->getPropertiesStorageSize());
692 PropertyRef propCopy(name.getOpPropertiesTypeID(), propertiesStorage);
693 name.initOpProperties(propCopy, prop);
697 static bool classof(
const IRRewrite *
rewrite) {
698 return rewrite->getKind() == Kind::ModifyOperation;
701 ~ModifyOperationRewrite()
override {
702 assert(!propertiesStorage &&
703 "rewrite was neither committed nor rolled back");
706 void commit(RewriterBase &rewriter)
override {
709 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
712 if (propertiesStorage) {
717 operator delete(propertiesStorage);
718 propertiesStorage =
nullptr;
722 void rollback()
override {
726 for (
const auto &it : llvm::enumerate(successors))
728 if (propertiesStorage) {
732 operator delete(propertiesStorage);
733 propertiesStorage =
nullptr;
740 DictionaryAttr attrs;
741 SmallVector<Value, 8> operands;
742 SmallVector<Block *, 2> successors;
743 void *propertiesStorage =
nullptr;
750class ReplaceOperationRewrite :
public OperationRewrite {
752 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
753 Operation *op,
const TypeConverter *converter)
754 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
755 converter(converter) {}
757 static bool classof(
const IRRewrite *
rewrite) {
758 return rewrite->getKind() == Kind::ReplaceOperation;
761 void commit(RewriterBase &rewriter)
override;
763 void rollback()
override;
765 void cleanup(RewriterBase &rewriter)
override;
770 const TypeConverter *converter;
773class CreateOperationRewrite :
public OperationRewrite {
775 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
777 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
779 static bool classof(
const IRRewrite *
rewrite) {
780 return rewrite->getKind() == Kind::CreateOperation;
783 void commit(RewriterBase &rewriter)
override {
789 void rollback()
override;
793enum MaterializationKind {
804class UnresolvedMaterializationInfo {
806 UnresolvedMaterializationInfo() =
default;
807 UnresolvedMaterializationInfo(
const TypeConverter *converter,
808 MaterializationKind kind, Type originalType)
809 : converterAndKind(converter, kind), originalType(originalType) {}
812 const TypeConverter *getConverter()
const {
813 return converterAndKind.getPointer();
817 MaterializationKind getMaterializationKind()
const {
818 return converterAndKind.getInt();
822 Type getOriginalType()
const {
return originalType; }
827 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
838class UnresolvedMaterializationRewrite :
public OperationRewrite {
840 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
841 UnrealizedConversionCastOp op,
843 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
844 mappedValues(std::move(mappedValues)) {}
846 static bool classof(
const IRRewrite *
rewrite) {
847 return rewrite->getKind() == Kind::UnresolvedMaterialization;
850 void rollback()
override;
852 UnrealizedConversionCastOp getOperation()
const {
853 return cast<UnrealizedConversionCastOp>(op);
863#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
866template <
typename RewriteTy,
typename R>
867static bool hasRewrite(R &&rewrites, Operation *op) {
868 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
869 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
870 return rewriteTy && rewriteTy->getOperation() == op;
876template <
typename RewriteTy,
typename R>
877static bool hasRewrite(R &&rewrites,
Block *block) {
878 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
879 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
880 return rewriteTy && rewriteTy->getBlock() == block;
892 const ConversionConfig &
config,
902 RewriterState getCurrentState();
906 void applyRewrites();
911 void resetState(RewriterState state, StringRef patternName =
"");
915 template <
typename RewriteTy,
typename... Args>
917 assert(
config.allowPatternRollback &&
"appending rewrites is not allowed");
919 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
925 void undoRewrites(
unsigned numRewritesToKeep = 0, StringRef patternName =
"");
931 LogicalResult remapValues(StringRef valueDiagTag,
932 std::optional<Location> inputLoc,
ValueRange values,
949 bool skipPureTypeConversions =
false)
const;
963 TypeConverter::SignatureConversion *entryConversion);
971 Block *applySignatureConversion(
973 TypeConverter::SignatureConversion &signatureConversion);
993 void eraseBlock(
Block *block);
1031 Value findOrBuildReplacementValue(
Value value,
1039 void notifyOperationInserted(
Operation *op,
1043 void notifyBlockInserted(
Block *block,
Region *previous,
1062 std::function<
void(
Operation *)> opErasedCallback =
nullptr)
1064 opErasedCallback(std::move(opErasedCallback)) {}
1078 assert(block->empty() &&
"expected empty block");
1079 block->dropAllDefinedValueUses();
1087 if (opErasedCallback)
1088 opErasedCallback(op);
1140 llvm::MapVector<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
1196const ConversionConfig &IRRewrite::getConfig()
const {
1197 return rewriterImpl.
config;
1200void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1204 if (
auto *listener =
1205 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1206 for (Operation *op : getNewBlock()->getUsers())
1210void BlockTypeConversionRewrite::rollback() {
1211 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1218 if (isa<BlockArgument>(repl)) {
1258 result &= functor(operand);
1263void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1270void ReplaceValueRewrite::rollback() {
1271 rewriterImpl.
mapping.erase({value});
1277void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1279 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1282 SmallVector<Value> replacements =
1284 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1292 for (
auto [
result, newValue] :
1293 llvm::zip_equal(op->
getResults(), replacements))
1299 if (getConfig().unlegalizedOps)
1300 getConfig().unlegalizedOps->erase(op);
1304 notifyIRErased(listener, *op);
1309 llvm::reportFatalInternalError(
1310 "dialect conversion attempted to replace a root operation that has no "
1311 "parent block; the pass must ensure its target op is nested in a "
1316void ReplaceOperationRewrite::rollback() {
1321void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1325void CreateOperationRewrite::rollback() {
1327 while (!region.getBlocks().empty())
1328 region.getBlocks().remove(region.getBlocks().begin());
1334void UnresolvedMaterializationRewrite::rollback() {
1335 if (!mappedValues.empty())
1336 rewriterImpl.
mapping.erase(mappedValues);
1347 for (
size_t i = 0; i <
rewrites.size(); ++i)
1353 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1354 unresolvedMaterializations.erase(castOp);
1357 rewrite->cleanup(eraseRewriter);
1365 Value from,
TypeRange desiredTypes,
bool skipPureTypeConversions)
const {
1368 assert(!values.empty() &&
"expected non-empty value vector");
1372 if (
config.allowPatternRollback)
1373 return mapping.lookup(values);
1380 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1385 if (castOp.getOutputs() != values)
1387 return castOp.getInputs();
1396 for (
Value v : values) {
1399 llvm::append_range(next, r);
1404 if (next != values) {
1433 if (skipPureTypeConversions) {
1436 match &= !pureConversion;
1439 if (!pureConversion)
1440 lastNonMaterialization = current;
1443 desiredValue = current;
1449 current = std::move(next);
1454 if (!desiredTypes.empty())
1455 return desiredValue;
1456 if (skipPureTypeConversions)
1457 return lastNonMaterialization;
1476 StringRef patternName) {
1481 while (
ignoredOps.size() != state.numIgnoredOperations)
1484 while (
replacedOps.size() != state.numReplacedOps)
1489 StringRef patternName) {
1491 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1493 rewrites.resize(numRewritesToKeep);
1497 StringRef valueDiagTag, std::optional<Location> inputLoc,
ValueRange values,
1499 remapped.reserve(llvm::size(values));
1501 for (
const auto &it : llvm::enumerate(values)) {
1502 Value operand = it.value();
1521 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1522 << it.index() <<
", type was " << origType;
1527 if (legalTypes.empty()) {
1528 remapped.push_back({});
1537 remapped.push_back(std::move(repl));
1546 repl, repl, legalTypes,
1548 remapped.push_back(castValues);
1569 TypeConverter::SignatureConversion *entryConversion) {
1571 if (region->
empty())
1576 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1578 std::optional<TypeConverter::SignatureConversion> conversion =
1579 converter.convertBlockSignature(&block);
1588 if (entryConversion)
1591 std::optional<TypeConverter::SignatureConversion> conversion =
1592 converter.convertBlockSignature(®ion->
front());
1600 TypeConverter::SignatureConversion &signatureConversion) {
1601#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1603 if (hasRewrite<BlockTypeConversionRewrite>(
rewrites, block))
1604 llvm::reportFatalInternalError(
"block was already converted");
1611 auto convertedTypes = signatureConversion.getConvertedTypes();
1618 for (
unsigned i = 0; i < origArgCount; ++i) {
1619 auto inputMap = signatureConversion.getInputMapping(i);
1620 if (!inputMap || inputMap->replacedWithValues())
1623 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1624 newLocs[inputMap->inputNo +
j] = origLoc;
1631 convertedTypes, newLocs);
1639 bool fastPath = !
config.listener;
1641 if (
config.allowPatternRollback)
1645 while (!block->
empty())
1652 for (
unsigned i = 0; i != origArgCount; ++i) {
1656 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1657 signatureConversion.getInputMapping(i);
1665 MaterializationKind::Source,
1669 origArgType,
Type(), converter,
1676 if (inputMap->replacedWithValues()) {
1678 assert(inputMap->size == 0 &&
1679 "invalid to provide a replacement value when the argument isn't "
1687 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1691 if (
config.allowPatternRollback)
1712 assert((!originalType || kind == MaterializationKind::Target) &&
1713 "original type is valid only for target materializations");
1714 assert(
TypeRange(inputs) != outputTypes &&
1715 "materialization is not necessary");
1719 OpBuilder builder(outputTypes.front().getContext());
1721 UnrealizedConversionCastOp convertOp =
1722 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1723 if (
config.attachDebugMaterializationKind) {
1725 kind == MaterializationKind::Source ?
"source" :
"target";
1726 convertOp->setAttr(
"__kind__", builder.
getStringAttr(kindStr));
1733 UnresolvedMaterializationInfo(converter, kind, originalType);
1734 if (
config.allowPatternRollback) {
1735 if (!valuesToMap.empty())
1736 mapping.map(valuesToMap, convertOp.getResults());
1738 std::move(valuesToMap));
1742 return convertOp.getResults();
1747 assert(
config.allowPatternRollback &&
1748 "this code path is valid only in rollback mode");
1755 return repl.front();
1762 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1787 MaterializationKind::Source, ip, value.
getLoc(),
1803 bool wasDetached = !previous.
isSet();
1805 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"' (" << op
1808 logger.getOStream() <<
" (was detached)";
1809 logger.getOStream() <<
"\n";
1815 "attempting to insert into a block within a replaced/erased op");
1819 config.listener->notifyOperationInserted(op, previous);
1828 if (
config.allowPatternRollback) {
1842 if (
config.allowPatternRollback)
1852 assert(!
impl.config.allowPatternRollback &&
1853 "this code path is valid only in 'no rollback' mode");
1855 for (
auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1858 repls.push_back(
Value());
1865 Value srcMat =
impl.buildUnresolvedMaterialization(
1870 repls.push_back(srcMat);
1876 repls.push_back(to[0]);
1885 Value srcMat =
impl.buildUnresolvedMaterialization(
1888 Type(), converter)[0];
1889 repls.push_back(srcMat);
1898 "incorrect number of replacement values");
1900 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
1908 for (
auto [
result, repls] :
1909 llvm::zip_equal(op->
getResults(), newValues)) {
1911 auto logProlog = [&, repls = repls]() {
1912 logger.startLine() <<
" Note: Replacing op result of type "
1913 << resultType <<
" with value(s) of type (";
1914 llvm::interleaveComma(repls,
logger.getOStream(), [&](
Value v) {
1915 logger.getOStream() << v.getType();
1917 logger.getOStream() <<
")";
1923 logger.getOStream() <<
", but the type converter failed to legalize "
1924 "the original type.\n";
1929 logger.getOStream() <<
", but the legalized type(s) is/are (";
1930 llvm::interleaveComma(convertedTypes,
logger.getOStream(),
1931 [&](
Type t) { logger.getOStream() << t; });
1932 logger.getOStream() <<
")\n";
1938 if (!
config.allowPatternRollback) {
1947 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1953 if (
config.unlegalizedOps)
1954 config.unlegalizedOps->erase(op);
1962 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1966 "attempting to replace a value that was already replaced");
1971 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1976 "attempting to replace/erase an unresolved materialization");
1992 logger.startLine() <<
"** Replace Value : '" << from <<
"'";
1993 if (
auto blockArg = dyn_cast<BlockArgument>(from)) {
1995 logger.getOStream() <<
" (in region of '" << parentOp->getName()
1996 <<
"' (" << parentOp <<
")";
1998 logger.getOStream() <<
" (unlinked block)";
2002 logger.getOStream() <<
", conditional replacement";
2006 if (!
config.allowPatternRollback) {
2011 Value repl = repls.front();
2028 "attempting to replace a value that was already replaced");
2030 "attempting to replace a op result that was already replaced");
2035 llvm::reportFatalInternalError(
2036 "conditional value replacement is not supported in rollback mode");
2042 if (!
config.allowPatternRollback) {
2049 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2055 if (
config.unlegalizedOps)
2056 config.unlegalizedOps->erase(op);
2065 "attempting to erase a block within a replaced/erased op");
2081 bool wasDetached = !previous;
2087 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
2088 <<
"' (" << parent <<
")";
2091 <<
"** Insert Block into detached Region (nullptr parent op)";
2094 logger.getOStream() <<
" (was detached)";
2095 logger.getOStream() <<
"\n";
2101 "attempting to insert into a region within a replaced/erased op");
2106 config.listener->notifyBlockInserted(block, previous, previousIt);
2110 if (
config.allowPatternRollback) {
2124 if (
config.allowPatternRollback)
2138 reasonCallback(
diag);
2139 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
2140 if (
config.notifyCallback)
2149ConversionPatternRewriter::ConversionPatternRewriter(
2153 *this, config, opConverter)) {
2154 setListener(
impl.get());
2157ConversionPatternRewriter::~ConversionPatternRewriter() =
default;
2159const ConversionConfig &ConversionPatternRewriter::getConfig()
const {
2160 return impl->config;
2164 assert(op && newOp &&
"expected non-null op");
2168void ConversionPatternRewriter::replaceOp(Operation *op,
ValueRange newValues) {
2170 "incorrect # of replacement values");
2174 if (getInsertionPoint() == op->getIterator())
2177 SmallVector<SmallVector<Value>> newVals =
2178 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2179 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2181 impl->replaceOp(op, std::move(newVals));
2184void ConversionPatternRewriter::replaceOpWithMultiple(
2185 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2187 "incorrect # of replacement values");
2191 if (getInsertionPoint() == op->getIterator())
2194 impl->replaceOp(op, std::move(newValues));
2197void ConversionPatternRewriter::eraseOp(Operation *op) {
2199 impl->logger.startLine()
2200 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
2205 if (getInsertionPoint() == op->getIterator())
2208 SmallVector<SmallVector<Value>> nullRepls(op->
getNumResults(), {});
2209 impl->replaceOp(op, std::move(nullRepls));
2212void ConversionPatternRewriter::eraseBlock(
Block *block) {
2213 impl->eraseBlock(block);
2216Block *ConversionPatternRewriter::applySignatureConversion(
2217 Block *block, TypeConverter::SignatureConversion &conversion,
2218 const TypeConverter *converter) {
2219 assert(!impl->wasOpReplaced(block->
getParentOp()) &&
2220 "attempting to apply a signature conversion to a block within a "
2221 "replaced/erased op");
2222 return impl->applySignatureConversion(block, converter, conversion);
2225FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2226 Region *region,
const TypeConverter &converter,
2227 TypeConverter::SignatureConversion *entryConversion) {
2228 assert(!impl->wasOpReplaced(region->
getParentOp()) &&
2229 "attempting to apply a signature conversion to a block within a "
2230 "replaced/erased op");
2231 return impl->convertRegionTypes(region, converter, entryConversion);
2234void ConversionPatternRewriter::replaceAllUsesWith(Value from,
ValueRange to) {
2235 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2238void ConversionPatternRewriter::replaceUsesWithIf(
2240 bool *allUsesReplaced) {
2241 assert(!allUsesReplaced &&
2242 "allUsesReplaced is not supported in a dialect conversion");
2243 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2246Value ConversionPatternRewriter::getRemappedValue(Value key) {
2247 SmallVector<ValueVector> remappedValues;
2248 if (
failed(impl->remapValues(
"value", std::nullopt, key,
2251 assert(remappedValues.front().size() == 1 &&
"1:N conversion not supported");
2252 return remappedValues.front().front();
2256ConversionPatternRewriter::getRemappedValues(
ValueRange keys,
2257 SmallVectorImpl<Value> &results) {
2260 SmallVector<ValueVector> remapped;
2261 if (
failed(impl->remapValues(
"value", std::nullopt, keys,
2264 for (
const auto &values : remapped) {
2265 assert(values.size() == 1 &&
"1:N conversion not supported");
2266 results.push_back(values.front());
2271void ConversionPatternRewriter::inlineBlockBefore(
Block *source,
Block *dest,
2276 "incorrect # of argument replacement values");
2277 assert(!impl->wasOpReplaced(source->
getParentOp()) &&
2278 "attempting to inline a block from a replaced/erased op");
2279 assert(!impl->wasOpReplaced(dest->
getParentOp()) &&
2280 "attempting to inline a block into a replaced/erased op");
2281 auto opIgnored = [&](Operation *op) {
return impl->isOpIgnored(op); };
2284 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
2285 "expected 'source' to have no predecessors");
2294 bool fastPath = !getConfig().listener;
2296 if (fastPath && impl->config.allowPatternRollback)
2297 impl->inlineBlockBefore(source, dest, before);
2300 for (
auto it : llvm::zip(source->
getArguments(), argValues))
2301 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2308 while (!source->
empty())
2309 moveOpBefore(&source->
front(), dest, before);
2314 if (getInsertionBlock() == source)
2315 setInsertionPoint(dest, getInsertionPoint());
2321void ConversionPatternRewriter::startOpModification(Operation *op) {
2322 if (!impl->config.allowPatternRollback) {
2327 assert(!impl->wasOpReplaced(op) &&
2328 "attempting to modify a replaced/erased op");
2330 impl->pendingRootUpdates.insert(op);
2332 impl->appendRewrite<ModifyOperationRewrite>(op);
2335void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2336 impl->patternModifiedOps.insert(op);
2337 if (!impl->config.allowPatternRollback) {
2339 if (getConfig().listener)
2340 getConfig().listener->notifyOperationModified(op);
2347 assert(!impl->wasOpReplaced(op) &&
2348 "attempting to modify a replaced/erased op");
2349 assert(impl->pendingRootUpdates.erase(op) &&
2350 "operation did not have a pending in-place update");
2354void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2355 if (!impl->config.allowPatternRollback) {
2360 assert(impl->pendingRootUpdates.erase(op) &&
2361 "operation did not have a pending in-place update");
2364 auto it = llvm::find_if(
2365 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
2366 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2367 return modifyRewrite && modifyRewrite->getOperation() == op;
2369 assert(it != impl->rewrites.rend() &&
"no root update started on op");
2371 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2372 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2375detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2383FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2384 ArrayRef<ValueRange> operands)
const {
2385 SmallVector<Value> oneToOneOperands;
2386 oneToOneOperands.reserve(operands.size());
2388 if (operand.size() != 1)
2391 oneToOneOperands.push_back(operand.front());
2393 return std::move(oneToOneOperands);
2397ConversionPattern::matchAndRewrite(Operation *op,
2398 PatternRewriter &rewriter)
const {
2399 auto &dialectRewriter =
static_cast<ConversionPatternRewriter &
>(rewriter);
2400 auto &rewriterImpl = dialectRewriter.getImpl();
2404 getTypeConverter());
2407 SmallVector<ValueVector> remapped;
2412 SmallVector<ValueRange> remappedAsRange =
2413 llvm::to_vector_of<ValueRange>(remapped);
2414 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2423using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2426class OperationLegalizer {
2428 using LegalizationAction = ConversionTarget::LegalizationAction;
2430 OperationLegalizer(ConversionPatternRewriter &rewriter,
2431 const ConversionTarget &targetInfo,
2432 const FrozenRewritePatternSet &patterns);
2435 bool isIllegal(Operation *op)
const;
2439 LogicalResult legalize(Operation *op);
2442 const ConversionTarget &getTarget() {
return target; }
2446 LogicalResult legalizeWithFold(Operation *op);
2450 LogicalResult legalizeWithPattern(Operation *op);
2454 bool canApplyPattern(Operation *op,
const Pattern &pattern);
2458 legalizePatternResult(Operation *op,
const Pattern &pattern,
2459 const RewriterState &curState,
2476 void buildLegalizationGraph(
2477 LegalizationPatterns &anyOpLegalizerPatterns,
2488 void computeLegalizationGraphBenefit(
2489 LegalizationPatterns &anyOpLegalizerPatterns,
2494 unsigned computeOpLegalizationDepth(
2501 unsigned applyCostModelToPatterns(
2502 LegalizationPatterns &patterns,
2507 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2510 ConversionPatternRewriter &rewriter;
2513 const ConversionTarget &
target;
2516 PatternApplicator applicator;
2520OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2521 const ConversionTarget &targetInfo,
2522 const FrozenRewritePatternSet &patterns)
2523 : rewriter(rewriter),
target(targetInfo), applicator(patterns) {
2527 LegalizationPatterns anyOpLegalizerPatterns;
2529 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2530 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2533bool OperationLegalizer::isIllegal(Operation *op)
const {
2534 return target.isIllegal(op);
2537LogicalResult OperationLegalizer::legalize(Operation *op) {
2539 const char *logLineComment =
2540 "//===-------------------------------------------===//\n";
2542 auto &logger = rewriter.getImpl().logger;
2546 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2549 logger.getOStream() <<
"\n";
2550 logger.startLine() << logLineComment;
2551 logger.startLine() <<
"Legalizing operation : ";
2556 logger.getOStream() <<
"'" << op->
getName() <<
"' ";
2557 logger.getOStream() <<
"(" << op <<
") {\n";
2562 logger.startLine() << OpWithFlags(op,
2563 OpPrintingFlags().printGenericOpForm())
2570 logSuccess(logger,
"operation marked 'ignored' during conversion");
2571 logger.startLine() << logLineComment;
2577 if (
auto legalityInfo =
target.isLegal(op)) {
2580 logger,
"operation marked legal by the target{0}",
2581 legalityInfo->isRecursivelyLegal
2582 ?
"; NOTE: operation is recursively legal; skipping internals"
2584 logger.startLine() << logLineComment;
2589 if (legalityInfo->isRecursivelyLegal) {
2590 op->
walk([&](Operation *nested) {
2592 rewriter.getImpl().ignoredOps.
insert(nested);
2601 const ConversionConfig &config = rewriter.getConfig();
2602 if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2603 if (succeeded(legalizeWithFold(op))) {
2606 logger.startLine() << logLineComment;
2613 if (succeeded(legalizeWithPattern(op))) {
2616 logger.startLine() << logLineComment;
2623 if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2624 if (succeeded(legalizeWithFold(op))) {
2627 logger.startLine() << logLineComment;
2634 logFailure(logger,
"no matched legalization pattern");
2635 logger.startLine() << logLineComment;
2642template <
typename T>
2644 T
result = std::move(obj);
2649LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2650 auto &rewriterImpl = rewriter.getImpl();
2652 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2653 rewriterImpl.
logger.indent();
2658 llvm::scope_exit cleanup([&]() {
2668 SmallVector<Value, 2> replacementValues;
2669 SmallVector<Operation *, 2> newOps;
2672 if (
failed(rewriter.
tryFold(op, replacementValues, &newOps))) {
2681 if (replacementValues.empty())
2682 return legalize(op);
2685 rewriter.
replaceOp(op, replacementValues);
2688 for (Operation *newOp : newOps) {
2689 if (
failed(legalize(newOp))) {
2691 "failed to legalize generated constant '{0}'",
2693 if (!rewriter.getConfig().allowPatternRollback) {
2695 llvm::reportFatalInternalError(
2697 "' folder rollback of IR modifications requested");
2715 auto newOpNames = llvm::map_range(
2717 auto modifiedOpNames = llvm::map_range(
2719 llvm::reportFatalInternalError(
"pattern '" + pattern.
getDebugName() +
2720 "' produced IR that could not be legalized. " +
2721 "new ops: {" + llvm::join(newOpNames,
", ") +
2722 "}, " +
"modified ops: {" +
2723 llvm::join(modifiedOpNames,
", ") +
"}");
2726LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2727 auto &rewriterImpl = rewriter.getImpl();
2728 const ConversionConfig &config = rewriter.getConfig();
2730#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2732 std::optional<OperationFingerPrint> topLevelFingerPrint;
2733 if (!rewriterImpl.
config.allowPatternRollback) {
2740 topLevelFingerPrint = OperationFingerPrint(checkOp);
2746 rewriterImpl.
logger.startLine()
2747 <<
"WARNING: Multi-threadeding is enabled. Some dialect "
2748 "conversion expensive checks are skipped in multithreading "
2757 auto canApply = [&](
const Pattern &pattern) {
2758 bool canApply = canApplyPattern(op, pattern);
2759 if (canApply && config.listener)
2760 config.listener->notifyPatternBegin(pattern, op);
2766 auto onFailure = [&](
const Pattern &pattern) {
2768 if (!rewriterImpl.
config.allowPatternRollback) {
2775#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2777 if (checkOp && topLevelFingerPrint) {
2778 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2779 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2780 llvm::reportFatalInternalError(
2781 "pattern '" + pattern.getDebugName() +
2782 "' returned failure but IR did change");
2790 if (rewriterImpl.
config.notifyCallback) {
2792 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2798 if (config.listener)
2799 config.listener->notifyPatternEnd(pattern, failure());
2800 rewriterImpl.
resetState(curState, pattern.getDebugName());
2801 appliedPatterns.erase(&pattern);
2806 auto onSuccess = [&](
const Pattern &pattern) {
2808 if (!rewriterImpl.
config.allowPatternRollback) {
2822 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2823 appliedPatterns.erase(&pattern);
2825 if (!rewriterImpl.
config.allowPatternRollback)
2827 rewriterImpl.
resetState(curState, pattern.getDebugName());
2829 if (config.listener)
2830 config.listener->notifyPatternEnd(pattern,
result);
2835 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2839bool OperationLegalizer::canApplyPattern(Operation *op,
2840 const Pattern &pattern) {
2842 auto &os = rewriter.getImpl().logger;
2843 os.getOStream() <<
"\n";
2844 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2846 os.getOStream() <<
")' {\n";
2853 !appliedPatterns.insert(&pattern).second) {
2855 logFailure(rewriter.getImpl().logger,
"pattern was already applied"));
2861LogicalResult OperationLegalizer::legalizePatternResult(
2862 Operation *op,
const Pattern &pattern,
const RewriterState &curState,
2865 [[maybe_unused]]
auto &impl = rewriter.getImpl();
2866 assert(impl.pendingRootUpdates.empty() &&
"dangling root updates");
2868#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2869 if (impl.config.allowPatternRollback) {
2871 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2872 auto replacedRoot = [&] {
2873 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2875 auto updatedRootInPlace = [&] {
2876 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2878 if (!replacedRoot() && !updatedRootInPlace())
2879 llvm::reportFatalInternalError(
2880 "expected pattern to replace the root operation "
2881 "or modify it in place");
2886 if (
failed(legalizePatternRootUpdates(modifiedOps)) ||
2887 failed(legalizePatternCreatedOperations(newOps))) {
2891 LLVM_DEBUG(
logSuccess(impl.logger,
"pattern applied successfully"));
2895LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2897 for (Operation *op : newOps) {
2898 if (
failed(legalize(op))) {
2899 LLVM_DEBUG(
logFailure(rewriter.getImpl().logger,
2900 "failed to legalize generated operation '{0}'({1})",
2908LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2910 for (Operation *op : modifiedOps) {
2911 if (
failed(legalize(op))) {
2914 "failed to legalize operation updated in-place '{0}'",
2926void OperationLegalizer::buildLegalizationGraph(
2927 LegalizationPatterns &anyOpLegalizerPatterns,
2938 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2939 std::optional<OperationName> root = pattern.
getRootKind();
2945 anyOpLegalizerPatterns.push_back(&pattern);
2950 if (
target.getOpAction(*root) == LegalizationAction::Legal)
2955 invalidPatterns[*root].insert(&pattern);
2957 parentOps[op].insert(*root);
2960 patternWorklist.insert(&pattern);
2968 if (!anyOpLegalizerPatterns.empty()) {
2969 for (
const Pattern *pattern : patternWorklist)
2970 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2974 while (!patternWorklist.empty()) {
2975 auto *pattern = patternWorklist.pop_back_val();
2979 std::optional<LegalizationAction> action = target.getOpAction(op);
2980 return !legalizerPatterns.count(op) &&
2981 (!action || action == LegalizationAction::Illegal);
2987 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2988 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2992 for (
auto op : parentOps[*pattern->
getRootKind()])
2993 patternWorklist.set_union(invalidPatterns[op]);
2997void OperationLegalizer::computeLegalizationGraphBenefit(
2998 LegalizationPatterns &anyOpLegalizerPatterns,
3004 for (
auto &opIt : legalizerPatterns)
3005 if (!minOpPatternDepth.count(opIt.first))
3006 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3012 if (!anyOpLegalizerPatterns.empty())
3013 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3019 applicator.applyCostModel([&](
const Pattern &pattern) {
3020 ArrayRef<const Pattern *> orderedPatternList;
3021 if (std::optional<OperationName> rootName = pattern.
getRootKind())
3022 orderedPatternList = legalizerPatterns[*rootName];
3024 orderedPatternList = anyOpLegalizerPatterns;
3027 auto *it = llvm::find(orderedPatternList, &pattern);
3028 if (it == orderedPatternList.end())
3032 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3036unsigned OperationLegalizer::computeOpLegalizationDepth(
3040 auto depthIt = minOpPatternDepth.find(op);
3041 if (depthIt != minOpPatternDepth.end())
3042 return depthIt->second;
3046 auto opPatternsIt = legalizerPatterns.find(op);
3047 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3052 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3056 unsigned minDepth = applyCostModelToPatterns(
3057 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3058 minOpPatternDepth[op] = minDepth;
3062unsigned OperationLegalizer::applyCostModelToPatterns(
3063 LegalizationPatterns &patterns,
3066 unsigned minDepth = std::numeric_limits<unsigned>::max();
3069 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3070 patternsByDepth.reserve(patterns.size());
3071 for (
const Pattern *pattern : patterns) {
3074 unsigned generatedOpDepth = computeOpLegalizationDepth(
3075 generatedOp, minOpPatternDepth, legalizerPatterns);
3076 depth = std::max(depth, generatedOpDepth + 1);
3078 patternsByDepth.emplace_back(pattern, depth);
3081 minDepth = std::min(minDepth, depth);
3086 if (patternsByDepth.size() == 1)
3090 llvm::stable_sort(patternsByDepth,
3091 [](
const std::pair<const Pattern *, unsigned> &
lhs,
3092 const std::pair<const Pattern *, unsigned> &
rhs) {
3095 if (
lhs.second !=
rhs.second)
3096 return lhs.second <
rhs.second;
3099 auto lhsBenefit =
lhs.first->getBenefit();
3100 auto rhsBenefit =
rhs.first->getBenefit();
3101 return lhsBenefit > rhsBenefit;
3106 for (
auto &patternIt : patternsByDepth)
3107 patterns.push_back(patternIt.first);
3121template <
typename RangeT>
3124 function_ref<
bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3133 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3134 if (castOp.getInputs().empty())
3137 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3140 if (inputCastOp.getOutputs() != castOp.getInputs())
3146 while (!worklist.empty()) {
3147 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3151 UnrealizedConversionCastOp nextCast = castOp;
3153 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3154 if (llvm::any_of(nextCast.getInputs(), [&](
Value v) {
3155 return v.getDefiningOp() == castOp;
3163 castOp.replaceAllUsesWith(nextCast.getInputs());
3166 nextCast = getInputCast(nextCast);
3176 auto markOpLive = [&](
Operation *rootOp) {
3178 worklist.push_back(rootOp);
3179 while (!worklist.empty()) {
3180 Operation *op = worklist.pop_back_val();
3181 if (liveOps.insert(op).second) {
3184 if (
auto castOp = v.
getDefiningOp<UnrealizedConversionCastOp>())
3185 if (isCastOpOfInterestFn(castOp))
3186 worklist.push_back(castOp);
3192 for (UnrealizedConversionCastOp op : castOps) {
3195 if (liveOps.contains(op.getOperation()))
3199 if (llvm::any_of(op->getUsers(), [&](
Operation *user) {
3200 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3201 return !castOp || !isCastOpOfInterestFn(castOp);
3207 for (UnrealizedConversionCastOp op : castOps) {
3208 if (liveOps.contains(op)) {
3210 if (remainingCastOps)
3211 remainingCastOps->push_back(op);
3222 ArrayRef<UnrealizedConversionCastOp> castOps,
3223 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3225 DenseSet<UnrealizedConversionCastOp> castOpSet;
3226 for (UnrealizedConversionCastOp op : castOps)
3227 castOpSet.insert(op);
3232 const DenseSet<UnrealizedConversionCastOp> &castOps,
3233 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3235 llvm::make_range(castOps.begin(), castOps.end()),
3236 [&](UnrealizedConversionCastOp castOp) {
3237 return castOps.contains(castOp);
3244 const llvm::MapVector<UnrealizedConversionCastOp,
3245 UnresolvedMaterializationInfo> &castOps,
3249 [&](UnrealizedConversionCastOp castOp) {
3250 return castOps.contains(castOp);
3267 const ConversionConfig &config,
3268 OpConversionMode mode)
3269 : rewriter(ctx, config, *this), opLegalizer(rewriter,
target, patterns),
3278 template <
typename Fn>
3280 bool isRecursiveLegalization =
false);
3282 bool isRecursiveLegalization =
false) {
3284 ops, [&]() {}, isRecursiveLegalization);
3292 LogicalResult convert(
Operation *op,
bool isRecursiveLegalization =
false);
3298 ConversionPatternRewriter rewriter;
3301 OperationLegalizer opLegalizer;
3304 OpConversionMode mode;
3309 bool isRecursiveLegalization) {
3310 const ConversionConfig &config = rewriter.getConfig();
3311 auto emitFailedToLegalizeDiag = [&](
bool wasExplicitlyIllegal) {
3313 <<
"failed to legalize operation '"
3315 if (wasExplicitlyIllegal)
3316 diag <<
" that was explicitly marked illegal";
3321 if (failed(opLegalizer.legalize(op))) {
3324 if (mode == OpConversionMode::Full) {
3325 if (!isRecursiveLegalization)
3326 emitFailedToLegalizeDiag(
false);
3332 if (mode == OpConversionMode::Partial) {
3333 if (opLegalizer.isIllegal(op)) {
3334 if (!isRecursiveLegalization)
3335 emitFailedToLegalizeDiag(
true);
3338 if (config.unlegalizedOps && !isRecursiveLegalization)
3339 config.unlegalizedOps->insert(op);
3341 }
else if (mode == OpConversionMode::Analysis) {
3345 if (config.legalizableOps && !isRecursiveLegalization)
3346 config.legalizableOps->insert(op);
3353 UnrealizedConversionCastOp op,
3354 const UnresolvedMaterializationInfo &info) {
3355 assert(!op.use_empty() &&
3356 "expected that dead materializations have already been DCE'd");
3363 switch (info.getMaterializationKind()) {
3364 case MaterializationKind::Target:
3365 newMaterialization = converter->materializeTargetConversion(
3366 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3367 info.getOriginalType());
3369 case MaterializationKind::Source:
3370 assert(op->getNumResults() == 1 &&
"expected single result");
3371 Value sourceMat = converter->materializeSourceConversion(
3372 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3374 newMaterialization.push_back(sourceMat);
3377 if (!newMaterialization.empty()) {
3379 ValueRange newMaterializationRange(newMaterialization);
3380 assert(
TypeRange(newMaterializationRange) == op.getResultTypes() &&
3381 "materialization callback produced value of incorrect type");
3383 rewriter.
replaceOp(op, newMaterialization);
3389 <<
"failed to legalize unresolved materialization "
3391 << inputOperands.
getTypes() <<
") to ("
3392 << op.getResultTypes()
3393 <<
") that remained live after conversion";
3394 diag.attachNote(op->getUsers().begin()->getLoc())
3395 <<
"see existing live user here: " << *op->getUsers().begin();
3399template <
typename Fn>
3402 bool isRecursiveLegalization) {
3410 toConvert.push_back(op);
3413 auto legalityInfo =
target.isLegal(op);
3414 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3420 if (failed(
convert(op, isRecursiveLegalization))) {
3429LogicalResult ConversionPatternRewriter::legalize(
Operation *op) {
3430 return impl->opConverter.legalizeOperations(op,
3434LogicalResult ConversionPatternRewriter::legalize(
Region *r) {
3450 std::optional<TypeConverter::SignatureConversion> conversion =
3451 converter->convertBlockSignature(&r->front());
3454 applySignatureConversion(&r->front(), *conversion, converter);
3459 return impl->opConverter.legalizeOperations(ops,
3469 if (rewriterImpl.
config.allowPatternRollback) {
3487 const llvm::MapVector<UnrealizedConversionCastOp,
3488 UnresolvedMaterializationInfo> &materializations =
3494 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3498 if (rewriter.getConfig().buildMaterializations) {
3502 rewriter.getConfig().listener);
3503 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3504 auto it = materializations.find(castOp);
3505 assert(it != materializations.end() &&
"inconsistent state");
3519void TypeConverter::SignatureConversion::addInputs(
unsigned origInputNo,
3521 assert(!types.empty() &&
"expected valid types");
3522 remapInput(origInputNo, argTypes.size(), types.size());
3526void TypeConverter::SignatureConversion::addInputs(
ArrayRef<Type> types) {
3527 assert(!types.empty() &&
3528 "1->0 type remappings don't need to be added explicitly");
3529 argTypes.append(types.begin(), types.end());
3532void TypeConverter::SignatureConversion::remapInput(
unsigned origInputNo,
3533 unsigned newInputNo,
3534 unsigned newInputCount) {
3535 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3536 assert(newInputCount != 0 &&
"expected valid input count");
3537 remappedInputs[origInputNo] =
3538 InputMapping{newInputNo, newInputCount, {}};
3541void TypeConverter::SignatureConversion::remapInput(
3542 unsigned origInputNo, ArrayRef<Value> replacements) {
3543 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3544 remappedInputs[origInputNo] = InputMapping{
3546 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3557TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3558 SmallVectorImpl<Type> &results)
const {
3559 assert(typeOrValue &&
"expected non-null type");
3560 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3561 : cast<Type>(typeOrValue);
3563 std::shared_lock<
decltype(cacheMutex)> cacheReadLock(cacheMutex,
3566 cacheReadLock.lock();
3567 auto existingIt = cachedDirectConversions.find(t);
3568 if (existingIt != cachedDirectConversions.end()) {
3569 if (existingIt->second)
3570 results.push_back(existingIt->second);
3571 return success(existingIt->second !=
nullptr);
3573 auto multiIt = cachedMultiConversions.find(t);
3574 if (multiIt != cachedMultiConversions.end()) {
3575 results.append(multiIt->second.begin(), multiIt->second.end());
3581 size_t currentCount = results.size();
3585 auto isCacheable = [&](
int index) {
3586 int numberOfConversionsUntilContextAware =
3587 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3588 return index < numberOfConversionsUntilContextAware;
3591 std::unique_lock<
decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3594 for (
auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3595 const ConversionCallbackFn &converter = indexedConverter.value();
3596 std::optional<LogicalResult>
result = converter(typeOrValue, results);
3598 assert(results.size() == currentCount &&
3599 "failed type conversion should not change results");
3602 if (!isCacheable(indexedConverter.index()))
3605 cacheWriteLock.lock();
3606 if (!succeeded(*
result)) {
3607 assert(results.size() == currentCount &&
3608 "failed type conversion should not change results");
3609 cachedDirectConversions.try_emplace(t,
nullptr);
3612 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3613 if (newTypes.size() == 1)
3614 cachedDirectConversions.try_emplace(t, newTypes.front());
3616 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3622LogicalResult TypeConverter::convertType(Type t,
3623 SmallVectorImpl<Type> &results)
const {
3624 return convertTypeImpl(t, results);
3627LogicalResult TypeConverter::convertType(Value v,
3628 SmallVectorImpl<Type> &results)
const {
3629 return convertTypeImpl(v, results);
3632Type TypeConverter::convertType(Type t)
const {
3634 SmallVector<Type, 1> results;
3635 if (
failed(convertType(t, results)))
3639 return results.size() == 1 ? results.front() :
nullptr;
3642Type TypeConverter::convertType(Value v)
const {
3644 SmallVector<Type, 1> results;
3645 if (
failed(convertType(v, results)))
3649 return results.size() == 1 ? results.front() :
nullptr;
3653TypeConverter::convertTypes(
TypeRange types,
3654 SmallVectorImpl<Type> &results)
const {
3655 for (Type type : types)
3656 if (
failed(convertType(type, results)))
3662TypeConverter::convertTypes(
ValueRange values,
3663 SmallVectorImpl<Type> &results)
const {
3664 for (Value value : values)
3665 if (
failed(convertType(value, results)))
3670bool TypeConverter::isLegal(Type type)
const {
3671 return convertType(type) == type;
3674bool TypeConverter::isLegal(Value value)
const {
3675 return convertType(value) == value.
getType();
3678bool TypeConverter::isLegal(Operation *op)
const {
3682bool TypeConverter::isLegal(Region *region)
const {
3683 return llvm::all_of(
3687bool TypeConverter::isSignatureLegal(FunctionType ty)
const {
3688 if (!isLegal(ty.getInputs()))
3690 if (!isLegal(ty.getResults()))
3696TypeConverter::convertSignatureArg(
unsigned inputNo, Type type,
3697 SignatureConversion &
result)
const {
3699 SmallVector<Type, 1> convertedTypes;
3700 if (
failed(convertType(type, convertedTypes)))
3704 if (convertedTypes.empty())
3708 result.addInputs(inputNo, convertedTypes);
3712TypeConverter::convertSignatureArgs(
TypeRange types,
3713 SignatureConversion &
result,
3714 unsigned origInputOffset)
const {
3715 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3716 if (
failed(convertSignatureArg(origInputOffset + i, types[i],
result)))
3721TypeConverter::convertSignatureArg(
unsigned inputNo, Value value,
3722 SignatureConversion &
result)
const {
3724 SmallVector<Type, 1> convertedTypes;
3725 if (
failed(convertType(value, convertedTypes)))
3729 if (convertedTypes.empty())
3733 result.addInputs(inputNo, convertedTypes);
3737TypeConverter::convertSignatureArgs(
ValueRange values,
3738 SignatureConversion &
result,
3739 unsigned origInputOffset)
const {
3740 for (
unsigned i = 0, e = values.size(); i != e; ++i)
3741 if (
failed(convertSignatureArg(origInputOffset + i, values[i],
result)))
3746Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3747 Location loc, Type resultType,
3749 for (
const SourceMaterializationCallbackFn &fn :
3750 llvm::reverse(sourceMaterializations))
3751 if (Value
result = fn(builder, resultType, inputs, loc))
3756Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3757 Location loc, Type resultType,
3759 Type originalType)
const {
3760 SmallVector<Value>
result = materializeTargetConversion(
3761 builder, loc,
TypeRange(resultType), inputs, originalType);
3764 assert(
result.size() == 1 &&
"expected single result");
3768SmallVector<Value> TypeConverter::materializeTargetConversion(
3770 Type originalType)
const {
3771 for (
const TargetMaterializationCallbackFn &fn :
3772 llvm::reverse(targetMaterializations)) {
3773 SmallVector<Value>
result =
3774 fn(builder, resultTypes, inputs, loc, originalType);
3778 "callback produced incorrect number of values or values with "
3785std::optional<TypeConverter::SignatureConversion>
3786TypeConverter::convertBlockSignature(
Block *block)
const {
3789 return std::nullopt;
3796TypeConverter::AttributeConversionResult
3797TypeConverter::AttributeConversionResult::result(Attribute attr) {
3798 return AttributeConversionResult(attr, resultTag);
3801TypeConverter::AttributeConversionResult
3802TypeConverter::AttributeConversionResult::na() {
3803 return AttributeConversionResult(
nullptr, naTag);
3806TypeConverter::AttributeConversionResult
3807TypeConverter::AttributeConversionResult::abort() {
3808 return AttributeConversionResult(
nullptr, abortTag);
3811bool TypeConverter::AttributeConversionResult::hasResult()
const {
3812 return impl.getInt() == resultTag;
3815bool TypeConverter::AttributeConversionResult::isNa()
const {
3816 return impl.getInt() == naTag;
3819bool TypeConverter::AttributeConversionResult::isAbort()
const {
3820 return impl.getInt() == abortTag;
3823Attribute TypeConverter::AttributeConversionResult::getResult()
const {
3824 assert(hasResult() &&
"Cannot get result from N/A or abort");
3825 return impl.getPointer();
3828std::optional<Attribute>
3829TypeConverter::convertTypeAttribute(Type type, Attribute attr)
const {
3830 for (
const TypeAttributeConversionCallbackFn &fn :
3831 llvm::reverse(typeAttributeConversions)) {
3832 AttributeConversionResult res = fn(type, attr);
3833 if (res.hasResult())
3834 return res.getResult();
3836 return std::nullopt;
3838 return std::nullopt;
3847 ConversionPatternRewriter &rewriter) {
3848 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3853 TypeConverter::SignatureConversion funcConversion(type.getNumInputs());
3855 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
3857 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3866 if (!funcOp.getFunctionBody().empty()) {
3867 Block *entryBlock = &funcOp.getFunctionBody().
front();
3869 unsigned numFuncTypeInputs = type.getNumInputs();
3870 TypeConverter::SignatureConversion blockConversion(numEntryBlockArgs);
3872 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
3877 for (
unsigned i = numFuncTypeInputs; i < numEntryBlockArgs; ++i)
3879 rewriter.applySignatureConversion(entryBlock, blockConversion,
3883 auto newType = FunctionType::get(
3884 rewriter.getContext(), funcConversion.getConvertedTypes(), newResults);
3886 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3895struct FunctionOpInterfaceSignatureConversion :
public ConversionPattern {
3896 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3898 const TypeConverter &converter,
3899 PatternBenefit benefit)
3900 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3903 matchAndRewrite(Operation *op, ArrayRef<Value> ,
3904 ConversionPatternRewriter &rewriter)
const override {
3905 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3910struct AnyFunctionOpInterfaceSignatureConversion
3911 :
public OpInterfaceConversionPattern<FunctionOpInterface> {
3912 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3915 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> ,
3916 ConversionPatternRewriter &rewriter)
const override {
3922FailureOr<Operation *>
3923mlir::convertOpResultTypes(Operation *op,
ValueRange operands,
3924 const TypeConverter &converter,
3925 ConversionPatternRewriter &rewriter) {
3926 assert(op &&
"Invalid op");
3927 Location loc = op->
getLoc();
3928 if (converter.isLegal(op))
3929 return rewriter.notifyMatchFailure(loc,
"op already legal");
3931 OperationState newOp(loc, op->
getName());
3932 newOp.addOperands(operands);
3934 SmallVector<Type> newResultTypes;
3936 return rewriter.notifyMatchFailure(loc,
"couldn't convert return types");
3938 newOp.addTypes(newResultTypes);
3939 newOp.addAttributes(op->
getAttrs());
3940 return rewriter.create(newOp);
3943void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3944 StringRef functionLikeOpName, RewritePatternSet &patterns,
3945 const TypeConverter &converter, PatternBenefit benefit) {
3946 patterns.
add<FunctionOpInterfaceSignatureConversion>(
3947 functionLikeOpName, patterns.
getContext(), converter, benefit);
3950void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3951 RewritePatternSet &patterns,
const TypeConverter &converter,
3952 PatternBenefit benefit) {
3953 patterns.
add<AnyFunctionOpInterfaceSignatureConversion>(
3961void ConversionTarget::setOpAction(OperationName op,
3962 LegalizationAction action) {
3963 legalOperations[op].action = action;
3966void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3967 LegalizationAction action) {
3968 for (StringRef dialect : dialectNames)
3969 legalDialects[dialect] = action;
3972auto ConversionTarget::getOpAction(OperationName op)
const
3973 -> std::optional<LegalizationAction> {
3974 std::optional<LegalizationInfo> info = getOpInfo(op);
3975 return info ? info->action : std::optional<LegalizationAction>();
3978auto ConversionTarget::isLegal(Operation *op)
const
3979 -> std::optional<LegalOpDetails> {
3980 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3982 return std::nullopt;
3985 auto isOpLegal = [&] {
3987 if (info->action == LegalizationAction::Dynamic) {
3988 std::optional<bool>
result = info->legalityFn(op);
3994 return info->action == LegalizationAction::Legal;
3997 return std::nullopt;
4000 LegalOpDetails legalityDetails;
4001 if (info->isRecursivelyLegal) {
4002 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
4003 if (legalityFnIt != opRecursiveLegalityFns.end()) {
4004 legalityDetails.isRecursivelyLegal =
4005 legalityFnIt->second(op).value_or(
true);
4007 legalityDetails.isRecursivelyLegal =
true;
4010 return legalityDetails;
4013bool ConversionTarget::isIllegal(Operation *op)
const {
4014 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
4018 if (info->action == LegalizationAction::Dynamic) {
4019 std::optional<bool>
result = info->legalityFn(op);
4026 return info->action == LegalizationAction::Illegal;
4030 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
4031 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
4035 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
4037 if (std::optional<bool>
result = newCl(op))
4045void ConversionTarget::setLegalityCallback(
4046 OperationName name,
const DynamicLegalityCallbackFn &callback) {
4047 assert(callback &&
"expected valid legality callback");
4048 auto *infoIt = legalOperations.find(name);
4049 assert(infoIt != legalOperations.end() &&
4050 infoIt->second.action == LegalizationAction::Dynamic &&
4051 "expected operation to already be marked as dynamically legal");
4052 infoIt->second.legalityFn =
4056void ConversionTarget::markOpRecursivelyLegal(
4057 OperationName name,
const DynamicLegalityCallbackFn &callback) {
4058 auto *infoIt = legalOperations.find(name);
4059 assert(infoIt != legalOperations.end() &&
4060 infoIt->second.action != LegalizationAction::Illegal &&
4061 "expected operation to already be marked as legal");
4062 infoIt->second.isRecursivelyLegal =
true;
4065 std::move(opRecursiveLegalityFns[name]), callback);
4067 opRecursiveLegalityFns.erase(name);
4070void ConversionTarget::setLegalityCallback(
4071 ArrayRef<StringRef> dialects,
const DynamicLegalityCallbackFn &callback) {
4072 assert(callback &&
"expected valid legality callback");
4073 for (StringRef dialect : dialects)
4075 std::move(dialectLegalityFns[dialect]), callback);
4078void ConversionTarget::setLegalityCallback(
4079 const DynamicLegalityCallbackFn &callback) {
4080 assert(callback &&
"expected valid legality callback");
4084auto ConversionTarget::getOpInfo(OperationName op)
const
4085 -> std::optional<LegalizationInfo> {
4087 const auto *it = legalOperations.find(op);
4088 if (it != legalOperations.end())
4091 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4092 if (dialectIt != legalDialects.end()) {
4093 DynamicLegalityCallbackFn callback;
4094 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4095 if (dialectFn != dialectLegalityFns.end())
4096 callback = dialectFn->second;
4097 return LegalizationInfo{dialectIt->second,
false,
4101 if (unknownLegalityFn)
4102 return LegalizationInfo{LegalizationAction::Dynamic,
4103 false, unknownLegalityFn};
4104 return std::nullopt;
4107#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4112void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4113 auto &rewriterImpl =
4114 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4118void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4119 auto &rewriterImpl =
4120 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4126static FailureOr<SmallVector<Value>>
4127pdllConvertValues(ConversionPatternRewriter &rewriter,
ValueRange values) {
4128 SmallVector<Value> mappedValues;
4129 if (
failed(rewriter.getRemappedValues(values, mappedValues)))
4131 return std::move(mappedValues);
4134void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
4137 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4138 auto results = pdllConvertValues(
4139 static_cast<ConversionPatternRewriter &
>(rewriter), value);
4142 return results->front();
4145 "convertValues", [](PatternRewriter &rewriter,
ValueRange values) {
4146 return pdllConvertValues(
4147 static_cast<ConversionPatternRewriter &
>(rewriter), values);
4151 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4152 auto &rewriterImpl =
4153 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4154 if (
const TypeConverter *converter =
4156 if (Type newType = converter->convertType(type))
4164 [](PatternRewriter &rewriter,
4165 TypeRange types) -> FailureOr<SmallVector<Type>> {
4166 auto &rewriterImpl =
4167 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4170 return SmallVector<Type>(types);
4172 SmallVector<Type> remappedTypes;
4173 if (
failed(converter->convertTypes(types, remappedTypes)))
4175 return std::move(remappedTypes);
4190 static constexpr StringLiteral
tag =
"apply-conversion";
4191 static constexpr StringLiteral
desc =
4192 "Encapsulate the application of a dialect conversion";
4200 ConversionConfig config,
4201 OpConversionMode mode) {
4205 LogicalResult status =
success();
4210 patterns, config, mode);
4221LogicalResult mlir::applyPartialConversion(
4222 ArrayRef<Operation *> ops,
const ConversionTarget &
target,
4223 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4225 OpConversionMode::Partial);
4228mlir::applyPartialConversion(Operation *op,
const ConversionTarget &
target,
4229 const FrozenRewritePatternSet &patterns,
4230 ConversionConfig config) {
4231 return applyPartialConversion(llvm::ArrayRef(op),
target, patterns, config);
4238LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4239 const ConversionTarget &
target,
4240 const FrozenRewritePatternSet &patterns,
4241 ConversionConfig config) {
4244LogicalResult mlir::applyFullConversion(Operation *op,
4245 const ConversionTarget &
target,
4246 const FrozenRewritePatternSet &patterns,
4247 ConversionConfig config) {
4248 return applyFullConversion(llvm::ArrayRef(op),
target, patterns, config);
4265 "expected top-level op to be isolated from above");
4268 "expected ops to have a common ancestor");
4277 for (
Operation *op : ops.drop_front()) {
4281 assert(commonAncestor &&
4282 "expected to find a common isolated from above ancestor");
4286 return commonAncestor;
4289LogicalResult mlir::applyAnalysisConversion(
4290 ArrayRef<Operation *> ops, ConversionTarget &
target,
4291 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4293 if (config.legalizableOps)
4294 assert(config.legalizableOps->empty() &&
"expected empty set");
4300 Operation *clonedAncestor = commonAncestor->
clone(mapping);
4304 inverseOperationMap[it.second] = it.first;
4307 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4308 ops, [&](Operation *op) {
return mapping.
lookup(op); });
4310 OpConversionMode::Analysis);
4314 if (config.legalizableOps) {
4316 for (Operation *op : *config.legalizableOps)
4317 originalLegalizableOps.insert(inverseOperationMap[op]);
4318 *config.legalizableOps = std::move(originalLegalizableOps);
4322 clonedAncestor->
erase();
4327mlir::applyAnalysisConversion(Operation *op, ConversionTarget &
target,
4328 const FrozenRewritePatternSet &patterns,
4329 ConversionConfig config) {
4330 return applyAnalysisConversion(llvm::ArrayRef(op),
target, patterns, config);
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
SmallVector< Value, 2 > ValueVector
A vector of SSA values, optimized for the most common case of one or two values.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
type_range getTypes() const
void destroyOpProperties(PropertyRef properties) const
This hooks destroy the op properties.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
TypeID getOpPropertiesTypeID() const
Return the TypeID of the op properties.
Operation is the basic unit of execution within MLIR.
PropertyRef getPropertiesStorage()
Return a generic (but typed) reference to the property type storage.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
void copyProperties(PropertyRef rhs)
Copy properties from an existing other properties object.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
CRTP Implementation of an action.
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
llvm::SetVector< T, Vector, Set, N > SetVector
static void reconcileUnrealizedCasts(const llvm::MapVector< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This iterator enumerates elements according to their dominance relationship.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
const ConversionTarget & getTarget()
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, bool isRecursiveLegalization=false)
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, Fn onFailure, bool isRecursiveLegalization=false)
Legalizes the given operations (and their nested operations) to the conversion target.
LogicalResult applyConversion(ArrayRef< Operation * > ops)
Applies the conversion to the given operations (and their nested operations).
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
bool wasErased(void *ptr) const
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
llvm::MapVector< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.