10#include "mlir/Config/mlir-config.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/SaveAndRestore.h"
27#include "llvm/Support/ScopedPrinter.h"
34#define DEBUG_TYPE "dialect-conversion"
37template <
typename... Args>
38static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
41 os.startLine() <<
"} -> SUCCESS";
43 os.getOStream() <<
" : "
44 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
45 os.getOStream() <<
"\n";
50template <
typename... Args>
51static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
54 os.startLine() <<
"} -> FAILURE : "
55 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
65 if (
OpResult inputRes = dyn_cast<OpResult>(value))
66 insertPt = ++inputRes.getOwner()->getIterator();
73 assert(!vals.empty() &&
"expected at least one value");
76 for (
Value v : vals.drop_front()) {
90 assert(dom &&
"unable to find valid insertion point");
98enum OpConversionMode {
124struct ValueVectorMapInfo {
127 static ::llvm::hash_code getHashValue(
const ValueVector &val) {
128 return ::llvm::hash_combine_range(val);
137struct ConversionValueMapping {
140 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
145 template <
typename T>
146 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
149 template <
typename OldVal,
typename NewVal>
150 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
151 map(OldVal &&oldVal, NewVal &&newVal) {
155 assert(next != oldVal &&
"inserting cyclic mapping");
156 auto it = mapping.find(next);
157 if (it == mapping.end())
162 mappedTo.insert_range(newVal);
164 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
168 template <
typename OldVal,
typename NewVal>
169 std::enable_if_t<!IsValueVector<OldVal>::value ||
170 !IsValueVector<NewVal>::value>
171 map(OldVal &&oldVal, NewVal &&newVal) {
172 if constexpr (IsValueVector<OldVal>{}) {
173 map(std::forward<OldVal>(oldVal),
ValueVector{newVal});
174 }
else if constexpr (IsValueVector<NewVal>{}) {
175 map(
ValueVector{oldVal}, std::forward<NewVal>(newVal));
186 void erase(
const ValueVector &value) { mapping.erase(value); }
206 assert(!values.empty() &&
"expected non-empty value vector");
207 Operation *op = values.front().getDefiningOp();
208 for (
Value v : llvm::drop_begin(values)) {
209 if (v.getDefiningOp() != op)
219 assert(!values.empty() &&
"expected non-empty value vector");
225 auto it = mapping.find(from);
226 if (it == mapping.end()) {
239struct RewriterState {
240 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
241 unsigned numReplacedOps)
242 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
243 numReplacedOps(numReplacedOps) {}
246 unsigned numRewrites;
249 unsigned numIgnoredOperations;
252 unsigned numReplacedOps;
259static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
262static void notifyIRErased(RewriterBase::Listener *listener,
Block &
b) {
263 for (Operation &op :
b)
264 notifyIRErased(listener, op);
270static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
273 notifyIRErased(listener,
b);
303 UnresolvedMaterialization,
308 virtual ~IRRewrite() =
default;
311 virtual void rollback() = 0;
325 virtual void commit(RewriterBase &rewriter) {}
328 virtual void cleanup(RewriterBase &rewriter) {}
330 Kind getKind()
const {
return kind; }
332 static bool classof(
const IRRewrite *
rewrite) {
return true; }
335 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
336 : kind(kind), rewriterImpl(rewriterImpl) {}
338 const ConversionConfig &getConfig()
const;
341 ConversionPatternRewriterImpl &rewriterImpl;
345class BlockRewrite :
public IRRewrite {
348 Block *getBlock()
const {
return block; }
350 static bool classof(
const IRRewrite *
rewrite) {
351 return rewrite->getKind() >= Kind::CreateBlock &&
352 rewrite->getKind() <= Kind::BlockTypeConversion;
356 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
358 : IRRewrite(kind, rewriterImpl), block(block) {}
365class ValueRewrite :
public IRRewrite {
368 Value getValue()
const {
return value; }
370 static bool classof(
const IRRewrite *
rewrite) {
371 return rewrite->getKind() == Kind::ReplaceValue;
375 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
377 : IRRewrite(kind, rewriterImpl), value(value) {}
386class CreateBlockRewrite :
public BlockRewrite {
388 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
389 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
391 static bool classof(
const IRRewrite *
rewrite) {
392 return rewrite->getKind() == Kind::CreateBlock;
395 void commit(RewriterBase &rewriter)
override {
401 void rollback()
override {
404 auto &blockOps = block->getOperations();
405 while (!blockOps.empty())
406 blockOps.remove(blockOps.begin());
407 block->dropAllUses();
408 if (block->getParent())
419class EraseBlockRewrite :
public BlockRewrite {
421 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
422 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
423 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
425 static bool classof(
const IRRewrite *
rewrite) {
426 return rewrite->getKind() == Kind::EraseBlock;
429 ~EraseBlockRewrite()
override {
431 "rewrite was neither rolled back nor committed/cleaned up");
434 void rollback()
override {
437 assert(block &&
"expected block");
442 blockList.insert(before, block);
446 void commit(RewriterBase &rewriter)
override {
447 assert(block &&
"expected block");
451 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
452 notifyIRErased(listener, *block);
455 void cleanup(RewriterBase &rewriter)
override {
457 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
459 assert(block->empty() &&
"expected empty block");
462 block->dropAllDefinedValueUses();
473 Block *insertBeforeBlock;
479class InlineBlockRewrite :
public BlockRewrite {
481 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
483 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
484 sourceBlock(sourceBlock),
485 firstInlinedInst(sourceBlock->empty() ?
nullptr
486 : &sourceBlock->front()),
487 lastInlinedInst(sourceBlock->empty() ?
nullptr : &sourceBlock->back()) {
493 assert(!getConfig().listener &&
494 "InlineBlockRewrite not supported if listener is attached");
497 static bool classof(
const IRRewrite *
rewrite) {
498 return rewrite->getKind() == Kind::InlineBlock;
501 void rollback()
override {
504 if (firstInlinedInst) {
505 assert(lastInlinedInst &&
"expected operation");
518 Operation *firstInlinedInst;
521 Operation *lastInlinedInst;
525class MoveBlockRewrite :
public BlockRewrite {
527 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
529 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block),
530 region(previousRegion),
531 insertBeforeBlock(previousIt == previousRegion->end() ?
nullptr
534 static bool classof(
const IRRewrite *
rewrite) {
535 return rewrite->getKind() == Kind::MoveBlock;
538 void commit(RewriterBase &rewriter)
override {
548 void rollback()
override {
552 if (Region *currentParent = block->
getParent()) {
554 region->getBlocks().splice(before, currentParent->getBlocks(), block);
558 region->
getBlocks().insert(before, block);
567 Block *insertBeforeBlock;
571class BlockTypeConversionRewrite :
public BlockRewrite {
573 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
575 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, origBlock),
576 newBlock(newBlock) {}
578 static bool classof(
const IRRewrite *
rewrite) {
579 return rewrite->getKind() == Kind::BlockTypeConversion;
582 Block *getOrigBlock()
const {
return block; }
584 Block *getNewBlock()
const {
return newBlock; }
586 void commit(RewriterBase &rewriter)
override;
588 void rollback()
override;
598class ReplaceValueRewrite :
public ValueRewrite {
600 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
601 const TypeConverter *converter)
602 : ValueRewrite(
Kind::ReplaceValue, rewriterImpl, value),
603 converter(converter) {}
605 static bool classof(
const IRRewrite *
rewrite) {
606 return rewrite->getKind() == Kind::ReplaceValue;
609 void commit(RewriterBase &rewriter)
override;
611 void rollback()
override;
615 const TypeConverter *converter;
619class OperationRewrite :
public IRRewrite {
622 Operation *getOperation()
const {
return op; }
624 static bool classof(
const IRRewrite *
rewrite) {
625 return rewrite->getKind() >= Kind::MoveOperation &&
626 rewrite->getKind() <= Kind::UnresolvedMaterialization;
630 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
632 : IRRewrite(kind, rewriterImpl), op(op) {}
639class MoveOperationRewrite :
public OperationRewrite {
641 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
642 Operation *op, OpBuilder::InsertPoint previous)
643 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op),
644 block(previous.getBlock()),
645 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
647 : &*previous.getPoint()) {}
649 static bool classof(
const IRRewrite *
rewrite) {
650 return rewrite->getKind() == Kind::MoveOperation;
653 void commit(RewriterBase &rewriter)
override {
659 op, OpBuilder::InsertPoint(block,
664 void rollback()
override {
677 Operation *insertBeforeOp;
682class ModifyOperationRewrite :
public OperationRewrite {
684 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
686 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
687 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
688 operands(op->operand_begin(), op->operand_end()),
689 successors(op->successor_begin(), op->successor_end()) {
692 propertiesStorage = operator new(op->getPropertiesStorageSize());
693 OpaqueProperties propCopy(propertiesStorage);
694 name.initOpProperties(propCopy, prop);
698 static bool classof(
const IRRewrite *
rewrite) {
699 return rewrite->getKind() == Kind::ModifyOperation;
702 ~ModifyOperationRewrite()
override {
703 assert(!propertiesStorage &&
704 "rewrite was neither committed nor rolled back");
707 void commit(RewriterBase &rewriter)
override {
710 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
713 if (propertiesStorage) {
714 OpaqueProperties propCopy(propertiesStorage);
718 operator delete(propertiesStorage);
719 propertiesStorage =
nullptr;
723 void rollback()
override {
727 for (
const auto &it : llvm::enumerate(successors))
729 if (propertiesStorage) {
730 OpaqueProperties propCopy(propertiesStorage);
733 operator delete(propertiesStorage);
734 propertiesStorage =
nullptr;
741 DictionaryAttr attrs;
742 SmallVector<Value, 8> operands;
743 SmallVector<Block *, 2> successors;
744 void *propertiesStorage =
nullptr;
751class ReplaceOperationRewrite :
public OperationRewrite {
753 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
754 Operation *op,
const TypeConverter *converter)
755 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
756 converter(converter) {}
758 static bool classof(
const IRRewrite *
rewrite) {
759 return rewrite->getKind() == Kind::ReplaceOperation;
762 void commit(RewriterBase &rewriter)
override;
764 void rollback()
override;
766 void cleanup(RewriterBase &rewriter)
override;
771 const TypeConverter *converter;
774class CreateOperationRewrite :
public OperationRewrite {
776 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
778 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
780 static bool classof(
const IRRewrite *
rewrite) {
781 return rewrite->getKind() == Kind::CreateOperation;
784 void commit(RewriterBase &rewriter)
override {
790 void rollback()
override;
794enum MaterializationKind {
805class UnresolvedMaterializationInfo {
807 UnresolvedMaterializationInfo() =
default;
808 UnresolvedMaterializationInfo(
const TypeConverter *converter,
809 MaterializationKind kind, Type originalType)
810 : converterAndKind(converter, kind), originalType(originalType) {}
813 const TypeConverter *getConverter()
const {
814 return converterAndKind.getPointer();
818 MaterializationKind getMaterializationKind()
const {
819 return converterAndKind.getInt();
823 Type getOriginalType()
const {
return originalType; }
828 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
839class UnresolvedMaterializationRewrite :
public OperationRewrite {
841 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
842 UnrealizedConversionCastOp op,
844 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
845 mappedValues(std::move(mappedValues)) {}
847 static bool classof(
const IRRewrite *
rewrite) {
848 return rewrite->getKind() == Kind::UnresolvedMaterialization;
851 void rollback()
override;
853 UnrealizedConversionCastOp getOperation()
const {
854 return cast<UnrealizedConversionCastOp>(op);
864#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
867template <
typename RewriteTy,
typename R>
868static bool hasRewrite(R &&rewrites, Operation *op) {
869 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
870 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
871 return rewriteTy && rewriteTy->getOperation() == op;
877template <
typename RewriteTy,
typename R>
878static bool hasRewrite(R &&rewrites,
Block *block) {
879 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
880 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
881 return rewriteTy && rewriteTy->getBlock() == block;
893 const ConversionConfig &
config,
903 RewriterState getCurrentState();
907 void applyRewrites();
912 void resetState(RewriterState state, StringRef patternName =
"");
916 template <
typename RewriteTy,
typename... Args>
918 assert(
config.allowPatternRollback &&
"appending rewrites is not allowed");
920 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
926 void undoRewrites(
unsigned numRewritesToKeep = 0, StringRef patternName =
"");
932 LogicalResult remapValues(StringRef valueDiagTag,
933 std::optional<Location> inputLoc,
ValueRange values,
950 bool skipPureTypeConversions =
false)
const;
964 TypeConverter::SignatureConversion *entryConversion);
972 Block *applySignatureConversion(
974 TypeConverter::SignatureConversion &signatureConversion);
994 void eraseBlock(
Block *block);
1032 Value findOrBuildReplacementValue(
Value value,
1040 void notifyOperationInserted(
Operation *op,
1044 void notifyBlockInserted(
Block *block,
Region *previous,
1063 std::function<
void(
Operation *)> opErasedCallback =
nullptr)
1065 opErasedCallback(std::move(opErasedCallback)) {}
1079 assert(block->empty() &&
"expected empty block");
1080 block->dropAllDefinedValueUses();
1088 if (opErasedCallback)
1089 opErasedCallback(op);
1197const ConversionConfig &IRRewrite::getConfig()
const {
1198 return rewriterImpl.
config;
1201void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1205 if (
auto *listener =
1206 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1207 for (Operation *op : getNewBlock()->getUsers())
1211void BlockTypeConversionRewrite::rollback() {
1212 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1219 if (isa<BlockArgument>(repl)) {
1259 result &= functor(operand);
1264void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1271void ReplaceValueRewrite::rollback() {
1272 rewriterImpl.
mapping.erase({value});
1278void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1280 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1283 SmallVector<Value> replacements =
1285 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1293 for (
auto [
result, newValue] :
1294 llvm::zip_equal(op->
getResults(), replacements))
1300 if (getConfig().unlegalizedOps)
1301 getConfig().unlegalizedOps->erase(op);
1305 notifyIRErased(listener, *op);
1312void ReplaceOperationRewrite::rollback() {
1317void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1321void CreateOperationRewrite::rollback() {
1323 while (!region.getBlocks().empty())
1324 region.getBlocks().remove(region.getBlocks().begin());
1330void UnresolvedMaterializationRewrite::rollback() {
1331 if (!mappedValues.empty())
1332 rewriterImpl.
mapping.erase(mappedValues);
1343 for (
size_t i = 0; i <
rewrites.size(); ++i)
1349 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1350 unresolvedMaterializations.erase(castOp);
1353 rewrite->cleanup(eraseRewriter);
1361 Value from,
TypeRange desiredTypes,
bool skipPureTypeConversions)
const {
1364 assert(!values.empty() &&
"expected non-empty value vector");
1368 if (
config.allowPatternRollback)
1369 return mapping.lookup(values);
1376 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1381 if (castOp.getOutputs() != values)
1383 return castOp.getInputs();
1392 for (
Value v : values) {
1395 llvm::append_range(next, r);
1400 if (next != values) {
1429 if (skipPureTypeConversions) {
1432 match &= !pureConversion;
1435 if (!pureConversion)
1436 lastNonMaterialization = current;
1439 desiredValue = current;
1445 current = std::move(next);
1450 if (!desiredTypes.empty())
1451 return desiredValue;
1452 if (skipPureTypeConversions)
1453 return lastNonMaterialization;
1472 StringRef patternName) {
1477 while (
ignoredOps.size() != state.numIgnoredOperations)
1480 while (
replacedOps.size() != state.numReplacedOps)
1485 StringRef patternName) {
1487 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1489 rewrites.resize(numRewritesToKeep);
1493 StringRef valueDiagTag, std::optional<Location> inputLoc,
ValueRange values,
1495 remapped.reserve(llvm::size(values));
1497 for (
const auto &it : llvm::enumerate(values)) {
1498 Value operand = it.value();
1517 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1518 << it.index() <<
", type was " << origType;
1523 if (legalTypes.empty()) {
1524 remapped.push_back({});
1533 remapped.push_back(std::move(repl));
1542 repl, repl, legalTypes,
1544 remapped.push_back(castValues);
1565 TypeConverter::SignatureConversion *entryConversion) {
1567 if (region->
empty())
1572 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1574 std::optional<TypeConverter::SignatureConversion> conversion =
1575 converter.convertBlockSignature(&block);
1584 if (entryConversion)
1587 std::optional<TypeConverter::SignatureConversion> conversion =
1588 converter.convertBlockSignature(®ion->
front());
1596 TypeConverter::SignatureConversion &signatureConversion) {
1597#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1599 if (hasRewrite<BlockTypeConversionRewrite>(
rewrites, block))
1600 llvm::reportFatalInternalError(
"block was already converted");
1607 auto convertedTypes = signatureConversion.getConvertedTypes();
1614 for (
unsigned i = 0; i < origArgCount; ++i) {
1615 auto inputMap = signatureConversion.getInputMapping(i);
1616 if (!inputMap || inputMap->replacedWithValues())
1619 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1620 newLocs[inputMap->inputNo +
j] = origLoc;
1627 convertedTypes, newLocs);
1635 bool fastPath = !
config.listener;
1637 if (
config.allowPatternRollback)
1641 while (!block->
empty())
1648 for (
unsigned i = 0; i != origArgCount; ++i) {
1652 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1653 signatureConversion.getInputMapping(i);
1661 MaterializationKind::Source,
1665 origArgType,
Type(), converter,
1672 if (inputMap->replacedWithValues()) {
1674 assert(inputMap->size == 0 &&
1675 "invalid to provide a replacement value when the argument isn't "
1683 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1687 if (
config.allowPatternRollback)
1708 assert((!originalType || kind == MaterializationKind::Target) &&
1709 "original type is valid only for target materializations");
1710 assert(
TypeRange(inputs) != outputTypes &&
1711 "materialization is not necessary");
1715 OpBuilder builder(outputTypes.front().getContext());
1717 UnrealizedConversionCastOp convertOp =
1718 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1719 if (
config.attachDebugMaterializationKind) {
1721 kind == MaterializationKind::Source ?
"source" :
"target";
1722 convertOp->setAttr(
"__kind__", builder.
getStringAttr(kindStr));
1729 UnresolvedMaterializationInfo(converter, kind, originalType);
1730 if (
config.allowPatternRollback) {
1731 if (!valuesToMap.empty())
1732 mapping.map(valuesToMap, convertOp.getResults());
1734 std::move(valuesToMap));
1738 return convertOp.getResults();
1743 assert(
config.allowPatternRollback &&
1744 "this code path is valid only in rollback mode");
1751 return repl.front();
1758 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1783 MaterializationKind::Source, ip, value.
getLoc(),
1799 bool wasDetached = !previous.
isSet();
1801 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"' (" << op
1804 logger.getOStream() <<
" (was detached)";
1805 logger.getOStream() <<
"\n";
1811 "attempting to insert into a block within a replaced/erased op");
1815 config.listener->notifyOperationInserted(op, previous);
1824 if (
config.allowPatternRollback) {
1838 if (
config.allowPatternRollback)
1848 assert(!
impl.config.allowPatternRollback &&
1849 "this code path is valid only in 'no rollback' mode");
1851 for (
auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1854 repls.push_back(
Value());
1861 Value srcMat =
impl.buildUnresolvedMaterialization(
1866 repls.push_back(srcMat);
1872 repls.push_back(to[0]);
1881 Value srcMat =
impl.buildUnresolvedMaterialization(
1884 Type(), converter)[0];
1885 repls.push_back(srcMat);
1894 "incorrect number of replacement values");
1896 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
1904 for (
auto [
result, repls] :
1905 llvm::zip_equal(op->
getResults(), newValues)) {
1907 auto logProlog = [&, repls = repls]() {
1908 logger.startLine() <<
" Note: Replacing op result of type "
1909 << resultType <<
" with value(s) of type (";
1910 llvm::interleaveComma(repls,
logger.getOStream(), [&](
Value v) {
1911 logger.getOStream() << v.getType();
1913 logger.getOStream() <<
")";
1919 logger.getOStream() <<
", but the type converter failed to legalize "
1920 "the original type.\n";
1925 logger.getOStream() <<
", but the legalized type(s) is/are (";
1926 llvm::interleaveComma(convertedTypes,
logger.getOStream(),
1927 [&](
Type t) { logger.getOStream() << t; });
1928 logger.getOStream() <<
")\n";
1934 if (!
config.allowPatternRollback) {
1943 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1949 if (
config.unlegalizedOps)
1950 config.unlegalizedOps->erase(op);
1958 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1962 "attempting to replace a value that was already replaced");
1967 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1972 "attempting to replace/erase an unresolved materialization");
1988 logger.startLine() <<
"** Replace Value : '" << from <<
"'";
1989 if (
auto blockArg = dyn_cast<BlockArgument>(from)) {
1991 logger.getOStream() <<
" (in region of '" << parentOp->getName()
1992 <<
"' (" << parentOp <<
")";
1994 logger.getOStream() <<
" (unlinked block)";
1998 logger.getOStream() <<
", conditional replacement";
2002 if (!
config.allowPatternRollback) {
2007 Value repl = repls.front();
2024 "attempting to replace a value that was already replaced");
2026 "attempting to replace a op result that was already replaced");
2031 llvm::reportFatalInternalError(
2032 "conditional value replacement is not supported in rollback mode");
2038 if (!
config.allowPatternRollback) {
2045 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2051 if (
config.unlegalizedOps)
2052 config.unlegalizedOps->erase(op);
2061 "attempting to erase a block within a replaced/erased op");
2077 bool wasDetached = !previous;
2083 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
2084 <<
"' (" << parent <<
")";
2087 <<
"** Insert Block into detached Region (nullptr parent op)";
2090 logger.getOStream() <<
" (was detached)";
2091 logger.getOStream() <<
"\n";
2097 "attempting to insert into a region within a replaced/erased op");
2102 config.listener->notifyBlockInserted(block, previous, previousIt);
2106 if (
config.allowPatternRollback) {
2120 if (
config.allowPatternRollback)
2134 reasonCallback(
diag);
2135 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
2136 if (
config.notifyCallback)
2145ConversionPatternRewriter::ConversionPatternRewriter(
2149 *this, config, opConverter)) {
2150 setListener(
impl.get());
2153ConversionPatternRewriter::~ConversionPatternRewriter() =
default;
2155const ConversionConfig &ConversionPatternRewriter::getConfig()
const {
2156 return impl->config;
2160 assert(op && newOp &&
"expected non-null op");
2164void ConversionPatternRewriter::replaceOp(Operation *op,
ValueRange newValues) {
2166 "incorrect # of replacement values");
2170 if (getInsertionPoint() == op->getIterator())
2173 SmallVector<SmallVector<Value>> newVals =
2174 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2175 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2177 impl->replaceOp(op, std::move(newVals));
2180void ConversionPatternRewriter::replaceOpWithMultiple(
2181 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2183 "incorrect # of replacement values");
2187 if (getInsertionPoint() == op->getIterator())
2190 impl->replaceOp(op, std::move(newValues));
2193void ConversionPatternRewriter::eraseOp(Operation *op) {
2195 impl->logger.startLine()
2196 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
2201 if (getInsertionPoint() == op->getIterator())
2204 SmallVector<SmallVector<Value>> nullRepls(op->
getNumResults(), {});
2205 impl->replaceOp(op, std::move(nullRepls));
2208void ConversionPatternRewriter::eraseBlock(
Block *block) {
2209 impl->eraseBlock(block);
2212Block *ConversionPatternRewriter::applySignatureConversion(
2213 Block *block, TypeConverter::SignatureConversion &conversion,
2214 const TypeConverter *converter) {
2215 assert(!impl->wasOpReplaced(block->
getParentOp()) &&
2216 "attempting to apply a signature conversion to a block within a "
2217 "replaced/erased op");
2218 return impl->applySignatureConversion(block, converter, conversion);
2221FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2222 Region *region,
const TypeConverter &converter,
2223 TypeConverter::SignatureConversion *entryConversion) {
2224 assert(!impl->wasOpReplaced(region->
getParentOp()) &&
2225 "attempting to apply a signature conversion to a block within a "
2226 "replaced/erased op");
2227 return impl->convertRegionTypes(region, converter, entryConversion);
2230void ConversionPatternRewriter::replaceAllUsesWith(Value from,
ValueRange to) {
2231 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2234void ConversionPatternRewriter::replaceUsesWithIf(
2236 bool *allUsesReplaced) {
2237 assert(!allUsesReplaced &&
2238 "allUsesReplaced is not supported in a dialect conversion");
2239 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2242Value ConversionPatternRewriter::getRemappedValue(Value key) {
2243 SmallVector<ValueVector> remappedValues;
2244 if (
failed(impl->remapValues(
"value", std::nullopt, key,
2247 assert(remappedValues.front().size() == 1 &&
"1:N conversion not supported");
2248 return remappedValues.front().front();
2252ConversionPatternRewriter::getRemappedValues(
ValueRange keys,
2253 SmallVectorImpl<Value> &results) {
2256 SmallVector<ValueVector> remapped;
2257 if (
failed(impl->remapValues(
"value", std::nullopt, keys,
2260 for (
const auto &values : remapped) {
2261 assert(values.size() == 1 &&
"1:N conversion not supported");
2262 results.push_back(values.front());
2267void ConversionPatternRewriter::inlineBlockBefore(
Block *source,
Block *dest,
2272 "incorrect # of argument replacement values");
2273 assert(!impl->wasOpReplaced(source->
getParentOp()) &&
2274 "attempting to inline a block from a replaced/erased op");
2275 assert(!impl->wasOpReplaced(dest->
getParentOp()) &&
2276 "attempting to inline a block into a replaced/erased op");
2277 auto opIgnored = [&](Operation *op) {
return impl->isOpIgnored(op); };
2280 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
2281 "expected 'source' to have no predecessors");
2290 bool fastPath = !getConfig().listener;
2292 if (fastPath && impl->config.allowPatternRollback)
2293 impl->inlineBlockBefore(source, dest, before);
2296 for (
auto it : llvm::zip(source->
getArguments(), argValues))
2297 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2304 while (!source->
empty())
2305 moveOpBefore(&source->
front(), dest, before);
2310 if (getInsertionBlock() == source)
2311 setInsertionPoint(dest, getInsertionPoint());
2317void ConversionPatternRewriter::startOpModification(Operation *op) {
2318 if (!impl->config.allowPatternRollback) {
2323 assert(!impl->wasOpReplaced(op) &&
2324 "attempting to modify a replaced/erased op");
2326 impl->pendingRootUpdates.insert(op);
2328 impl->appendRewrite<ModifyOperationRewrite>(op);
2331void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2332 impl->patternModifiedOps.insert(op);
2333 if (!impl->config.allowPatternRollback) {
2335 if (getConfig().listener)
2336 getConfig().listener->notifyOperationModified(op);
2343 assert(!impl->wasOpReplaced(op) &&
2344 "attempting to modify a replaced/erased op");
2345 assert(impl->pendingRootUpdates.erase(op) &&
2346 "operation did not have a pending in-place update");
2350void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2351 if (!impl->config.allowPatternRollback) {
2356 assert(impl->pendingRootUpdates.erase(op) &&
2357 "operation did not have a pending in-place update");
2360 auto it = llvm::find_if(
2361 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
2362 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2363 return modifyRewrite && modifyRewrite->getOperation() == op;
2365 assert(it != impl->rewrites.rend() &&
"no root update started on op");
2367 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2368 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2371detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2379FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2380 ArrayRef<ValueRange> operands)
const {
2381 SmallVector<Value> oneToOneOperands;
2382 oneToOneOperands.reserve(operands.size());
2384 if (operand.size() != 1)
2387 oneToOneOperands.push_back(operand.front());
2389 return std::move(oneToOneOperands);
2393ConversionPattern::matchAndRewrite(Operation *op,
2394 PatternRewriter &rewriter)
const {
2395 auto &dialectRewriter =
static_cast<ConversionPatternRewriter &
>(rewriter);
2396 auto &rewriterImpl = dialectRewriter.getImpl();
2400 getTypeConverter());
2403 SmallVector<ValueVector> remapped;
2408 SmallVector<ValueRange> remappedAsRange =
2409 llvm::to_vector_of<ValueRange>(remapped);
2410 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2419using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2422class OperationLegalizer {
2424 using LegalizationAction = ConversionTarget::LegalizationAction;
2426 OperationLegalizer(ConversionPatternRewriter &rewriter,
2427 const ConversionTarget &targetInfo,
2428 const FrozenRewritePatternSet &patterns);
2431 bool isIllegal(Operation *op)
const;
2435 LogicalResult legalize(Operation *op);
2438 const ConversionTarget &getTarget() {
return target; }
2442 LogicalResult legalizeWithFold(Operation *op);
2446 LogicalResult legalizeWithPattern(Operation *op);
2450 bool canApplyPattern(Operation *op,
const Pattern &pattern);
2454 legalizePatternResult(Operation *op,
const Pattern &pattern,
2455 const RewriterState &curState,
2472 void buildLegalizationGraph(
2473 LegalizationPatterns &anyOpLegalizerPatterns,
2484 void computeLegalizationGraphBenefit(
2485 LegalizationPatterns &anyOpLegalizerPatterns,
2490 unsigned computeOpLegalizationDepth(
2497 unsigned applyCostModelToPatterns(
2498 LegalizationPatterns &patterns,
2503 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2506 ConversionPatternRewriter &rewriter;
2509 const ConversionTarget &
target;
2512 PatternApplicator applicator;
2516OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2517 const ConversionTarget &targetInfo,
2518 const FrozenRewritePatternSet &patterns)
2519 : rewriter(rewriter),
target(targetInfo), applicator(patterns) {
2523 LegalizationPatterns anyOpLegalizerPatterns;
2525 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2526 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2529bool OperationLegalizer::isIllegal(Operation *op)
const {
2530 return target.isIllegal(op);
2533LogicalResult OperationLegalizer::legalize(Operation *op) {
2535 const char *logLineComment =
2536 "//===-------------------------------------------===//\n";
2538 auto &logger = rewriter.getImpl().logger;
2542 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2545 logger.getOStream() <<
"\n";
2546 logger.startLine() << logLineComment;
2547 logger.startLine() <<
"Legalizing operation : ";
2552 logger.getOStream() <<
"'" << op->
getName() <<
"' ";
2553 logger.getOStream() <<
"(" << op <<
") {\n";
2558 logger.startLine() << OpWithFlags(op,
2559 OpPrintingFlags().printGenericOpForm())
2566 logSuccess(logger,
"operation marked 'ignored' during conversion");
2567 logger.startLine() << logLineComment;
2573 if (
auto legalityInfo =
target.isLegal(op)) {
2576 logger,
"operation marked legal by the target{0}",
2577 legalityInfo->isRecursivelyLegal
2578 ?
"; NOTE: operation is recursively legal; skipping internals"
2580 logger.startLine() << logLineComment;
2585 if (legalityInfo->isRecursivelyLegal) {
2586 op->
walk([&](Operation *nested) {
2588 rewriter.getImpl().ignoredOps.
insert(nested);
2597 const ConversionConfig &config = rewriter.getConfig();
2598 if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2599 if (succeeded(legalizeWithFold(op))) {
2602 logger.startLine() << logLineComment;
2609 if (succeeded(legalizeWithPattern(op))) {
2612 logger.startLine() << logLineComment;
2619 if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2620 if (succeeded(legalizeWithFold(op))) {
2623 logger.startLine() << logLineComment;
2630 logFailure(logger,
"no matched legalization pattern");
2631 logger.startLine() << logLineComment;
2638template <
typename T>
2640 T
result = std::move(obj);
2645LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2646 auto &rewriterImpl = rewriter.getImpl();
2648 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2649 rewriterImpl.
logger.indent();
2654 llvm::scope_exit cleanup([&]() {
2664 SmallVector<Value, 2> replacementValues;
2665 SmallVector<Operation *, 2> newOps;
2668 if (
failed(rewriter.
tryFold(op, replacementValues, &newOps))) {
2677 if (replacementValues.empty())
2678 return legalize(op);
2681 rewriter.
replaceOp(op, replacementValues);
2684 for (Operation *newOp : newOps) {
2685 if (
failed(legalize(newOp))) {
2687 "failed to legalize generated constant '{0}'",
2689 if (!rewriter.getConfig().allowPatternRollback) {
2691 llvm::reportFatalInternalError(
2693 "' folder rollback of IR modifications requested");
2711 auto newOpNames = llvm::map_range(
2713 auto modifiedOpNames = llvm::map_range(
2715 llvm::reportFatalInternalError(
"pattern '" + pattern.
getDebugName() +
2716 "' produced IR that could not be legalized. " +
2717 "new ops: {" + llvm::join(newOpNames,
", ") +
2718 "}, " +
"modified ops: {" +
2719 llvm::join(modifiedOpNames,
", ") +
"}");
2722LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2723 auto &rewriterImpl = rewriter.getImpl();
2724 const ConversionConfig &config = rewriter.getConfig();
2726#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2728 std::optional<OperationFingerPrint> topLevelFingerPrint;
2729 if (!rewriterImpl.
config.allowPatternRollback) {
2736 topLevelFingerPrint = OperationFingerPrint(checkOp);
2742 rewriterImpl.
logger.startLine()
2743 <<
"WARNING: Multi-threadeding is enabled. Some dialect "
2744 "conversion expensive checks are skipped in multithreading "
2753 auto canApply = [&](
const Pattern &pattern) {
2754 bool canApply = canApplyPattern(op, pattern);
2755 if (canApply && config.listener)
2756 config.listener->notifyPatternBegin(pattern, op);
2762 auto onFailure = [&](
const Pattern &pattern) {
2764 if (!rewriterImpl.
config.allowPatternRollback) {
2771#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2773 if (checkOp && topLevelFingerPrint) {
2774 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2775 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2776 llvm::reportFatalInternalError(
2777 "pattern '" + pattern.getDebugName() +
2778 "' returned failure but IR did change");
2786 if (rewriterImpl.
config.notifyCallback) {
2788 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2794 if (config.listener)
2795 config.listener->notifyPatternEnd(pattern, failure());
2796 rewriterImpl.
resetState(curState, pattern.getDebugName());
2797 appliedPatterns.erase(&pattern);
2802 auto onSuccess = [&](
const Pattern &pattern) {
2804 if (!rewriterImpl.
config.allowPatternRollback) {
2818 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2819 appliedPatterns.erase(&pattern);
2821 if (!rewriterImpl.
config.allowPatternRollback)
2823 rewriterImpl.
resetState(curState, pattern.getDebugName());
2825 if (config.listener)
2826 config.listener->notifyPatternEnd(pattern,
result);
2831 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2835bool OperationLegalizer::canApplyPattern(Operation *op,
2836 const Pattern &pattern) {
2838 auto &os = rewriter.getImpl().logger;
2839 os.getOStream() <<
"\n";
2840 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2842 os.getOStream() <<
")' {\n";
2849 !appliedPatterns.insert(&pattern).second) {
2851 logFailure(rewriter.getImpl().logger,
"pattern was already applied"));
2857LogicalResult OperationLegalizer::legalizePatternResult(
2858 Operation *op,
const Pattern &pattern,
const RewriterState &curState,
2861 [[maybe_unused]]
auto &impl = rewriter.getImpl();
2862 assert(impl.pendingRootUpdates.empty() &&
"dangling root updates");
2864#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2865 if (impl.config.allowPatternRollback) {
2867 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2868 auto replacedRoot = [&] {
2869 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2871 auto updatedRootInPlace = [&] {
2872 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2874 if (!replacedRoot() && !updatedRootInPlace())
2875 llvm::reportFatalInternalError(
2876 "expected pattern to replace the root operation "
2877 "or modify it in place");
2882 if (
failed(legalizePatternRootUpdates(modifiedOps)) ||
2883 failed(legalizePatternCreatedOperations(newOps))) {
2887 LLVM_DEBUG(
logSuccess(impl.logger,
"pattern applied successfully"));
2891LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2893 for (Operation *op : newOps) {
2894 if (
failed(legalize(op))) {
2895 LLVM_DEBUG(
logFailure(rewriter.getImpl().logger,
2896 "failed to legalize generated operation '{0}'({1})",
2904LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2906 for (Operation *op : modifiedOps) {
2907 if (
failed(legalize(op))) {
2910 "failed to legalize operation updated in-place '{0}'",
2922void OperationLegalizer::buildLegalizationGraph(
2923 LegalizationPatterns &anyOpLegalizerPatterns,
2934 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2935 std::optional<OperationName> root = pattern.
getRootKind();
2941 anyOpLegalizerPatterns.push_back(&pattern);
2946 if (
target.getOpAction(*root) == LegalizationAction::Legal)
2951 invalidPatterns[*root].insert(&pattern);
2953 parentOps[op].insert(*root);
2956 patternWorklist.insert(&pattern);
2964 if (!anyOpLegalizerPatterns.empty()) {
2965 for (
const Pattern *pattern : patternWorklist)
2966 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2970 while (!patternWorklist.empty()) {
2971 auto *pattern = patternWorklist.pop_back_val();
2975 std::optional<LegalizationAction> action = target.getOpAction(op);
2976 return !legalizerPatterns.count(op) &&
2977 (!action || action == LegalizationAction::Illegal);
2983 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2984 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2988 for (
auto op : parentOps[*pattern->
getRootKind()])
2989 patternWorklist.set_union(invalidPatterns[op]);
2993void OperationLegalizer::computeLegalizationGraphBenefit(
2994 LegalizationPatterns &anyOpLegalizerPatterns,
3000 for (
auto &opIt : legalizerPatterns)
3001 if (!minOpPatternDepth.count(opIt.first))
3002 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3008 if (!anyOpLegalizerPatterns.empty())
3009 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3015 applicator.applyCostModel([&](
const Pattern &pattern) {
3016 ArrayRef<const Pattern *> orderedPatternList;
3017 if (std::optional<OperationName> rootName = pattern.
getRootKind())
3018 orderedPatternList = legalizerPatterns[*rootName];
3020 orderedPatternList = anyOpLegalizerPatterns;
3023 auto *it = llvm::find(orderedPatternList, &pattern);
3024 if (it == orderedPatternList.end())
3028 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3032unsigned OperationLegalizer::computeOpLegalizationDepth(
3036 auto depthIt = minOpPatternDepth.find(op);
3037 if (depthIt != minOpPatternDepth.end())
3038 return depthIt->second;
3042 auto opPatternsIt = legalizerPatterns.find(op);
3043 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3048 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3052 unsigned minDepth = applyCostModelToPatterns(
3053 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3054 minOpPatternDepth[op] = minDepth;
3058unsigned OperationLegalizer::applyCostModelToPatterns(
3059 LegalizationPatterns &patterns,
3062 unsigned minDepth = std::numeric_limits<unsigned>::max();
3065 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3066 patternsByDepth.reserve(patterns.size());
3067 for (
const Pattern *pattern : patterns) {
3070 unsigned generatedOpDepth = computeOpLegalizationDepth(
3071 generatedOp, minOpPatternDepth, legalizerPatterns);
3072 depth = std::max(depth, generatedOpDepth + 1);
3074 patternsByDepth.emplace_back(pattern, depth);
3077 minDepth = std::min(minDepth, depth);
3082 if (patternsByDepth.size() == 1)
3086 llvm::stable_sort(patternsByDepth,
3087 [](
const std::pair<const Pattern *, unsigned> &
lhs,
3088 const std::pair<const Pattern *, unsigned> &
rhs) {
3091 if (
lhs.second !=
rhs.second)
3092 return lhs.second <
rhs.second;
3095 auto lhsBenefit =
lhs.first->getBenefit();
3096 auto rhsBenefit =
rhs.first->getBenefit();
3097 return lhsBenefit > rhsBenefit;
3102 for (
auto &patternIt : patternsByDepth)
3103 patterns.push_back(patternIt.first);
3117template <
typename RangeT>
3120 function_ref<
bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3129 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3130 if (castOp.getInputs().empty())
3133 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3136 if (inputCastOp.getOutputs() != castOp.getInputs())
3142 while (!worklist.empty()) {
3143 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3147 UnrealizedConversionCastOp nextCast = castOp;
3149 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3150 if (llvm::any_of(nextCast.getInputs(), [&](
Value v) {
3151 return v.getDefiningOp() == castOp;
3159 castOp.replaceAllUsesWith(nextCast.getInputs());
3162 nextCast = getInputCast(nextCast);
3172 auto markOpLive = [&](
Operation *rootOp) {
3174 worklist.push_back(rootOp);
3175 while (!worklist.empty()) {
3176 Operation *op = worklist.pop_back_val();
3177 if (liveOps.insert(op).second) {
3180 if (
auto castOp = v.
getDefiningOp<UnrealizedConversionCastOp>())
3181 if (isCastOpOfInterestFn(castOp))
3182 worklist.push_back(castOp);
3188 for (UnrealizedConversionCastOp op : castOps) {
3191 if (liveOps.contains(op.getOperation()))
3195 if (llvm::any_of(op->getUsers(), [&](
Operation *user) {
3196 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3197 return !castOp || !isCastOpOfInterestFn(castOp);
3203 for (UnrealizedConversionCastOp op : castOps) {
3204 if (liveOps.contains(op)) {
3206 if (remainingCastOps)
3207 remainingCastOps->push_back(op);
3218 ArrayRef<UnrealizedConversionCastOp> castOps,
3219 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3221 DenseSet<UnrealizedConversionCastOp> castOpSet;
3222 for (UnrealizedConversionCastOp op : castOps)
3223 castOpSet.insert(op);
3228 const DenseSet<UnrealizedConversionCastOp> &castOps,
3229 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3231 llvm::make_range(castOps.begin(), castOps.end()),
3232 [&](UnrealizedConversionCastOp castOp) {
3233 return castOps.contains(castOp);
3245 [&](UnrealizedConversionCastOp castOp) {
3246 return castOps.contains(castOp);
3263 const ConversionConfig &config,
3264 OpConversionMode mode)
3265 : rewriter(ctx, config, *this), opLegalizer(rewriter,
target, patterns),
3274 template <
typename Fn>
3276 bool isRecursiveLegalization =
false);
3278 bool isRecursiveLegalization =
false) {
3280 ops, [&]() {}, isRecursiveLegalization);
3288 LogicalResult convert(
Operation *op,
bool isRecursiveLegalization =
false);
3294 ConversionPatternRewriter rewriter;
3297 OperationLegalizer opLegalizer;
3300 OpConversionMode mode;
3305 bool isRecursiveLegalization) {
3306 const ConversionConfig &config = rewriter.getConfig();
3307 auto emitFailedToLegalizeDiag = [&](
bool wasExplicitlyIllegal) {
3309 <<
"failed to legalize operation '"
3311 if (wasExplicitlyIllegal)
3312 diag <<
" that was explicitly marked illegal";
3317 if (failed(opLegalizer.legalize(op))) {
3320 if (mode == OpConversionMode::Full) {
3321 if (!isRecursiveLegalization)
3322 emitFailedToLegalizeDiag(
false);
3328 if (mode == OpConversionMode::Partial) {
3329 if (opLegalizer.isIllegal(op)) {
3330 if (!isRecursiveLegalization)
3331 emitFailedToLegalizeDiag(
true);
3334 if (config.unlegalizedOps && !isRecursiveLegalization)
3335 config.unlegalizedOps->insert(op);
3337 }
else if (mode == OpConversionMode::Analysis) {
3341 if (config.legalizableOps && !isRecursiveLegalization)
3342 config.legalizableOps->insert(op);
3349 UnrealizedConversionCastOp op,
3350 const UnresolvedMaterializationInfo &info) {
3351 assert(!op.use_empty() &&
3352 "expected that dead materializations have already been DCE'd");
3359 switch (info.getMaterializationKind()) {
3360 case MaterializationKind::Target:
3361 newMaterialization = converter->materializeTargetConversion(
3362 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3363 info.getOriginalType());
3365 case MaterializationKind::Source:
3366 assert(op->getNumResults() == 1 &&
"expected single result");
3367 Value sourceMat = converter->materializeSourceConversion(
3368 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3370 newMaterialization.push_back(sourceMat);
3373 if (!newMaterialization.empty()) {
3375 ValueRange newMaterializationRange(newMaterialization);
3376 assert(
TypeRange(newMaterializationRange) == op.getResultTypes() &&
3377 "materialization callback produced value of incorrect type");
3379 rewriter.
replaceOp(op, newMaterialization);
3385 <<
"failed to legalize unresolved materialization "
3387 << inputOperands.
getTypes() <<
") to ("
3388 << op.getResultTypes()
3389 <<
") that remained live after conversion";
3390 diag.attachNote(op->getUsers().begin()->getLoc())
3391 <<
"see existing live user here: " << *op->getUsers().begin();
3395template <
typename Fn>
3398 bool isRecursiveLegalization) {
3406 toConvert.push_back(op);
3409 auto legalityInfo =
target.isLegal(op);
3410 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3416 if (failed(
convert(op, isRecursiveLegalization))) {
3425LogicalResult ConversionPatternRewriter::legalize(
Operation *op) {
3426 return impl->opConverter.legalizeOperations(op,
3430LogicalResult ConversionPatternRewriter::legalize(
Region *r) {
3446 std::optional<TypeConverter::SignatureConversion> conversion =
3447 converter->convertBlockSignature(&r->front());
3450 applySignatureConversion(&r->front(), *conversion, converter);
3455 return impl->opConverter.legalizeOperations(ops,
3464 if (rewriterImpl.
config.allowPatternRollback) {
3488 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3492 if (rewriter.getConfig().buildMaterializations) {
3496 rewriter.getConfig().listener);
3497 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3498 auto it = materializations.find(castOp);
3499 assert(it != materializations.end() &&
"inconsistent state");
3513void TypeConverter::SignatureConversion::addInputs(
unsigned origInputNo,
3515 assert(!types.empty() &&
"expected valid types");
3516 remapInput(origInputNo, argTypes.size(), types.size());
3520void TypeConverter::SignatureConversion::addInputs(
ArrayRef<Type> types) {
3521 assert(!types.empty() &&
3522 "1->0 type remappings don't need to be added explicitly");
3523 argTypes.append(types.begin(), types.end());
3526void TypeConverter::SignatureConversion::remapInput(
unsigned origInputNo,
3527 unsigned newInputNo,
3528 unsigned newInputCount) {
3529 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3530 assert(newInputCount != 0 &&
"expected valid input count");
3531 remappedInputs[origInputNo] =
3532 InputMapping{newInputNo, newInputCount, {}};
3535void TypeConverter::SignatureConversion::remapInput(
3536 unsigned origInputNo, ArrayRef<Value> replacements) {
3537 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3538 remappedInputs[origInputNo] = InputMapping{
3540 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3551TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3552 SmallVectorImpl<Type> &results)
const {
3553 assert(typeOrValue &&
"expected non-null type");
3554 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3555 : cast<Type>(typeOrValue);
3557 std::shared_lock<
decltype(cacheMutex)> cacheReadLock(cacheMutex,
3560 cacheReadLock.lock();
3561 auto existingIt = cachedDirectConversions.find(t);
3562 if (existingIt != cachedDirectConversions.end()) {
3563 if (existingIt->second)
3564 results.push_back(existingIt->second);
3565 return success(existingIt->second !=
nullptr);
3567 auto multiIt = cachedMultiConversions.find(t);
3568 if (multiIt != cachedMultiConversions.end()) {
3569 results.append(multiIt->second.begin(), multiIt->second.end());
3575 size_t currentCount = results.size();
3579 auto isCacheable = [&](
int index) {
3580 int numberOfConversionsUntilContextAware =
3581 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3582 return index < numberOfConversionsUntilContextAware;
3585 std::unique_lock<
decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3588 for (
auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3589 const ConversionCallbackFn &converter = indexedConverter.value();
3590 std::optional<LogicalResult>
result = converter(typeOrValue, results);
3592 assert(results.size() == currentCount &&
3593 "failed type conversion should not change results");
3596 if (!isCacheable(indexedConverter.index()))
3599 cacheWriteLock.lock();
3600 if (!succeeded(*
result)) {
3601 assert(results.size() == currentCount &&
3602 "failed type conversion should not change results");
3603 cachedDirectConversions.try_emplace(t,
nullptr);
3606 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3607 if (newTypes.size() == 1)
3608 cachedDirectConversions.try_emplace(t, newTypes.front());
3610 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3616LogicalResult TypeConverter::convertType(Type t,
3617 SmallVectorImpl<Type> &results)
const {
3618 return convertTypeImpl(t, results);
3621LogicalResult TypeConverter::convertType(Value v,
3622 SmallVectorImpl<Type> &results)
const {
3623 return convertTypeImpl(v, results);
3626Type TypeConverter::convertType(Type t)
const {
3628 SmallVector<Type, 1> results;
3629 if (
failed(convertType(t, results)))
3633 return results.size() == 1 ? results.front() :
nullptr;
3636Type TypeConverter::convertType(Value v)
const {
3638 SmallVector<Type, 1> results;
3639 if (
failed(convertType(v, results)))
3643 return results.size() == 1 ? results.front() :
nullptr;
3647TypeConverter::convertTypes(
TypeRange types,
3648 SmallVectorImpl<Type> &results)
const {
3649 for (Type type : types)
3650 if (
failed(convertType(type, results)))
3656TypeConverter::convertTypes(
ValueRange values,
3657 SmallVectorImpl<Type> &results)
const {
3658 for (Value value : values)
3659 if (
failed(convertType(value, results)))
3664bool TypeConverter::isLegal(Type type)
const {
3665 return convertType(type) == type;
3668bool TypeConverter::isLegal(Value value)
const {
3669 return convertType(value) == value.
getType();
3672bool TypeConverter::isLegal(Operation *op)
const {
3676bool TypeConverter::isLegal(Region *region)
const {
3677 return llvm::all_of(
3681bool TypeConverter::isSignatureLegal(FunctionType ty)
const {
3682 if (!isLegal(ty.getInputs()))
3684 if (!isLegal(ty.getResults()))
3690TypeConverter::convertSignatureArg(
unsigned inputNo, Type type,
3691 SignatureConversion &
result)
const {
3693 SmallVector<Type, 1> convertedTypes;
3694 if (
failed(convertType(type, convertedTypes)))
3698 if (convertedTypes.empty())
3702 result.addInputs(inputNo, convertedTypes);
3706TypeConverter::convertSignatureArgs(
TypeRange types,
3707 SignatureConversion &
result,
3708 unsigned origInputOffset)
const {
3709 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3710 if (
failed(convertSignatureArg(origInputOffset + i, types[i],
result)))
3715TypeConverter::convertSignatureArg(
unsigned inputNo, Value value,
3716 SignatureConversion &
result)
const {
3718 SmallVector<Type, 1> convertedTypes;
3719 if (
failed(convertType(value, convertedTypes)))
3723 if (convertedTypes.empty())
3727 result.addInputs(inputNo, convertedTypes);
3731TypeConverter::convertSignatureArgs(
ValueRange values,
3732 SignatureConversion &
result,
3733 unsigned origInputOffset)
const {
3734 for (
unsigned i = 0, e = values.size(); i != e; ++i)
3735 if (
failed(convertSignatureArg(origInputOffset + i, values[i],
result)))
3740Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3741 Location loc, Type resultType,
3743 for (
const SourceMaterializationCallbackFn &fn :
3744 llvm::reverse(sourceMaterializations))
3745 if (Value
result = fn(builder, resultType, inputs, loc))
3750Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3751 Location loc, Type resultType,
3753 Type originalType)
const {
3754 SmallVector<Value>
result = materializeTargetConversion(
3755 builder, loc,
TypeRange(resultType), inputs, originalType);
3758 assert(
result.size() == 1 &&
"expected single result");
3762SmallVector<Value> TypeConverter::materializeTargetConversion(
3764 Type originalType)
const {
3765 for (
const TargetMaterializationCallbackFn &fn :
3766 llvm::reverse(targetMaterializations)) {
3767 SmallVector<Value>
result =
3768 fn(builder, resultTypes, inputs, loc, originalType);
3772 "callback produced incorrect number of values or values with "
3779std::optional<TypeConverter::SignatureConversion>
3780TypeConverter::convertBlockSignature(
Block *block)
const {
3783 return std::nullopt;
3790TypeConverter::AttributeConversionResult
3791TypeConverter::AttributeConversionResult::result(Attribute attr) {
3792 return AttributeConversionResult(attr, resultTag);
3795TypeConverter::AttributeConversionResult
3796TypeConverter::AttributeConversionResult::na() {
3797 return AttributeConversionResult(
nullptr, naTag);
3800TypeConverter::AttributeConversionResult
3801TypeConverter::AttributeConversionResult::abort() {
3802 return AttributeConversionResult(
nullptr, abortTag);
3805bool TypeConverter::AttributeConversionResult::hasResult()
const {
3806 return impl.getInt() == resultTag;
3809bool TypeConverter::AttributeConversionResult::isNa()
const {
3810 return impl.getInt() == naTag;
3813bool TypeConverter::AttributeConversionResult::isAbort()
const {
3814 return impl.getInt() == abortTag;
3817Attribute TypeConverter::AttributeConversionResult::getResult()
const {
3818 assert(hasResult() &&
"Cannot get result from N/A or abort");
3819 return impl.getPointer();
3822std::optional<Attribute>
3823TypeConverter::convertTypeAttribute(Type type, Attribute attr)
const {
3824 for (
const TypeAttributeConversionCallbackFn &fn :
3825 llvm::reverse(typeAttributeConversions)) {
3826 AttributeConversionResult res = fn(type, attr);
3827 if (res.hasResult())
3828 return res.getResult();
3830 return std::nullopt;
3832 return std::nullopt;
3841 ConversionPatternRewriter &rewriter) {
3842 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3847 TypeConverter::SignatureConversion
result(type.getNumInputs());
3849 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
result)) ||
3850 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3852 if (!funcOp.getFunctionBody().empty())
3853 rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(),
result,
3857 auto newType = FunctionType::get(rewriter.getContext(),
3858 result.getConvertedTypes(), newResults);
3860 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3869struct FunctionOpInterfaceSignatureConversion :
public ConversionPattern {
3870 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3872 const TypeConverter &converter,
3873 PatternBenefit benefit)
3874 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3877 matchAndRewrite(Operation *op, ArrayRef<Value> ,
3878 ConversionPatternRewriter &rewriter)
const override {
3879 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3884struct AnyFunctionOpInterfaceSignatureConversion
3885 :
public OpInterfaceConversionPattern<FunctionOpInterface> {
3886 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3889 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> ,
3890 ConversionPatternRewriter &rewriter)
const override {
3896FailureOr<Operation *>
3897mlir::convertOpResultTypes(Operation *op,
ValueRange operands,
3898 const TypeConverter &converter,
3899 ConversionPatternRewriter &rewriter) {
3900 assert(op &&
"Invalid op");
3901 Location loc = op->
getLoc();
3902 if (converter.isLegal(op))
3903 return rewriter.notifyMatchFailure(loc,
"op already legal");
3905 OperationState newOp(loc, op->
getName());
3906 newOp.addOperands(operands);
3908 SmallVector<Type> newResultTypes;
3910 return rewriter.notifyMatchFailure(loc,
"couldn't convert return types");
3912 newOp.addTypes(newResultTypes);
3913 newOp.addAttributes(op->
getAttrs());
3914 return rewriter.create(newOp);
3917void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3918 StringRef functionLikeOpName, RewritePatternSet &patterns,
3919 const TypeConverter &converter, PatternBenefit benefit) {
3920 patterns.
add<FunctionOpInterfaceSignatureConversion>(
3921 functionLikeOpName, patterns.
getContext(), converter, benefit);
3924void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3925 RewritePatternSet &patterns,
const TypeConverter &converter,
3926 PatternBenefit benefit) {
3927 patterns.
add<AnyFunctionOpInterfaceSignatureConversion>(
3935void ConversionTarget::setOpAction(OperationName op,
3936 LegalizationAction action) {
3937 legalOperations[op].action = action;
3940void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3941 LegalizationAction action) {
3942 for (StringRef dialect : dialectNames)
3943 legalDialects[dialect] = action;
3946auto ConversionTarget::getOpAction(OperationName op)
const
3947 -> std::optional<LegalizationAction> {
3948 std::optional<LegalizationInfo> info = getOpInfo(op);
3949 return info ? info->action : std::optional<LegalizationAction>();
3952auto ConversionTarget::isLegal(Operation *op)
const
3953 -> std::optional<LegalOpDetails> {
3954 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3956 return std::nullopt;
3959 auto isOpLegal = [&] {
3961 if (info->action == LegalizationAction::Dynamic) {
3962 std::optional<bool>
result = info->legalityFn(op);
3968 return info->action == LegalizationAction::Legal;
3971 return std::nullopt;
3974 LegalOpDetails legalityDetails;
3975 if (info->isRecursivelyLegal) {
3976 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3977 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3978 legalityDetails.isRecursivelyLegal =
3979 legalityFnIt->second(op).value_or(
true);
3981 legalityDetails.isRecursivelyLegal =
true;
3984 return legalityDetails;
3987bool ConversionTarget::isIllegal(Operation *op)
const {
3988 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3992 if (info->action == LegalizationAction::Dynamic) {
3993 std::optional<bool>
result = info->legalityFn(op);
4000 return info->action == LegalizationAction::Illegal;
4004 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
4005 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
4009 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
4011 if (std::optional<bool>
result = newCl(op))
4019void ConversionTarget::setLegalityCallback(
4020 OperationName name,
const DynamicLegalityCallbackFn &callback) {
4021 assert(callback &&
"expected valid legality callback");
4022 auto *infoIt = legalOperations.find(name);
4023 assert(infoIt != legalOperations.end() &&
4024 infoIt->second.action == LegalizationAction::Dynamic &&
4025 "expected operation to already be marked as dynamically legal");
4026 infoIt->second.legalityFn =
4030void ConversionTarget::markOpRecursivelyLegal(
4031 OperationName name,
const DynamicLegalityCallbackFn &callback) {
4032 auto *infoIt = legalOperations.find(name);
4033 assert(infoIt != legalOperations.end() &&
4034 infoIt->second.action != LegalizationAction::Illegal &&
4035 "expected operation to already be marked as legal");
4036 infoIt->second.isRecursivelyLegal =
true;
4039 std::move(opRecursiveLegalityFns[name]), callback);
4041 opRecursiveLegalityFns.erase(name);
4044void ConversionTarget::setLegalityCallback(
4045 ArrayRef<StringRef> dialects,
const DynamicLegalityCallbackFn &callback) {
4046 assert(callback &&
"expected valid legality callback");
4047 for (StringRef dialect : dialects)
4049 std::move(dialectLegalityFns[dialect]), callback);
4052void ConversionTarget::setLegalityCallback(
4053 const DynamicLegalityCallbackFn &callback) {
4054 assert(callback &&
"expected valid legality callback");
4058auto ConversionTarget::getOpInfo(OperationName op)
const
4059 -> std::optional<LegalizationInfo> {
4061 const auto *it = legalOperations.find(op);
4062 if (it != legalOperations.end())
4065 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4066 if (dialectIt != legalDialects.end()) {
4067 DynamicLegalityCallbackFn callback;
4068 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4069 if (dialectFn != dialectLegalityFns.end())
4070 callback = dialectFn->second;
4071 return LegalizationInfo{dialectIt->second,
false,
4075 if (unknownLegalityFn)
4076 return LegalizationInfo{LegalizationAction::Dynamic,
4077 false, unknownLegalityFn};
4078 return std::nullopt;
4081#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4086void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4087 auto &rewriterImpl =
4088 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4092void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4093 auto &rewriterImpl =
4094 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4100static FailureOr<SmallVector<Value>>
4101pdllConvertValues(ConversionPatternRewriter &rewriter,
ValueRange values) {
4102 SmallVector<Value> mappedValues;
4103 if (
failed(rewriter.getRemappedValues(values, mappedValues)))
4105 return std::move(mappedValues);
4108void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
4111 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4112 auto results = pdllConvertValues(
4113 static_cast<ConversionPatternRewriter &
>(rewriter), value);
4116 return results->front();
4119 "convertValues", [](PatternRewriter &rewriter,
ValueRange values) {
4120 return pdllConvertValues(
4121 static_cast<ConversionPatternRewriter &
>(rewriter), values);
4125 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4126 auto &rewriterImpl =
4127 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4128 if (
const TypeConverter *converter =
4130 if (Type newType = converter->convertType(type))
4138 [](PatternRewriter &rewriter,
4139 TypeRange types) -> FailureOr<SmallVector<Type>> {
4140 auto &rewriterImpl =
4141 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4144 return SmallVector<Type>(types);
4146 SmallVector<Type> remappedTypes;
4147 if (
failed(converter->convertTypes(types, remappedTypes)))
4149 return std::move(remappedTypes);
4164 static constexpr StringLiteral
tag =
"apply-conversion";
4165 static constexpr StringLiteral
desc =
4166 "Encapsulate the application of a dialect conversion";
4174 ConversionConfig config,
4175 OpConversionMode mode) {
4179 LogicalResult status =
success();
4184 patterns, config, mode);
4195LogicalResult mlir::applyPartialConversion(
4196 ArrayRef<Operation *> ops,
const ConversionTarget &
target,
4197 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4199 OpConversionMode::Partial);
4202mlir::applyPartialConversion(Operation *op,
const ConversionTarget &
target,
4203 const FrozenRewritePatternSet &patterns,
4204 ConversionConfig config) {
4205 return applyPartialConversion(llvm::ArrayRef(op),
target, patterns, config);
4212LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4213 const ConversionTarget &
target,
4214 const FrozenRewritePatternSet &patterns,
4215 ConversionConfig config) {
4218LogicalResult mlir::applyFullConversion(Operation *op,
4219 const ConversionTarget &
target,
4220 const FrozenRewritePatternSet &patterns,
4221 ConversionConfig config) {
4222 return applyFullConversion(llvm::ArrayRef(op),
target, patterns, config);
4239 "expected top-level op to be isolated from above");
4242 "expected ops to have a common ancestor");
4251 for (
Operation *op : ops.drop_front()) {
4255 assert(commonAncestor &&
4256 "expected to find a common isolated from above ancestor");
4260 return commonAncestor;
4263LogicalResult mlir::applyAnalysisConversion(
4264 ArrayRef<Operation *> ops, ConversionTarget &
target,
4265 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4267 if (config.legalizableOps)
4268 assert(config.legalizableOps->empty() &&
"expected empty set");
4274 Operation *clonedAncestor = commonAncestor->
clone(mapping);
4278 inverseOperationMap[it.second] = it.first;
4281 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4282 ops, [&](Operation *op) {
return mapping.
lookup(op); });
4284 OpConversionMode::Analysis);
4288 if (config.legalizableOps) {
4290 for (Operation *op : *config.legalizableOps)
4291 originalLegalizableOps.insert(inverseOperationMap[op]);
4292 *config.legalizableOps = std::move(originalLegalizableOps);
4296 clonedAncestor->
erase();
4301mlir::applyAnalysisConversion(Operation *op, ConversionTarget &
target,
4302 const FrozenRewritePatternSet &patterns,
4303 ConversionConfig config) {
4304 return applyAnalysisConversion(llvm::ArrayRef(op),
target, patterns, config);
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
void destroyOpProperties(OpaqueProperties properties) const
This hooks destroy the op properties.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
CRTP Implementation of an action.
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This iterator enumerates elements according to their dominance relationship.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
const ConversionTarget & getTarget()
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, bool isRecursiveLegalization=false)
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, Fn onFailure, bool isRecursiveLegalization=false)
Legalizes the given operations (and their nested operations) to the conversion target.
LogicalResult applyConversion(ArrayRef< Operation * > ops)
Applies the conversion to the given operations (and their nested operations).
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
bool wasErased(void *ptr) const
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.