10#include "mlir/Config/mlir-config.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/SaveAndRestore.h"
27#include "llvm/Support/ScopedPrinter.h"
34#define DEBUG_TYPE "dialect-conversion"
37template <
typename... Args>
38static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
41 os.startLine() <<
"} -> SUCCESS";
43 os.getOStream() <<
" : "
44 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
45 os.getOStream() <<
"\n";
50template <
typename... Args>
51static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
54 os.startLine() <<
"} -> FAILURE : "
55 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
65 if (
OpResult inputRes = dyn_cast<OpResult>(value))
66 insertPt = ++inputRes.getOwner()->getIterator();
73 assert(!vals.empty() &&
"expected at least one value");
76 for (
Value v : vals.drop_front()) {
90 assert(dom &&
"unable to find valid insertion point");
98enum OpConversionMode {
124struct ValueVectorMapInfo {
127 static ::llvm::hash_code getHashValue(
const ValueVector &val) {
128 return ::llvm::hash_combine_range(val);
137struct ConversionValueMapping {
140 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
145 template <
typename T>
146 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
149 template <
typename OldVal,
typename NewVal>
150 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
151 map(OldVal &&oldVal, NewVal &&newVal) {
155 assert(next != oldVal &&
"inserting cyclic mapping");
156 auto it = mapping.find(next);
157 if (it == mapping.end())
162 mappedTo.insert_range(newVal);
164 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
168 template <
typename OldVal,
typename NewVal>
169 std::enable_if_t<!IsValueVector<OldVal>::value ||
170 !IsValueVector<NewVal>::value>
171 map(OldVal &&oldVal, NewVal &&newVal) {
172 if constexpr (IsValueVector<OldVal>{}) {
173 map(std::forward<OldVal>(oldVal),
ValueVector{newVal});
174 }
else if constexpr (IsValueVector<NewVal>{}) {
175 map(
ValueVector{oldVal}, std::forward<NewVal>(newVal));
186 void erase(
const ValueVector &value) { mapping.erase(value); }
206 assert(!values.empty() &&
"expected non-empty value vector");
207 Operation *op = values.front().getDefiningOp();
208 for (
Value v : llvm::drop_begin(values)) {
209 if (v.getDefiningOp() != op)
219 assert(!values.empty() &&
"expected non-empty value vector");
225 auto it = mapping.find(from);
226 if (it == mapping.end()) {
239struct RewriterState {
240 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
241 unsigned numReplacedOps)
242 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
243 numReplacedOps(numReplacedOps) {}
246 unsigned numRewrites;
249 unsigned numIgnoredOperations;
252 unsigned numReplacedOps;
259static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
262static void notifyIRErased(RewriterBase::Listener *listener,
Block &
b) {
263 for (Operation &op :
b)
264 notifyIRErased(listener, op);
270static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
273 notifyIRErased(listener,
b);
303 UnresolvedMaterialization,
308 virtual ~IRRewrite() =
default;
311 virtual void rollback() = 0;
325 virtual void commit(RewriterBase &rewriter) {}
328 virtual void cleanup(RewriterBase &rewriter) {}
330 Kind getKind()
const {
return kind; }
332 static bool classof(
const IRRewrite *
rewrite) {
return true; }
335 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
336 : kind(kind), rewriterImpl(rewriterImpl) {}
338 const ConversionConfig &getConfig()
const;
341 ConversionPatternRewriterImpl &rewriterImpl;
345class BlockRewrite :
public IRRewrite {
348 Block *getBlock()
const {
return block; }
350 static bool classof(
const IRRewrite *
rewrite) {
351 return rewrite->getKind() >= Kind::CreateBlock &&
352 rewrite->getKind() <= Kind::BlockTypeConversion;
356 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
358 : IRRewrite(kind, rewriterImpl), block(block) {}
365class ValueRewrite :
public IRRewrite {
368 Value getValue()
const {
return value; }
370 static bool classof(
const IRRewrite *
rewrite) {
371 return rewrite->getKind() == Kind::ReplaceValue;
375 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
377 : IRRewrite(kind, rewriterImpl), value(value) {}
386class CreateBlockRewrite :
public BlockRewrite {
388 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
389 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
391 static bool classof(
const IRRewrite *
rewrite) {
392 return rewrite->getKind() == Kind::CreateBlock;
395 void commit(RewriterBase &rewriter)
override {
401 void rollback()
override {
404 auto &blockOps = block->getOperations();
405 while (!blockOps.empty())
406 blockOps.remove(blockOps.begin());
407 block->dropAllUses();
408 if (block->getParent())
419class EraseBlockRewrite :
public BlockRewrite {
421 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
422 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
423 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
425 static bool classof(
const IRRewrite *
rewrite) {
426 return rewrite->getKind() == Kind::EraseBlock;
429 ~EraseBlockRewrite()
override {
431 "rewrite was neither rolled back nor committed/cleaned up");
434 void rollback()
override {
437 assert(block &&
"expected block");
442 blockList.insert(before, block);
446 void commit(RewriterBase &rewriter)
override {
447 assert(block &&
"expected block");
451 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
452 notifyIRErased(listener, *block);
455 void cleanup(RewriterBase &rewriter)
override {
457 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
459 assert(block->empty() &&
"expected empty block");
462 block->dropAllDefinedValueUses();
473 Block *insertBeforeBlock;
479class InlineBlockRewrite :
public BlockRewrite {
481 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
483 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
484 sourceBlock(sourceBlock),
485 firstInlinedInst(sourceBlock->empty() ?
nullptr
486 : &sourceBlock->front()),
487 lastInlinedInst(sourceBlock->empty() ?
nullptr : &sourceBlock->back()) {
493 assert(!getConfig().listener &&
494 "InlineBlockRewrite not supported if listener is attached");
497 static bool classof(
const IRRewrite *
rewrite) {
498 return rewrite->getKind() == Kind::InlineBlock;
501 void rollback()
override {
504 if (firstInlinedInst) {
505 assert(lastInlinedInst &&
"expected operation");
518 Operation *firstInlinedInst;
521 Operation *lastInlinedInst;
525class MoveBlockRewrite :
public BlockRewrite {
527 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
529 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block),
530 region(previousRegion),
531 insertBeforeBlock(previousIt == previousRegion->end() ?
nullptr
534 static bool classof(
const IRRewrite *
rewrite) {
535 return rewrite->getKind() == Kind::MoveBlock;
538 void commit(RewriterBase &rewriter)
override {
548 void rollback()
override {
552 if (Region *currentParent = block->
getParent()) {
554 region->getBlocks().splice(before, currentParent->getBlocks(), block);
558 region->
getBlocks().insert(before, block);
567 Block *insertBeforeBlock;
571class BlockTypeConversionRewrite :
public BlockRewrite {
573 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
575 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, origBlock),
576 newBlock(newBlock) {}
578 static bool classof(
const IRRewrite *
rewrite) {
579 return rewrite->getKind() == Kind::BlockTypeConversion;
582 Block *getOrigBlock()
const {
return block; }
584 Block *getNewBlock()
const {
return newBlock; }
586 void commit(RewriterBase &rewriter)
override;
588 void rollback()
override;
598class ReplaceValueRewrite :
public ValueRewrite {
600 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
601 const TypeConverter *converter)
602 : ValueRewrite(
Kind::ReplaceValue, rewriterImpl, value),
603 converter(converter) {}
605 static bool classof(
const IRRewrite *
rewrite) {
606 return rewrite->getKind() == Kind::ReplaceValue;
609 void commit(RewriterBase &rewriter)
override;
611 void rollback()
override;
615 const TypeConverter *converter;
619class OperationRewrite :
public IRRewrite {
622 Operation *getOperation()
const {
return op; }
624 static bool classof(
const IRRewrite *
rewrite) {
625 return rewrite->getKind() >= Kind::MoveOperation &&
626 rewrite->getKind() <= Kind::UnresolvedMaterialization;
630 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
632 : IRRewrite(kind, rewriterImpl), op(op) {}
639class MoveOperationRewrite :
public OperationRewrite {
641 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
642 Operation *op, OpBuilder::InsertPoint previous)
643 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op),
644 block(previous.getBlock()),
645 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
647 : &*previous.getPoint()) {}
649 static bool classof(
const IRRewrite *
rewrite) {
650 return rewrite->getKind() == Kind::MoveOperation;
653 void commit(RewriterBase &rewriter)
override {
659 op, OpBuilder::InsertPoint(block,
664 void rollback()
override {
677 Operation *insertBeforeOp;
682class ModifyOperationRewrite :
public OperationRewrite {
684 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
686 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
687 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
688 operands(op->operand_begin(), op->operand_end()),
689 successors(op->successor_begin(), op->successor_end()) {
692 propertiesStorage = operator new(op->getPropertiesStorageSize());
693 OpaqueProperties propCopy(propertiesStorage);
694 name.initOpProperties(propCopy, prop);
698 static bool classof(
const IRRewrite *
rewrite) {
699 return rewrite->getKind() == Kind::ModifyOperation;
702 ~ModifyOperationRewrite()
override {
703 assert(!propertiesStorage &&
704 "rewrite was neither committed nor rolled back");
707 void commit(RewriterBase &rewriter)
override {
710 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
713 if (propertiesStorage) {
714 OpaqueProperties propCopy(propertiesStorage);
718 operator delete(propertiesStorage);
719 propertiesStorage =
nullptr;
723 void rollback()
override {
727 for (
const auto &it : llvm::enumerate(successors))
729 if (propertiesStorage) {
730 OpaqueProperties propCopy(propertiesStorage);
733 operator delete(propertiesStorage);
734 propertiesStorage =
nullptr;
741 DictionaryAttr attrs;
742 SmallVector<Value, 8> operands;
743 SmallVector<Block *, 2> successors;
744 void *propertiesStorage =
nullptr;
751class ReplaceOperationRewrite :
public OperationRewrite {
753 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
754 Operation *op,
const TypeConverter *converter)
755 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
756 converter(converter) {}
758 static bool classof(
const IRRewrite *
rewrite) {
759 return rewrite->getKind() == Kind::ReplaceOperation;
762 void commit(RewriterBase &rewriter)
override;
764 void rollback()
override;
766 void cleanup(RewriterBase &rewriter)
override;
771 const TypeConverter *converter;
774class CreateOperationRewrite :
public OperationRewrite {
776 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
778 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
780 static bool classof(
const IRRewrite *
rewrite) {
781 return rewrite->getKind() == Kind::CreateOperation;
784 void commit(RewriterBase &rewriter)
override {
790 void rollback()
override;
794enum MaterializationKind {
805class UnresolvedMaterializationInfo {
807 UnresolvedMaterializationInfo() =
default;
808 UnresolvedMaterializationInfo(
const TypeConverter *converter,
809 MaterializationKind kind, Type originalType)
810 : converterAndKind(converter, kind), originalType(originalType) {}
813 const TypeConverter *getConverter()
const {
814 return converterAndKind.getPointer();
818 MaterializationKind getMaterializationKind()
const {
819 return converterAndKind.getInt();
823 Type getOriginalType()
const {
return originalType; }
828 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
839class UnresolvedMaterializationRewrite :
public OperationRewrite {
841 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
842 UnrealizedConversionCastOp op,
844 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
845 mappedValues(std::move(mappedValues)) {}
847 static bool classof(
const IRRewrite *
rewrite) {
848 return rewrite->getKind() == Kind::UnresolvedMaterialization;
851 void rollback()
override;
853 UnrealizedConversionCastOp getOperation()
const {
854 return cast<UnrealizedConversionCastOp>(op);
864#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
867template <
typename RewriteTy,
typename R>
868static bool hasRewrite(R &&rewrites, Operation *op) {
869 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
870 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
871 return rewriteTy && rewriteTy->getOperation() == op;
877template <
typename RewriteTy,
typename R>
878static bool hasRewrite(R &&rewrites,
Block *block) {
879 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
880 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
881 return rewriteTy && rewriteTy->getBlock() == block;
893 const ConversionConfig &
config,
903 RewriterState getCurrentState();
907 void applyRewrites();
912 void resetState(RewriterState state, StringRef patternName =
"");
916 template <
typename RewriteTy,
typename... Args>
918 assert(
config.allowPatternRollback &&
"appending rewrites is not allowed");
920 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
926 void undoRewrites(
unsigned numRewritesToKeep = 0, StringRef patternName =
"");
932 LogicalResult remapValues(StringRef valueDiagTag,
933 std::optional<Location> inputLoc,
ValueRange values,
950 bool skipPureTypeConversions =
false)
const;
964 TypeConverter::SignatureConversion *entryConversion);
972 Block *applySignatureConversion(
974 TypeConverter::SignatureConversion &signatureConversion);
994 void eraseBlock(
Block *block);
1032 Value findOrBuildReplacementValue(
Value value,
1040 void notifyOperationInserted(
Operation *op,
1044 void notifyBlockInserted(
Block *block,
Region *previous,
1063 std::function<
void(
Operation *)> opErasedCallback =
nullptr)
1065 opErasedCallback(std::move(opErasedCallback)) {}
1079 assert(block->empty() &&
"expected empty block");
1080 block->dropAllDefinedValueUses();
1088 if (opErasedCallback)
1089 opErasedCallback(op);
1197const ConversionConfig &IRRewrite::getConfig()
const {
1198 return rewriterImpl.
config;
1201void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1205 if (
auto *listener =
1206 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1207 for (Operation *op : getNewBlock()->getUsers())
1211void BlockTypeConversionRewrite::rollback() {
1212 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1219 if (isa<BlockArgument>(repl)) {
1259 result &= functor(operand);
1264void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1271void ReplaceValueRewrite::rollback() {
1272 rewriterImpl.
mapping.erase({value});
1278void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1280 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1283 SmallVector<Value> replacements =
1285 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1293 for (
auto [
result, newValue] :
1294 llvm::zip_equal(op->
getResults(), replacements))
1300 if (getConfig().unlegalizedOps)
1301 getConfig().unlegalizedOps->erase(op);
1305 notifyIRErased(listener, *op);
1312void ReplaceOperationRewrite::rollback() {
1317void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1321void CreateOperationRewrite::rollback() {
1323 while (!region.getBlocks().empty())
1324 region.getBlocks().remove(region.getBlocks().begin());
1330void UnresolvedMaterializationRewrite::rollback() {
1331 if (!mappedValues.empty())
1332 rewriterImpl.
mapping.erase(mappedValues);
1343 for (
size_t i = 0; i <
rewrites.size(); ++i)
1349 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1350 unresolvedMaterializations.erase(castOp);
1353 rewrite->cleanup(eraseRewriter);
1361 Value from,
TypeRange desiredTypes,
bool skipPureTypeConversions)
const {
1364 assert(!values.empty() &&
"expected non-empty value vector");
1368 if (
config.allowPatternRollback)
1369 return mapping.lookup(values);
1376 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1381 if (castOp.getOutputs() != values)
1383 return castOp.getInputs();
1392 for (
Value v : values) {
1395 llvm::append_range(next, r);
1400 if (next != values) {
1429 if (skipPureTypeConversions) {
1432 match &= !pureConversion;
1435 if (!pureConversion)
1436 lastNonMaterialization = current;
1439 desiredValue = current;
1445 current = std::move(next);
1450 if (!desiredTypes.empty())
1451 return desiredValue;
1452 if (skipPureTypeConversions)
1453 return lastNonMaterialization;
1472 StringRef patternName) {
1477 while (
ignoredOps.size() != state.numIgnoredOperations)
1480 while (
replacedOps.size() != state.numReplacedOps)
1485 StringRef patternName) {
1487 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1489 rewrites.resize(numRewritesToKeep);
1493 StringRef valueDiagTag, std::optional<Location> inputLoc,
ValueRange values,
1495 remapped.reserve(llvm::size(values));
1497 for (
const auto &it : llvm::enumerate(values)) {
1498 Value operand = it.value();
1517 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1518 << it.index() <<
", type was " << origType;
1523 if (legalTypes.empty()) {
1524 remapped.push_back({});
1533 remapped.push_back(std::move(repl));
1542 repl, repl, legalTypes,
1544 remapped.push_back(castValues);
1565 TypeConverter::SignatureConversion *entryConversion) {
1567 if (region->
empty())
1572 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1574 std::optional<TypeConverter::SignatureConversion> conversion =
1575 converter.convertBlockSignature(&block);
1584 if (entryConversion)
1587 std::optional<TypeConverter::SignatureConversion> conversion =
1588 converter.convertBlockSignature(®ion->
front());
1596 TypeConverter::SignatureConversion &signatureConversion) {
1597#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1599 if (hasRewrite<BlockTypeConversionRewrite>(
rewrites, block))
1600 llvm::reportFatalInternalError(
"block was already converted");
1607 auto convertedTypes = signatureConversion.getConvertedTypes();
1614 for (
unsigned i = 0; i < origArgCount; ++i) {
1615 auto inputMap = signatureConversion.getInputMapping(i);
1616 if (!inputMap || inputMap->replacedWithValues())
1619 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1620 newLocs[inputMap->inputNo +
j] = origLoc;
1627 convertedTypes, newLocs);
1635 bool fastPath = !
config.listener;
1637 if (
config.allowPatternRollback)
1641 while (!block->
empty())
1648 for (
unsigned i = 0; i != origArgCount; ++i) {
1652 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1653 signatureConversion.getInputMapping(i);
1661 MaterializationKind::Source,
1665 origArgType,
Type(), converter,
1672 if (inputMap->replacedWithValues()) {
1674 assert(inputMap->size == 0 &&
1675 "invalid to provide a replacement value when the argument isn't "
1683 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1687 if (
config.allowPatternRollback)
1708 assert((!originalType || kind == MaterializationKind::Target) &&
1709 "original type is valid only for target materializations");
1710 assert(
TypeRange(inputs) != outputTypes &&
1711 "materialization is not necessary");
1715 OpBuilder builder(outputTypes.front().getContext());
1717 UnrealizedConversionCastOp convertOp =
1718 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1719 if (
config.attachDebugMaterializationKind) {
1721 kind == MaterializationKind::Source ?
"source" :
"target";
1722 convertOp->setAttr(
"__kind__", builder.
getStringAttr(kindStr));
1729 UnresolvedMaterializationInfo(converter, kind, originalType);
1730 if (
config.allowPatternRollback) {
1731 if (!valuesToMap.empty())
1732 mapping.map(valuesToMap, convertOp.getResults());
1734 std::move(valuesToMap));
1738 return convertOp.getResults();
1743 assert(
config.allowPatternRollback &&
1744 "this code path is valid only in rollback mode");
1751 return repl.front();
1758 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1783 MaterializationKind::Source, ip, value.
getLoc(),
1799 bool wasDetached = !previous.
isSet();
1801 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"' (" << op
1804 logger.getOStream() <<
" (was detached)";
1805 logger.getOStream() <<
"\n";
1811 "attempting to insert into a block within a replaced/erased op");
1815 config.listener->notifyOperationInserted(op, previous);
1824 if (
config.allowPatternRollback) {
1838 if (
config.allowPatternRollback)
1848 assert(!
impl.config.allowPatternRollback &&
1849 "this code path is valid only in 'no rollback' mode");
1851 for (
auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1854 repls.push_back(
Value());
1861 Value srcMat =
impl.buildUnresolvedMaterialization(
1866 repls.push_back(srcMat);
1872 repls.push_back(to[0]);
1881 Value srcMat =
impl.buildUnresolvedMaterialization(
1884 Type(), converter)[0];
1885 repls.push_back(srcMat);
1894 "incorrect number of replacement values");
1896 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
1904 for (
auto [
result, repls] :
1905 llvm::zip_equal(op->
getResults(), newValues)) {
1907 auto logProlog = [&, repls = repls]() {
1908 logger.startLine() <<
" Note: Replacing op result of type "
1909 << resultType <<
" with value(s) of type (";
1910 llvm::interleaveComma(repls,
logger.getOStream(), [&](
Value v) {
1911 logger.getOStream() << v.getType();
1913 logger.getOStream() <<
")";
1919 logger.getOStream() <<
", but the type converter failed to legalize "
1920 "the original type.\n";
1925 logger.getOStream() <<
", but the legalized type(s) is/are (";
1926 llvm::interleaveComma(convertedTypes,
logger.getOStream(),
1927 [&](
Type t) { logger.getOStream() << t; });
1928 logger.getOStream() <<
")\n";
1934 if (!
config.allowPatternRollback) {
1943 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1949 if (
config.unlegalizedOps)
1950 config.unlegalizedOps->erase(op);
1958 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1962 "attempting to replace a value that was already replaced");
1967 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1972 "attempting to replace/erase an unresolved materialization");
1988 logger.startLine() <<
"** Replace Value : '" << from <<
"'";
1989 if (
auto blockArg = dyn_cast<BlockArgument>(from)) {
1991 logger.getOStream() <<
" (in region of '" << parentOp->getName()
1992 <<
"' (" << parentOp <<
")";
1994 logger.getOStream() <<
" (unlinked block)";
1998 logger.getOStream() <<
", conditional replacement";
2002 if (!
config.allowPatternRollback) {
2007 Value repl = repls.front();
2024 "attempting to replace a value that was already replaced");
2026 "attempting to replace a op result that was already replaced");
2031 llvm::reportFatalInternalError(
2032 "conditional value replacement is not supported in rollback mode");
2038 if (!
config.allowPatternRollback) {
2045 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2051 if (
config.unlegalizedOps)
2052 config.unlegalizedOps->erase(op);
2061 "attempting to erase a block within a replaced/erased op");
2077 bool wasDetached = !previous;
2083 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
2084 <<
"' (" << parent <<
")";
2087 <<
"** Insert Block into detached Region (nullptr parent op)";
2090 logger.getOStream() <<
" (was detached)";
2091 logger.getOStream() <<
"\n";
2097 "attempting to insert into a region within a replaced/erased op");
2102 config.listener->notifyBlockInserted(block, previous, previousIt);
2106 if (
config.allowPatternRollback) {
2120 if (
config.allowPatternRollback)
2134 reasonCallback(
diag);
2135 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
2136 if (
config.notifyCallback)
2145ConversionPatternRewriter::ConversionPatternRewriter(
2149 *this,
config, opConverter)) {
2150 setListener(
impl.get());
2153ConversionPatternRewriter::~ConversionPatternRewriter() =
default;
2155const ConversionConfig &ConversionPatternRewriter::getConfig()
const {
2156 return impl->config;
2160 assert(op && newOp &&
"expected non-null op");
2164void ConversionPatternRewriter::replaceOp(Operation *op,
ValueRange newValues) {
2166 "incorrect # of replacement values");
2170 if (getInsertionPoint() == op->getIterator())
2173 SmallVector<SmallVector<Value>> newVals =
2174 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2175 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2177 impl->replaceOp(op, std::move(newVals));
2180void ConversionPatternRewriter::replaceOpWithMultiple(
2181 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2183 "incorrect # of replacement values");
2187 if (getInsertionPoint() == op->getIterator())
2190 impl->replaceOp(op, std::move(newValues));
2193void ConversionPatternRewriter::eraseOp(Operation *op) {
2195 impl->logger.startLine()
2196 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
2201 if (getInsertionPoint() == op->getIterator())
2204 SmallVector<SmallVector<Value>> nullRepls(op->
getNumResults(), {});
2205 impl->replaceOp(op, std::move(nullRepls));
2208void ConversionPatternRewriter::eraseBlock(
Block *block) {
2209 impl->eraseBlock(block);
2212Block *ConversionPatternRewriter::applySignatureConversion(
2213 Block *block, TypeConverter::SignatureConversion &conversion,
2214 const TypeConverter *converter) {
2215 assert(!impl->wasOpReplaced(block->
getParentOp()) &&
2216 "attempting to apply a signature conversion to a block within a "
2217 "replaced/erased op");
2218 return impl->applySignatureConversion(block, converter, conversion);
2221FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2222 Region *region,
const TypeConverter &converter,
2223 TypeConverter::SignatureConversion *entryConversion) {
2224 assert(!impl->wasOpReplaced(region->
getParentOp()) &&
2225 "attempting to apply a signature conversion to a block within a "
2226 "replaced/erased op");
2227 return impl->convertRegionTypes(region, converter, entryConversion);
2230void ConversionPatternRewriter::replaceAllUsesWith(Value from,
ValueRange to) {
2231 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2234void ConversionPatternRewriter::replaceUsesWithIf(
2236 bool *allUsesReplaced) {
2237 assert(!allUsesReplaced &&
2238 "allUsesReplaced is not supported in a dialect conversion");
2239 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2242Value ConversionPatternRewriter::getRemappedValue(Value key) {
2243 SmallVector<ValueVector> remappedValues;
2244 if (
failed(impl->remapValues(
"value", std::nullopt, key,
2247 assert(remappedValues.front().size() == 1 &&
"1:N conversion not supported");
2248 return remappedValues.front().front();
2252ConversionPatternRewriter::getRemappedValues(
ValueRange keys,
2253 SmallVectorImpl<Value> &results) {
2256 SmallVector<ValueVector> remapped;
2257 if (
failed(impl->remapValues(
"value", std::nullopt, keys,
2260 for (
const auto &values : remapped) {
2261 assert(values.size() == 1 &&
"1:N conversion not supported");
2262 results.push_back(values.front());
2267void ConversionPatternRewriter::inlineBlockBefore(
Block *source,
Block *dest,
2272 "incorrect # of argument replacement values");
2273 assert(!impl->wasOpReplaced(source->
getParentOp()) &&
2274 "attempting to inline a block from a replaced/erased op");
2275 assert(!impl->wasOpReplaced(dest->
getParentOp()) &&
2276 "attempting to inline a block into a replaced/erased op");
2277 auto opIgnored = [&](Operation *op) {
return impl->isOpIgnored(op); };
2280 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
2281 "expected 'source' to have no predecessors");
2290 bool fastPath = !getConfig().listener;
2292 if (fastPath && impl->config.allowPatternRollback)
2293 impl->inlineBlockBefore(source, dest, before);
2296 for (
auto it : llvm::zip(source->
getArguments(), argValues))
2297 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2304 while (!source->
empty())
2305 moveOpBefore(&source->
front(), dest, before);
2310 if (getInsertionBlock() == source)
2311 setInsertionPoint(dest, getInsertionPoint());
2317void ConversionPatternRewriter::startOpModification(Operation *op) {
2318 if (!impl->config.allowPatternRollback) {
2323 assert(!impl->wasOpReplaced(op) &&
2324 "attempting to modify a replaced/erased op");
2326 impl->pendingRootUpdates.insert(op);
2328 impl->appendRewrite<ModifyOperationRewrite>(op);
2331void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2332 impl->patternModifiedOps.insert(op);
2333 if (!impl->config.allowPatternRollback) {
2335 if (getConfig().listener)
2336 getConfig().listener->notifyOperationModified(op);
2343 assert(!impl->wasOpReplaced(op) &&
2344 "attempting to modify a replaced/erased op");
2345 assert(impl->pendingRootUpdates.erase(op) &&
2346 "operation did not have a pending in-place update");
2350void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2351 if (!impl->config.allowPatternRollback) {
2356 assert(impl->pendingRootUpdates.erase(op) &&
2357 "operation did not have a pending in-place update");
2360 auto it = llvm::find_if(
2361 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
2362 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2363 return modifyRewrite && modifyRewrite->getOperation() == op;
2365 assert(it != impl->rewrites.rend() &&
"no root update started on op");
2367 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2368 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2371detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2379FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2380 ArrayRef<ValueRange> operands)
const {
2381 SmallVector<Value> oneToOneOperands;
2382 oneToOneOperands.reserve(operands.size());
2384 if (operand.size() != 1)
2387 oneToOneOperands.push_back(operand.front());
2389 return std::move(oneToOneOperands);
2393ConversionPattern::matchAndRewrite(Operation *op,
2394 PatternRewriter &rewriter)
const {
2395 auto &dialectRewriter =
static_cast<ConversionPatternRewriter &
>(rewriter);
2396 auto &rewriterImpl = dialectRewriter.getImpl();
2400 getTypeConverter());
2403 SmallVector<ValueVector> remapped;
2408 SmallVector<ValueRange> remappedAsRange =
2409 llvm::to_vector_of<ValueRange>(remapped);
2410 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2419using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2422class OperationLegalizer {
2424 using LegalizationAction = ConversionTarget::LegalizationAction;
2426 OperationLegalizer(ConversionPatternRewriter &rewriter,
2427 const ConversionTarget &targetInfo,
2428 const FrozenRewritePatternSet &
patterns);
2431 bool isIllegal(Operation *op)
const;
2435 LogicalResult legalize(Operation *op);
2438 const ConversionTarget &getTarget() {
return target; }
2442 LogicalResult legalizeWithFold(Operation *op);
2446 LogicalResult legalizeWithPattern(Operation *op);
2450 bool canApplyPattern(Operation *op,
const Pattern &pattern);
2454 legalizePatternResult(Operation *op,
const Pattern &pattern,
2455 const RewriterState &curState,
2472 void buildLegalizationGraph(
2473 LegalizationPatterns &anyOpLegalizerPatterns,
2484 void computeLegalizationGraphBenefit(
2485 LegalizationPatterns &anyOpLegalizerPatterns,
2490 unsigned computeOpLegalizationDepth(
2497 unsigned applyCostModelToPatterns(
2503 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2506 ConversionPatternRewriter &rewriter;
2509 const ConversionTarget &
target;
2512 PatternApplicator applicator;
2516OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2517 const ConversionTarget &targetInfo,
2518 const FrozenRewritePatternSet &
patterns)
2523 LegalizationPatterns anyOpLegalizerPatterns;
2525 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2526 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2529bool OperationLegalizer::isIllegal(Operation *op)
const {
2530 return target.isIllegal(op);
2533LogicalResult OperationLegalizer::legalize(Operation *op) {
2535 const char *logLineComment =
2536 "//===-------------------------------------------===//\n";
2538 auto &logger = rewriter.getImpl().logger;
2542 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2545 logger.getOStream() <<
"\n";
2546 logger.startLine() << logLineComment;
2547 logger.startLine() <<
"Legalizing operation : ";
2552 logger.getOStream() <<
"'" << op->
getName() <<
"' ";
2553 logger.getOStream() <<
"(" << op <<
") {\n";
2558 logger.startLine() << OpWithFlags(op,
2559 OpPrintingFlags().printGenericOpForm())
2566 logSuccess(logger,
"operation marked 'ignored' during conversion");
2567 logger.startLine() << logLineComment;
2573 if (
auto legalityInfo =
target.isLegal(op)) {
2576 logger,
"operation marked legal by the target{0}",
2577 legalityInfo->isRecursivelyLegal
2578 ?
"; NOTE: operation is recursively legal; skipping internals"
2580 logger.startLine() << logLineComment;
2585 if (legalityInfo->isRecursivelyLegal) {
2586 op->
walk([&](Operation *nested) {
2588 rewriter.getImpl().ignoredOps.
insert(nested);
2597 const ConversionConfig &
config = rewriter.getConfig();
2598 if (
config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2599 if (succeeded(legalizeWithFold(op))) {
2602 logger.startLine() << logLineComment;
2609 if (succeeded(legalizeWithPattern(op))) {
2612 logger.startLine() << logLineComment;
2619 if (
config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2620 if (succeeded(legalizeWithFold(op))) {
2623 logger.startLine() << logLineComment;
2630 logFailure(logger,
"no matched legalization pattern");
2631 logger.startLine() << logLineComment;
2638template <
typename T>
2640 T
result = std::move(obj);
2645LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2646 auto &rewriterImpl = rewriter.getImpl();
2648 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2649 rewriterImpl.
logger.indent();
2654 llvm::scope_exit cleanup([&]() {
2664 SmallVector<Value, 2> replacementValues;
2665 SmallVector<Operation *, 2> newOps;
2668 if (
failed(rewriter.
tryFold(op, replacementValues, &newOps))) {
2677 if (replacementValues.empty())
2678 return legalize(op);
2681 rewriter.
replaceOp(op, replacementValues);
2684 for (Operation *newOp : newOps) {
2685 if (
failed(legalize(newOp))) {
2687 "failed to legalize generated constant '{0}'",
2689 if (!rewriter.getConfig().allowPatternRollback) {
2691 llvm::reportFatalInternalError(
2693 "' folder rollback of IR modifications requested");
2711 auto newOpNames = llvm::map_range(
2713 auto modifiedOpNames = llvm::map_range(
2715 llvm::reportFatalInternalError(
"pattern '" + pattern.
getDebugName() +
2716 "' produced IR that could not be legalized. " +
2717 "new ops: {" + llvm::join(newOpNames,
", ") +
2718 "}, " +
"modified ops: {" +
2719 llvm::join(modifiedOpNames,
", ") +
"}");
2722LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2723 auto &rewriterImpl = rewriter.getImpl();
2724 const ConversionConfig &
config = rewriter.getConfig();
2726#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2728 std::optional<OperationFingerPrint> topLevelFingerPrint;
2729 if (!rewriterImpl.
config.allowPatternRollback) {
2736 topLevelFingerPrint = OperationFingerPrint(checkOp);
2742 rewriterImpl.
logger.startLine()
2743 <<
"WARNING: Multi-threadeding is enabled. Some dialect "
2744 "conversion expensive checks are skipped in multithreading "
2753 auto canApply = [&](
const Pattern &pattern) {
2754 bool canApply = canApplyPattern(op, pattern);
2755 if (canApply &&
config.listener)
2756 config.listener->notifyPatternBegin(pattern, op);
2762 auto onFailure = [&](
const Pattern &pattern) {
2764 if (!rewriterImpl.
config.allowPatternRollback) {
2771#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2773 if (checkOp && topLevelFingerPrint) {
2774 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2775 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2776 llvm::reportFatalInternalError(
2777 "pattern '" + pattern.getDebugName() +
2778 "' returned failure but IR did change");
2786 if (rewriterImpl.
config.notifyCallback) {
2788 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2795 config.listener->notifyPatternEnd(pattern, failure());
2796 rewriterImpl.
resetState(curState, pattern.getDebugName());
2797 appliedPatterns.erase(&pattern);
2802 auto onSuccess = [&](
const Pattern &pattern) {
2804 if (!rewriterImpl.
config.allowPatternRollback) {
2818 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2819 appliedPatterns.erase(&pattern);
2821 if (!rewriterImpl.
config.allowPatternRollback)
2823 rewriterImpl.
resetState(curState, pattern.getDebugName());
2831 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2835bool OperationLegalizer::canApplyPattern(Operation *op,
2836 const Pattern &pattern) {
2838 auto &os = rewriter.getImpl().logger;
2839 os.getOStream() <<
"\n";
2840 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2842 os.getOStream() <<
")' {\n";
2849 !appliedPatterns.insert(&pattern).second) {
2851 logFailure(rewriter.getImpl().logger,
"pattern was already applied"));
2857LogicalResult OperationLegalizer::legalizePatternResult(
2858 Operation *op,
const Pattern &pattern,
const RewriterState &curState,
2861 [[maybe_unused]]
auto &impl = rewriter.getImpl();
2862 assert(impl.pendingRootUpdates.empty() &&
"dangling root updates");
2864#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2865 if (impl.config.allowPatternRollback) {
2867 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2868 auto replacedRoot = [&] {
2869 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2871 auto updatedRootInPlace = [&] {
2872 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2874 if (!replacedRoot() && !updatedRootInPlace())
2875 llvm::reportFatalInternalError(
2876 "expected pattern to replace the root operation "
2877 "or modify it in place");
2882 if (
failed(legalizePatternRootUpdates(modifiedOps)) ||
2883 failed(legalizePatternCreatedOperations(newOps))) {
2887 LLVM_DEBUG(
logSuccess(impl.logger,
"pattern applied successfully"));
2891LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2893 for (Operation *op : newOps) {
2894 if (
failed(legalize(op))) {
2895 LLVM_DEBUG(
logFailure(rewriter.getImpl().logger,
2896 "failed to legalize generated operation '{0}'({1})",
2904LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2906 for (Operation *op : modifiedOps) {
2907 if (
failed(legalize(op))) {
2910 "failed to legalize operation updated in-place '{0}'",
2922void OperationLegalizer::buildLegalizationGraph(
2923 LegalizationPatterns &anyOpLegalizerPatterns,
2934 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2935 std::optional<OperationName> root = pattern.
getRootKind();
2941 anyOpLegalizerPatterns.push_back(&pattern);
2946 if (
target.getOpAction(*root) == LegalizationAction::Legal)
2951 invalidPatterns[*root].insert(&pattern);
2953 parentOps[op].insert(*root);
2956 patternWorklist.insert(&pattern);
2964 if (!anyOpLegalizerPatterns.empty()) {
2965 for (
const Pattern *pattern : patternWorklist)
2966 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2970 while (!patternWorklist.empty()) {
2971 auto *pattern = patternWorklist.pop_back_val();
2975 std::optional<LegalizationAction> action = target.getOpAction(op);
2976 return !legalizerPatterns.count(op) &&
2977 (!action || action == LegalizationAction::Illegal);
2983 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2984 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2988 for (
auto op : parentOps[*pattern->
getRootKind()])
2989 patternWorklist.set_union(invalidPatterns[op]);
2993void OperationLegalizer::computeLegalizationGraphBenefit(
2994 LegalizationPatterns &anyOpLegalizerPatterns,
3000 for (
auto &opIt : legalizerPatterns)
3001 if (!minOpPatternDepth.count(opIt.first))
3002 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3008 if (!anyOpLegalizerPatterns.empty())
3009 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3015 applicator.applyCostModel([&](
const Pattern &pattern) {
3016 ArrayRef<const Pattern *> orderedPatternList;
3017 if (std::optional<OperationName> rootName = pattern.
getRootKind())
3018 orderedPatternList = legalizerPatterns[*rootName];
3020 orderedPatternList = anyOpLegalizerPatterns;
3023 auto *it = llvm::find(orderedPatternList, &pattern);
3024 if (it == orderedPatternList.end())
3028 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3032unsigned OperationLegalizer::computeOpLegalizationDepth(
3036 auto depthIt = minOpPatternDepth.find(op);
3037 if (depthIt != minOpPatternDepth.end())
3038 return depthIt->second;
3042 auto opPatternsIt = legalizerPatterns.find(op);
3043 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3048 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3052 unsigned minDepth = applyCostModelToPatterns(
3053 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3054 minOpPatternDepth[op] = minDepth;
3058unsigned OperationLegalizer::applyCostModelToPatterns(
3062 unsigned minDepth = std::numeric_limits<unsigned>::max();
3065 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3066 patternsByDepth.reserve(
patterns.size());
3067 for (
const Pattern *pattern :
patterns) {
3070 unsigned generatedOpDepth = computeOpLegalizationDepth(
3071 generatedOp, minOpPatternDepth, legalizerPatterns);
3072 depth = std::max(depth, generatedOpDepth + 1);
3074 patternsByDepth.emplace_back(pattern, depth);
3077 minDepth = std::min(minDepth, depth);
3082 if (patternsByDepth.size() == 1)
3086 llvm::stable_sort(patternsByDepth,
3087 [](
const std::pair<const Pattern *, unsigned> &
lhs,
3088 const std::pair<const Pattern *, unsigned> &
rhs) {
3091 if (
lhs.second !=
rhs.second)
3092 return lhs.second <
rhs.second;
3095 auto lhsBenefit =
lhs.first->getBenefit();
3096 auto rhsBenefit =
rhs.first->getBenefit();
3097 return lhsBenefit > rhsBenefit;
3102 for (
auto &patternIt : patternsByDepth)
3103 patterns.push_back(patternIt.first);
3117template <
typename RangeT>
3120 function_ref<
bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3129 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3130 if (castOp.getInputs().empty())
3133 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3136 if (inputCastOp.getOutputs() != castOp.getInputs())
3142 while (!worklist.empty()) {
3143 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3147 UnrealizedConversionCastOp nextCast = castOp;
3149 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3150 if (llvm::any_of(nextCast.getInputs(), [&](
Value v) {
3151 return v.getDefiningOp() == castOp;
3159 castOp.replaceAllUsesWith(nextCast.getInputs());
3162 nextCast = getInputCast(nextCast);
3172 auto markOpLive = [&](
Operation *rootOp) {
3174 worklist.push_back(rootOp);
3175 while (!worklist.empty()) {
3176 Operation *op = worklist.pop_back_val();
3177 if (liveOps.insert(op).second) {
3180 if (
auto castOp = v.
getDefiningOp<UnrealizedConversionCastOp>())
3181 if (isCastOpOfInterestFn(castOp))
3182 worklist.push_back(castOp);
3188 for (UnrealizedConversionCastOp op : castOps) {
3191 if (liveOps.contains(op.getOperation()))
3195 if (llvm::any_of(op->getUsers(), [&](
Operation *user) {
3196 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3197 return !castOp || !isCastOpOfInterestFn(castOp);
3203 for (UnrealizedConversionCastOp op : castOps) {
3204 if (liveOps.contains(op)) {
3206 if (remainingCastOps)
3207 remainingCastOps->push_back(op);
3218 ArrayRef<UnrealizedConversionCastOp> castOps,
3219 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3221 DenseSet<UnrealizedConversionCastOp> castOpSet;
3222 for (UnrealizedConversionCastOp op : castOps)
3223 castOpSet.insert(op);
3228 const DenseSet<UnrealizedConversionCastOp> &castOps,
3229 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3231 llvm::make_range(castOps.begin(), castOps.end()),
3232 [&](UnrealizedConversionCastOp castOp) {
3233 return castOps.contains(castOp);
3245 [&](UnrealizedConversionCastOp castOp) {
3246 return castOps.contains(castOp);
3263 const ConversionConfig &
config,
3264 OpConversionMode mode)
3274 template <
typename Fn>
3276 bool isRecursiveLegalization =
false);
3278 bool isRecursiveLegalization =
false) {
3280 ops, [&]() {}, isRecursiveLegalization);
3288 LogicalResult convert(
Operation *op,
bool isRecursiveLegalization =
false);
3294 ConversionPatternRewriter rewriter;
3297 OperationLegalizer opLegalizer;
3300 OpConversionMode mode;
3305 bool isRecursiveLegalization) {
3306 const ConversionConfig &
config = rewriter.getConfig();
3309 if (failed(opLegalizer.legalize(op))) {
3312 if (mode == OpConversionMode::Full) {
3313 if (!isRecursiveLegalization)
3321 if (mode == OpConversionMode::Partial) {
3322 if (opLegalizer.isIllegal(op)) {
3323 if (!isRecursiveLegalization)
3325 <<
"' that was explicitly marked illegal";
3328 if (
config.unlegalizedOps && !isRecursiveLegalization)
3329 config.unlegalizedOps->insert(op);
3331 }
else if (mode == OpConversionMode::Analysis) {
3335 if (
config.legalizableOps && !isRecursiveLegalization)
3336 config.legalizableOps->insert(op);
3343 UnrealizedConversionCastOp op,
3344 const UnresolvedMaterializationInfo &info) {
3345 assert(!op.use_empty() &&
3346 "expected that dead materializations have already been DCE'd");
3353 switch (info.getMaterializationKind()) {
3354 case MaterializationKind::Target:
3355 newMaterialization = converter->materializeTargetConversion(
3356 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3357 info.getOriginalType());
3359 case MaterializationKind::Source:
3360 assert(op->getNumResults() == 1 &&
"expected single result");
3361 Value sourceMat = converter->materializeSourceConversion(
3362 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3364 newMaterialization.push_back(sourceMat);
3367 if (!newMaterialization.empty()) {
3369 ValueRange newMaterializationRange(newMaterialization);
3370 assert(
TypeRange(newMaterializationRange) == op.getResultTypes() &&
3371 "materialization callback produced value of incorrect type");
3373 rewriter.
replaceOp(op, newMaterialization);
3379 <<
"failed to legalize unresolved materialization "
3381 << inputOperands.
getTypes() <<
") to ("
3382 << op.getResultTypes()
3383 <<
") that remained live after conversion";
3384 diag.attachNote(op->getUsers().begin()->getLoc())
3385 <<
"see existing live user here: " << *op->getUsers().begin();
3389template <
typename Fn>
3392 bool isRecursiveLegalization) {
3400 toConvert.push_back(op);
3403 auto legalityInfo =
target.isLegal(op);
3404 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3410 if (failed(
convert(op, isRecursiveLegalization))) {
3419LogicalResult ConversionPatternRewriter::legalize(
Operation *op) {
3420 return impl->opConverter.legalizeOperations(op,
3424LogicalResult ConversionPatternRewriter::legalize(
Region *r) {
3440 std::optional<TypeConverter::SignatureConversion> conversion =
3441 converter->convertBlockSignature(&r->front());
3444 applySignatureConversion(&r->front(), *conversion, converter);
3449 return impl->opConverter.legalizeOperations(ops,
3458 if (rewriterImpl.
config.allowPatternRollback) {
3482 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3486 if (rewriter.getConfig().buildMaterializations) {
3490 rewriter.getConfig().listener);
3491 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3492 auto it = materializations.find(castOp);
3493 assert(it != materializations.end() &&
"inconsistent state");
3507void TypeConverter::SignatureConversion::addInputs(
unsigned origInputNo,
3509 assert(!types.empty() &&
"expected valid types");
3510 remapInput(origInputNo, argTypes.size(), types.size());
3514void TypeConverter::SignatureConversion::addInputs(
ArrayRef<Type> types) {
3515 assert(!types.empty() &&
3516 "1->0 type remappings don't need to be added explicitly");
3517 argTypes.append(types.begin(), types.end());
3520void TypeConverter::SignatureConversion::remapInput(
unsigned origInputNo,
3521 unsigned newInputNo,
3522 unsigned newInputCount) {
3523 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3524 assert(newInputCount != 0 &&
"expected valid input count");
3525 remappedInputs[origInputNo] =
3526 InputMapping{newInputNo, newInputCount, {}};
3529void TypeConverter::SignatureConversion::remapInput(
3530 unsigned origInputNo, ArrayRef<Value> replacements) {
3531 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3532 remappedInputs[origInputNo] = InputMapping{
3534 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3545TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3546 SmallVectorImpl<Type> &results)
const {
3547 assert(typeOrValue &&
"expected non-null type");
3548 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3549 : cast<Type>(typeOrValue);
3551 std::shared_lock<
decltype(cacheMutex)> cacheReadLock(cacheMutex,
3554 cacheReadLock.lock();
3555 auto existingIt = cachedDirectConversions.find(t);
3556 if (existingIt != cachedDirectConversions.end()) {
3557 if (existingIt->second)
3558 results.push_back(existingIt->second);
3559 return success(existingIt->second !=
nullptr);
3561 auto multiIt = cachedMultiConversions.find(t);
3562 if (multiIt != cachedMultiConversions.end()) {
3563 results.append(multiIt->second.begin(), multiIt->second.end());
3569 size_t currentCount = results.size();
3573 auto isCacheable = [&](
int index) {
3574 int numberOfConversionsUntilContextAware =
3575 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3576 return index < numberOfConversionsUntilContextAware;
3579 std::unique_lock<
decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3582 for (
auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3583 const ConversionCallbackFn &converter = indexedConverter.value();
3584 std::optional<LogicalResult>
result = converter(typeOrValue, results);
3586 assert(results.size() == currentCount &&
3587 "failed type conversion should not change results");
3590 if (!isCacheable(indexedConverter.index()))
3593 cacheWriteLock.lock();
3594 if (!succeeded(*
result)) {
3595 assert(results.size() == currentCount &&
3596 "failed type conversion should not change results");
3597 cachedDirectConversions.try_emplace(t,
nullptr);
3600 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3601 if (newTypes.size() == 1)
3602 cachedDirectConversions.try_emplace(t, newTypes.front());
3604 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3610LogicalResult TypeConverter::convertType(Type t,
3611 SmallVectorImpl<Type> &results)
const {
3612 return convertTypeImpl(t, results);
3615LogicalResult TypeConverter::convertType(Value v,
3616 SmallVectorImpl<Type> &results)
const {
3617 return convertTypeImpl(v, results);
3620Type TypeConverter::convertType(Type t)
const {
3622 SmallVector<Type, 1> results;
3623 if (
failed(convertType(t, results)))
3627 return results.size() == 1 ? results.front() :
nullptr;
3630Type TypeConverter::convertType(Value v)
const {
3632 SmallVector<Type, 1> results;
3633 if (
failed(convertType(v, results)))
3637 return results.size() == 1 ? results.front() :
nullptr;
3641TypeConverter::convertTypes(
TypeRange types,
3642 SmallVectorImpl<Type> &results)
const {
3643 for (Type type : types)
3644 if (
failed(convertType(type, results)))
3650TypeConverter::convertTypes(
ValueRange values,
3651 SmallVectorImpl<Type> &results)
const {
3652 for (Value value : values)
3653 if (
failed(convertType(value, results)))
3658bool TypeConverter::isLegal(Type type)
const {
3659 return convertType(type) == type;
3662bool TypeConverter::isLegal(Value value)
const {
3663 return convertType(value) == value.
getType();
3666bool TypeConverter::isLegal(Operation *op)
const {
3670bool TypeConverter::isLegal(Region *region)
const {
3671 return llvm::all_of(
3675bool TypeConverter::isSignatureLegal(FunctionType ty)
const {
3676 if (!isLegal(ty.getInputs()))
3678 if (!isLegal(ty.getResults()))
3684TypeConverter::convertSignatureArg(
unsigned inputNo, Type type,
3685 SignatureConversion &
result)
const {
3687 SmallVector<Type, 1> convertedTypes;
3688 if (
failed(convertType(type, convertedTypes)))
3692 if (convertedTypes.empty())
3696 result.addInputs(inputNo, convertedTypes);
3700TypeConverter::convertSignatureArgs(
TypeRange types,
3701 SignatureConversion &
result,
3702 unsigned origInputOffset)
const {
3703 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3704 if (
failed(convertSignatureArg(origInputOffset + i, types[i],
result)))
3709TypeConverter::convertSignatureArg(
unsigned inputNo, Value value,
3710 SignatureConversion &
result)
const {
3712 SmallVector<Type, 1> convertedTypes;
3713 if (
failed(convertType(value, convertedTypes)))
3717 if (convertedTypes.empty())
3721 result.addInputs(inputNo, convertedTypes);
3725TypeConverter::convertSignatureArgs(
ValueRange values,
3726 SignatureConversion &
result,
3727 unsigned origInputOffset)
const {
3728 for (
unsigned i = 0, e = values.size(); i != e; ++i)
3729 if (
failed(convertSignatureArg(origInputOffset + i, values[i],
result)))
3734Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3735 Location loc, Type resultType,
3737 for (
const SourceMaterializationCallbackFn &fn :
3738 llvm::reverse(sourceMaterializations))
3739 if (Value
result = fn(builder, resultType, inputs, loc))
3744Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3745 Location loc, Type resultType,
3747 Type originalType)
const {
3748 SmallVector<Value>
result = materializeTargetConversion(
3749 builder, loc,
TypeRange(resultType), inputs, originalType);
3752 assert(
result.size() == 1 &&
"expected single result");
3756SmallVector<Value> TypeConverter::materializeTargetConversion(
3758 Type originalType)
const {
3759 for (
const TargetMaterializationCallbackFn &fn :
3760 llvm::reverse(targetMaterializations)) {
3761 SmallVector<Value>
result =
3762 fn(builder, resultTypes, inputs, loc, originalType);
3766 "callback produced incorrect number of values or values with "
3773std::optional<TypeConverter::SignatureConversion>
3774TypeConverter::convertBlockSignature(
Block *block)
const {
3777 return std::nullopt;
3784TypeConverter::AttributeConversionResult
3785TypeConverter::AttributeConversionResult::result(Attribute attr) {
3786 return AttributeConversionResult(attr, resultTag);
3789TypeConverter::AttributeConversionResult
3790TypeConverter::AttributeConversionResult::na() {
3791 return AttributeConversionResult(
nullptr, naTag);
3794TypeConverter::AttributeConversionResult
3795TypeConverter::AttributeConversionResult::abort() {
3796 return AttributeConversionResult(
nullptr, abortTag);
3799bool TypeConverter::AttributeConversionResult::hasResult()
const {
3800 return impl.getInt() == resultTag;
3803bool TypeConverter::AttributeConversionResult::isNa()
const {
3804 return impl.getInt() == naTag;
3807bool TypeConverter::AttributeConversionResult::isAbort()
const {
3808 return impl.getInt() == abortTag;
3811Attribute TypeConverter::AttributeConversionResult::getResult()
const {
3812 assert(hasResult() &&
"Cannot get result from N/A or abort");
3813 return impl.getPointer();
3816std::optional<Attribute>
3817TypeConverter::convertTypeAttribute(Type type, Attribute attr)
const {
3818 for (
const TypeAttributeConversionCallbackFn &fn :
3819 llvm::reverse(typeAttributeConversions)) {
3820 AttributeConversionResult res = fn(type, attr);
3821 if (res.hasResult())
3822 return res.getResult();
3824 return std::nullopt;
3826 return std::nullopt;
3835 ConversionPatternRewriter &rewriter) {
3836 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3841 TypeConverter::SignatureConversion
result(type.getNumInputs());
3843 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
result)) ||
3844 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3846 if (!funcOp.getFunctionBody().empty())
3847 rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(),
result,
3851 auto newType = FunctionType::get(rewriter.getContext(),
3852 result.getConvertedTypes(), newResults);
3854 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3863struct FunctionOpInterfaceSignatureConversion :
public ConversionPattern {
3864 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3866 const TypeConverter &converter,
3867 PatternBenefit benefit)
3868 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3871 matchAndRewrite(Operation *op, ArrayRef<Value> ,
3872 ConversionPatternRewriter &rewriter)
const override {
3873 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3878struct AnyFunctionOpInterfaceSignatureConversion
3879 :
public OpInterfaceConversionPattern<FunctionOpInterface> {
3880 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3883 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> ,
3884 ConversionPatternRewriter &rewriter)
const override {
3890FailureOr<Operation *>
3891mlir::convertOpResultTypes(Operation *op,
ValueRange operands,
3892 const TypeConverter &converter,
3893 ConversionPatternRewriter &rewriter) {
3894 assert(op &&
"Invalid op");
3895 Location loc = op->
getLoc();
3896 if (converter.isLegal(op))
3897 return rewriter.notifyMatchFailure(loc,
"op already legal");
3899 OperationState newOp(loc, op->
getName());
3900 newOp.addOperands(operands);
3902 SmallVector<Type> newResultTypes;
3904 return rewriter.notifyMatchFailure(loc,
"couldn't convert return types");
3906 newOp.addTypes(newResultTypes);
3907 newOp.addAttributes(op->
getAttrs());
3908 return rewriter.create(newOp);
3911void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3912 StringRef functionLikeOpName, RewritePatternSet &
patterns,
3913 const TypeConverter &converter, PatternBenefit benefit) {
3914 patterns.add<FunctionOpInterfaceSignatureConversion>(
3915 functionLikeOpName,
patterns.getContext(), converter, benefit);
3918void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3919 RewritePatternSet &
patterns,
const TypeConverter &converter,
3920 PatternBenefit benefit) {
3921 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3922 converter,
patterns.getContext(), benefit);
3929void ConversionTarget::setOpAction(OperationName op,
3930 LegalizationAction action) {
3931 legalOperations[op].action = action;
3934void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3935 LegalizationAction action) {
3936 for (StringRef dialect : dialectNames)
3937 legalDialects[dialect] = action;
3940auto ConversionTarget::getOpAction(OperationName op)
const
3941 -> std::optional<LegalizationAction> {
3942 std::optional<LegalizationInfo> info = getOpInfo(op);
3943 return info ? info->action : std::optional<LegalizationAction>();
3946auto ConversionTarget::isLegal(Operation *op)
const
3947 -> std::optional<LegalOpDetails> {
3948 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3950 return std::nullopt;
3953 auto isOpLegal = [&] {
3955 if (info->action == LegalizationAction::Dynamic) {
3956 std::optional<bool>
result = info->legalityFn(op);
3962 return info->action == LegalizationAction::Legal;
3965 return std::nullopt;
3968 LegalOpDetails legalityDetails;
3969 if (info->isRecursivelyLegal) {
3970 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3971 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3972 legalityDetails.isRecursivelyLegal =
3973 legalityFnIt->second(op).value_or(
true);
3975 legalityDetails.isRecursivelyLegal =
true;
3978 return legalityDetails;
3981bool ConversionTarget::isIllegal(Operation *op)
const {
3982 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3986 if (info->action == LegalizationAction::Dynamic) {
3987 std::optional<bool>
result = info->legalityFn(op);
3994 return info->action == LegalizationAction::Illegal;
3998 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
3999 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
4003 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
4005 if (std::optional<bool>
result = newCl(op))
4013void ConversionTarget::setLegalityCallback(
4014 OperationName name,
const DynamicLegalityCallbackFn &callback) {
4015 assert(callback &&
"expected valid legality callback");
4016 auto *infoIt = legalOperations.find(name);
4017 assert(infoIt != legalOperations.end() &&
4018 infoIt->second.action == LegalizationAction::Dynamic &&
4019 "expected operation to already be marked as dynamically legal");
4020 infoIt->second.legalityFn =
4024void ConversionTarget::markOpRecursivelyLegal(
4025 OperationName name,
const DynamicLegalityCallbackFn &callback) {
4026 auto *infoIt = legalOperations.find(name);
4027 assert(infoIt != legalOperations.end() &&
4028 infoIt->second.action != LegalizationAction::Illegal &&
4029 "expected operation to already be marked as legal");
4030 infoIt->second.isRecursivelyLegal =
true;
4033 std::move(opRecursiveLegalityFns[name]), callback);
4035 opRecursiveLegalityFns.erase(name);
4038void ConversionTarget::setLegalityCallback(
4039 ArrayRef<StringRef> dialects,
const DynamicLegalityCallbackFn &callback) {
4040 assert(callback &&
"expected valid legality callback");
4041 for (StringRef dialect : dialects)
4043 std::move(dialectLegalityFns[dialect]), callback);
4046void ConversionTarget::setLegalityCallback(
4047 const DynamicLegalityCallbackFn &callback) {
4048 assert(callback &&
"expected valid legality callback");
4052auto ConversionTarget::getOpInfo(OperationName op)
const
4053 -> std::optional<LegalizationInfo> {
4055 const auto *it = legalOperations.find(op);
4056 if (it != legalOperations.end())
4059 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4060 if (dialectIt != legalDialects.end()) {
4061 DynamicLegalityCallbackFn callback;
4062 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4063 if (dialectFn != dialectLegalityFns.end())
4064 callback = dialectFn->second;
4065 return LegalizationInfo{dialectIt->second,
false,
4069 if (unknownLegalityFn)
4070 return LegalizationInfo{LegalizationAction::Dynamic,
4071 false, unknownLegalityFn};
4072 return std::nullopt;
4075#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4080void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4081 auto &rewriterImpl =
4082 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4086void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4087 auto &rewriterImpl =
4088 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4094static FailureOr<SmallVector<Value>>
4095pdllConvertValues(ConversionPatternRewriter &rewriter,
ValueRange values) {
4096 SmallVector<Value> mappedValues;
4097 if (
failed(rewriter.getRemappedValues(values, mappedValues)))
4099 return std::move(mappedValues);
4102void mlir::registerConversionPDLFunctions(RewritePatternSet &
patterns) {
4103 patterns.getPDLPatterns().registerRewriteFunction(
4105 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4106 auto results = pdllConvertValues(
4107 static_cast<ConversionPatternRewriter &
>(rewriter), value);
4110 return results->front();
4112 patterns.getPDLPatterns().registerRewriteFunction(
4113 "convertValues", [](PatternRewriter &rewriter,
ValueRange values) {
4114 return pdllConvertValues(
4115 static_cast<ConversionPatternRewriter &
>(rewriter), values);
4117 patterns.getPDLPatterns().registerRewriteFunction(
4119 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4120 auto &rewriterImpl =
4121 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4122 if (
const TypeConverter *converter =
4124 if (Type newType = converter->convertType(type))
4130 patterns.getPDLPatterns().registerRewriteFunction(
4132 [](PatternRewriter &rewriter,
4133 TypeRange types) -> FailureOr<SmallVector<Type>> {
4134 auto &rewriterImpl =
4135 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4138 return SmallVector<Type>(types);
4140 SmallVector<Type> remappedTypes;
4141 if (
failed(converter->convertTypes(types, remappedTypes)))
4143 return std::move(remappedTypes);
4158 static constexpr StringLiteral
tag =
"apply-conversion";
4159 static constexpr StringLiteral
desc =
4160 "Encapsulate the application of a dialect conversion";
4169 OpConversionMode mode) {
4173 LogicalResult status =
success();
4189LogicalResult mlir::applyPartialConversion(
4190 ArrayRef<Operation *> ops,
const ConversionTarget &
target,
4191 const FrozenRewritePatternSet &
patterns, ConversionConfig
config) {
4193 OpConversionMode::Partial);
4196mlir::applyPartialConversion(Operation *op,
const ConversionTarget &
target,
4197 const FrozenRewritePatternSet &
patterns,
4198 ConversionConfig
config) {
4206LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4207 const ConversionTarget &
target,
4208 const FrozenRewritePatternSet &
patterns,
4209 ConversionConfig
config) {
4212LogicalResult mlir::applyFullConversion(Operation *op,
4213 const ConversionTarget &
target,
4214 const FrozenRewritePatternSet &
patterns,
4215 ConversionConfig
config) {
4233 "expected top-level op to be isolated from above");
4236 "expected ops to have a common ancestor");
4245 for (
Operation *op : ops.drop_front()) {
4249 assert(commonAncestor &&
4250 "expected to find a common isolated from above ancestor");
4254 return commonAncestor;
4257LogicalResult mlir::applyAnalysisConversion(
4258 ArrayRef<Operation *> ops, ConversionTarget &
target,
4259 const FrozenRewritePatternSet &
patterns, ConversionConfig
config) {
4261 if (
config.legalizableOps)
4262 assert(
config.legalizableOps->empty() &&
"expected empty set");
4268 Operation *clonedAncestor = commonAncestor->
clone(mapping);
4272 inverseOperationMap[it.second] = it.first;
4275 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4276 ops, [&](Operation *op) {
return mapping.
lookup(op); });
4278 OpConversionMode::Analysis);
4282 if (
config.legalizableOps) {
4284 for (Operation *op : *
config.legalizableOps)
4285 originalLegalizableOps.insert(inverseOperationMap[op]);
4286 *
config.legalizableOps = std::move(originalLegalizableOps);
4290 clonedAncestor->
erase();
4295mlir::applyAnalysisConversion(Operation *op, ConversionTarget &
target,
4296 const FrozenRewritePatternSet &
patterns,
4297 ConversionConfig
config) {
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
void destroyOpProperties(OpaqueProperties properties) const
This hooks destroy the op properties.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
CRTP Implementation of an action.
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This iterator enumerates elements according to their dominance relationship.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
const ConversionTarget & getTarget()
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, bool isRecursiveLegalization=false)
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, Fn onFailure, bool isRecursiveLegalization=false)
Legalizes the given operations (and their nested operations) to the conversion target.
LogicalResult applyConversion(ArrayRef< Operation * > ops)
Applies the conversion to the given operations (and their nested operations).
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
bool wasErased(void *ptr) const
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.