10#include "mlir/Config/mlir-config.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/SaveAndRestore.h"
27#include "llvm/Support/ScopedPrinter.h"
34#define DEBUG_TYPE "dialect-conversion"
37template <
typename... Args>
38static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
41 os.startLine() <<
"} -> SUCCESS";
43 os.getOStream() <<
" : "
44 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
45 os.getOStream() <<
"\n";
50template <
typename... Args>
51static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
54 os.startLine() <<
"} -> FAILURE : "
55 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
65 if (
OpResult inputRes = dyn_cast<OpResult>(value))
66 insertPt = ++inputRes.getOwner()->getIterator();
73 assert(!vals.empty() &&
"expected at least one value");
76 for (
Value v : vals.drop_front()) {
90 assert(dom &&
"unable to find valid insertion point");
98enum OpConversionMode {
125struct ValueVectorMapInfo {
127 static ::llvm::hash_code getHashValue(
const ValueVector &val) {
128 return ::llvm::hash_combine_range(val);
137struct ConversionValueMapping {
140 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
145 template <
typename T>
146 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
149 template <
typename OldVal,
typename NewVal>
150 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
151 map(OldVal &&oldVal, NewVal &&newVal) {
155 assert(next != oldVal &&
"inserting cyclic mapping");
156 auto it = mapping.find(next);
157 if (it == mapping.end())
162 mappedTo.insert_range(newVal);
164 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
168 template <
typename OldVal,
typename NewVal>
169 std::enable_if_t<!IsValueVector<OldVal>::value ||
170 !IsValueVector<NewVal>::value>
171 map(OldVal &&oldVal, NewVal &&newVal) {
172 if constexpr (IsValueVector<OldVal>{}) {
173 map(std::forward<OldVal>(oldVal),
ValueVector{newVal});
174 }
else if constexpr (IsValueVector<NewVal>{}) {
175 map(
ValueVector{oldVal}, std::forward<NewVal>(newVal));
186 void erase(
const ValueVector &value) { mapping.erase(value); }
206 assert(!values.empty() &&
"expected non-empty value vector");
207 Operation *op = values.front().getDefiningOp();
208 for (
Value v : llvm::drop_begin(values)) {
209 if (v.getDefiningOp() != op)
219 assert(!values.empty() &&
"expected non-empty value vector");
225 auto it = mapping.find(from);
226 if (it == mapping.end()) {
239struct RewriterState {
240 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
241 unsigned numReplacedOps)
242 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
243 numReplacedOps(numReplacedOps) {}
246 unsigned numRewrites;
249 unsigned numIgnoredOperations;
252 unsigned numReplacedOps;
259static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
262static void notifyIRErased(RewriterBase::Listener *listener,
Block &
b) {
263 for (Operation &op :
b)
264 notifyIRErased(listener, op);
270static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
273 notifyIRErased(listener,
b);
303 UnresolvedMaterialization,
308 virtual ~IRRewrite() =
default;
311 virtual void rollback() = 0;
325 virtual void commit(RewriterBase &rewriter) {}
328 virtual void cleanup(RewriterBase &rewriter) {}
330 Kind getKind()
const {
return kind; }
332 static bool classof(
const IRRewrite *
rewrite) {
return true; }
335 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
336 : kind(kind), rewriterImpl(rewriterImpl) {}
338 const ConversionConfig &getConfig()
const;
341 ConversionPatternRewriterImpl &rewriterImpl;
345class BlockRewrite :
public IRRewrite {
348 Block *getBlock()
const {
return block; }
350 static bool classof(
const IRRewrite *
rewrite) {
351 return rewrite->getKind() >= Kind::CreateBlock &&
352 rewrite->getKind() <= Kind::BlockTypeConversion;
356 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
358 : IRRewrite(kind, rewriterImpl), block(block) {}
365class ValueRewrite :
public IRRewrite {
368 Value getValue()
const {
return value; }
370 static bool classof(
const IRRewrite *
rewrite) {
371 return rewrite->getKind() == Kind::ReplaceValue;
375 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
377 : IRRewrite(kind, rewriterImpl), value(value) {}
386class CreateBlockRewrite :
public BlockRewrite {
388 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
389 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
391 static bool classof(
const IRRewrite *
rewrite) {
392 return rewrite->getKind() == Kind::CreateBlock;
395 void commit(RewriterBase &rewriter)
override {
401 void rollback()
override {
404 auto &blockOps = block->getOperations();
405 while (!blockOps.empty())
406 blockOps.remove(blockOps.begin());
407 block->dropAllUses();
408 if (block->getParent())
419class EraseBlockRewrite :
public BlockRewrite {
421 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
422 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
423 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
425 static bool classof(
const IRRewrite *
rewrite) {
426 return rewrite->getKind() == Kind::EraseBlock;
429 ~EraseBlockRewrite()
override {
431 "rewrite was neither rolled back nor committed/cleaned up");
434 void rollback()
override {
437 assert(block &&
"expected block");
442 blockList.insert(before, block);
446 void commit(RewriterBase &rewriter)
override {
447 assert(block &&
"expected block");
451 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
452 notifyIRErased(listener, *block);
455 void cleanup(RewriterBase &rewriter)
override {
457 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
459 assert(block->empty() &&
"expected empty block");
462 block->dropAllDefinedValueUses();
473 Block *insertBeforeBlock;
479class InlineBlockRewrite :
public BlockRewrite {
481 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
483 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
484 sourceBlock(sourceBlock),
485 firstInlinedInst(sourceBlock->empty() ?
nullptr
486 : &sourceBlock->front()),
487 lastInlinedInst(sourceBlock->empty() ?
nullptr : &sourceBlock->back()) {
493 assert(!getConfig().listener &&
494 "InlineBlockRewrite not supported if listener is attached");
497 static bool classof(
const IRRewrite *
rewrite) {
498 return rewrite->getKind() == Kind::InlineBlock;
501 void rollback()
override {
504 if (firstInlinedInst) {
505 assert(lastInlinedInst &&
"expected operation");
518 Operation *firstInlinedInst;
521 Operation *lastInlinedInst;
525class MoveBlockRewrite :
public BlockRewrite {
527 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
529 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block),
530 region(previousRegion),
531 insertBeforeBlock(previousIt == previousRegion->end() ?
nullptr
534 static bool classof(
const IRRewrite *
rewrite) {
535 return rewrite->getKind() == Kind::MoveBlock;
538 void commit(RewriterBase &rewriter)
override {
548 void rollback()
override {
552 if (Region *currentParent = block->
getParent()) {
554 region->getBlocks().splice(before, currentParent->getBlocks(), block);
558 region->
getBlocks().insert(before, block);
567 Block *insertBeforeBlock;
571class BlockTypeConversionRewrite :
public BlockRewrite {
573 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
575 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, origBlock),
576 newBlock(newBlock) {}
578 static bool classof(
const IRRewrite *
rewrite) {
579 return rewrite->getKind() == Kind::BlockTypeConversion;
582 Block *getOrigBlock()
const {
return block; }
584 Block *getNewBlock()
const {
return newBlock; }
586 void commit(RewriterBase &rewriter)
override;
588 void rollback()
override;
598class ReplaceValueRewrite :
public ValueRewrite {
600 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
601 const TypeConverter *converter)
602 : ValueRewrite(
Kind::ReplaceValue, rewriterImpl, value),
603 converter(converter) {}
605 static bool classof(
const IRRewrite *
rewrite) {
606 return rewrite->getKind() == Kind::ReplaceValue;
609 void commit(RewriterBase &rewriter)
override;
611 void rollback()
override;
615 const TypeConverter *converter;
619class OperationRewrite :
public IRRewrite {
622 Operation *getOperation()
const {
return op; }
624 static bool classof(
const IRRewrite *
rewrite) {
625 return rewrite->getKind() >= Kind::MoveOperation &&
626 rewrite->getKind() <= Kind::UnresolvedMaterialization;
630 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
632 : IRRewrite(kind, rewriterImpl), op(op) {}
639class MoveOperationRewrite :
public OperationRewrite {
641 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
642 Operation *op, OpBuilder::InsertPoint previous)
643 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op),
644 block(previous.getBlock()),
645 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
647 : &*previous.getPoint()) {}
649 static bool classof(
const IRRewrite *
rewrite) {
650 return rewrite->getKind() == Kind::MoveOperation;
653 void commit(RewriterBase &rewriter)
override {
659 op, OpBuilder::InsertPoint(block,
664 void rollback()
override {
677 Operation *insertBeforeOp;
682class ModifyOperationRewrite :
public OperationRewrite {
684 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
686 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
687 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
688 operands(op->operand_begin(), op->operand_end()),
689 successors(op->successor_begin(), op->successor_end()) {
692 propertiesStorage = operator new(op->getPropertiesStorageSize());
693 PropertyRef propCopy(name.getOpPropertiesTypeID(), propertiesStorage);
694 name.initOpProperties(propCopy, prop);
698 static bool classof(
const IRRewrite *
rewrite) {
699 return rewrite->getKind() == Kind::ModifyOperation;
702 ~ModifyOperationRewrite()
override {
703 assert(!propertiesStorage &&
704 "rewrite was neither committed nor rolled back");
707 void commit(RewriterBase &rewriter)
override {
710 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
713 if (propertiesStorage) {
718 operator delete(propertiesStorage);
719 propertiesStorage =
nullptr;
723 void rollback()
override {
727 for (
const auto &it : llvm::enumerate(successors))
729 if (propertiesStorage) {
733 operator delete(propertiesStorage);
734 propertiesStorage =
nullptr;
741 DictionaryAttr attrs;
742 SmallVector<Value, 8> operands;
743 SmallVector<Block *, 2> successors;
744 void *propertiesStorage =
nullptr;
751class ReplaceOperationRewrite :
public OperationRewrite {
753 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
754 Operation *op,
const TypeConverter *converter)
755 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
756 converter(converter) {}
758 static bool classof(
const IRRewrite *
rewrite) {
759 return rewrite->getKind() == Kind::ReplaceOperation;
762 void commit(RewriterBase &rewriter)
override;
764 void rollback()
override;
766 void cleanup(RewriterBase &rewriter)
override;
771 const TypeConverter *converter;
774class CreateOperationRewrite :
public OperationRewrite {
776 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
778 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
780 static bool classof(
const IRRewrite *
rewrite) {
781 return rewrite->getKind() == Kind::CreateOperation;
784 void commit(RewriterBase &rewriter)
override {
790 void rollback()
override;
794enum MaterializationKind {
805class UnresolvedMaterializationInfo {
807 UnresolvedMaterializationInfo() =
default;
808 UnresolvedMaterializationInfo(
const TypeConverter *converter,
809 MaterializationKind kind, Type originalType)
810 : converterAndKind(converter, kind), originalType(originalType) {}
813 const TypeConverter *getConverter()
const {
814 return converterAndKind.getPointer();
818 MaterializationKind getMaterializationKind()
const {
819 return converterAndKind.getInt();
823 Type getOriginalType()
const {
return originalType; }
828 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
839class UnresolvedMaterializationRewrite :
public OperationRewrite {
841 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
842 UnrealizedConversionCastOp op,
844 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
845 mappedValues(std::move(mappedValues)) {}
847 static bool classof(
const IRRewrite *
rewrite) {
848 return rewrite->getKind() == Kind::UnresolvedMaterialization;
851 void rollback()
override;
853 UnrealizedConversionCastOp getOperation()
const {
854 return cast<UnrealizedConversionCastOp>(op);
864#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
867template <
typename RewriteTy,
typename R>
868static bool hasRewrite(R &&rewrites, Operation *op) {
869 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
870 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
871 return rewriteTy && rewriteTy->getOperation() == op;
877template <
typename RewriteTy,
typename R>
878static bool hasRewrite(R &&rewrites,
Block *block) {
879 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
880 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
881 return rewriteTy && rewriteTy->getBlock() == block;
893 const ConversionConfig &
config,
903 RewriterState getCurrentState();
907 void applyRewrites();
912 void resetState(RewriterState state, StringRef patternName =
"");
916 template <
typename RewriteTy,
typename... Args>
918 assert(
config.allowPatternRollback &&
"appending rewrites is not allowed");
920 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
926 void undoRewrites(
unsigned numRewritesToKeep = 0, StringRef patternName =
"");
932 LogicalResult remapValues(StringRef valueDiagTag,
933 std::optional<Location> inputLoc,
ValueRange values,
950 bool skipPureTypeConversions =
false)
const;
964 TypeConverter::SignatureConversion *entryConversion);
972 Block *applySignatureConversion(
974 TypeConverter::SignatureConversion &signatureConversion);
994 void eraseBlock(
Block *block);
1032 Value findOrBuildReplacementValue(
Value value,
1040 void notifyOperationInserted(
Operation *op,
1044 void notifyBlockInserted(
Block *block,
Region *previous,
1063 std::function<
void(
Operation *)> opErasedCallback =
nullptr)
1065 opErasedCallback(std::move(opErasedCallback)) {}
1079 assert(block->empty() &&
"expected empty block");
1080 block->dropAllDefinedValueUses();
1088 if (opErasedCallback)
1089 opErasedCallback(op);
1141 llvm::MapVector<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
1197const ConversionConfig &IRRewrite::getConfig()
const {
1198 return rewriterImpl.
config;
1201void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1205 if (
auto *listener =
1206 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1207 for (Operation *op : getNewBlock()->getUsers())
1211void BlockTypeConversionRewrite::rollback() {
1212 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1219 if (isa<BlockArgument>(repl)) {
1259 result &= functor(operand);
1264void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1271void ReplaceValueRewrite::rollback() {
1272 rewriterImpl.
mapping.erase({value});
1278void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1280 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1283 SmallVector<Value> replacements =
1285 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1293 for (
auto [
result, newValue] :
1294 llvm::zip_equal(op->
getResults(), replacements))
1300 if (getConfig().unlegalizedOps)
1301 getConfig().unlegalizedOps->erase(op);
1305 notifyIRErased(listener, *op);
1310 llvm::reportFatalInternalError(
1311 "dialect conversion attempted to replace a root operation that has no "
1312 "parent block; the pass must ensure its target op is nested in a "
1317void ReplaceOperationRewrite::rollback() {
1322void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1326void CreateOperationRewrite::rollback() {
1328 while (!region.getBlocks().empty())
1329 region.getBlocks().remove(region.getBlocks().begin());
1335void UnresolvedMaterializationRewrite::rollback() {
1336 if (!mappedValues.empty())
1337 rewriterImpl.
mapping.erase(mappedValues);
1348 for (
size_t i = 0; i <
rewrites.size(); ++i)
1354 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1355 unresolvedMaterializations.erase(castOp);
1358 rewrite->cleanup(eraseRewriter);
1366 Value from,
TypeRange desiredTypes,
bool skipPureTypeConversions)
const {
1369 assert(!values.empty() &&
"expected non-empty value vector");
1373 if (
config.allowPatternRollback)
1374 return mapping.lookup(values);
1381 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1386 if (castOp.getOutputs() != values)
1388 return castOp.getInputs();
1397 for (
Value v : values) {
1400 llvm::append_range(next, r);
1405 if (next != values) {
1434 if (skipPureTypeConversions) {
1437 match &= !pureConversion;
1440 if (!pureConversion)
1441 lastNonMaterialization = current;
1444 desiredValue = current;
1450 current = std::move(next);
1455 if (!desiredTypes.empty())
1456 return desiredValue;
1457 if (skipPureTypeConversions)
1458 return lastNonMaterialization;
1477 StringRef patternName) {
1482 while (
ignoredOps.size() != state.numIgnoredOperations)
1485 while (
replacedOps.size() != state.numReplacedOps)
1490 StringRef patternName) {
1492 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1494 rewrites.resize(numRewritesToKeep);
1498 StringRef valueDiagTag, std::optional<Location> inputLoc,
ValueRange values,
1500 remapped.reserve(llvm::size(values));
1502 for (
const auto &it : llvm::enumerate(values)) {
1503 Value operand = it.value();
1522 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1523 << it.index() <<
", type was " << origType;
1528 if (legalTypes.empty()) {
1529 remapped.push_back({});
1538 remapped.push_back(std::move(repl));
1547 repl, repl, legalTypes,
1549 remapped.push_back(castValues);
1570 TypeConverter::SignatureConversion *entryConversion) {
1572 if (region->
empty())
1577 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1579 std::optional<TypeConverter::SignatureConversion> conversion =
1580 converter.convertBlockSignature(&block);
1589 if (entryConversion)
1592 std::optional<TypeConverter::SignatureConversion> conversion =
1593 converter.convertBlockSignature(®ion->
front());
1601 TypeConverter::SignatureConversion &signatureConversion) {
1602#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1604 if (hasRewrite<BlockTypeConversionRewrite>(
rewrites, block))
1605 llvm::reportFatalInternalError(
"block was already converted");
1612 auto convertedTypes = signatureConversion.getConvertedTypes();
1619 for (
unsigned i = 0; i < origArgCount; ++i) {
1620 auto inputMap = signatureConversion.getInputMapping(i);
1621 if (!inputMap || inputMap->replacedWithValues())
1624 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1625 newLocs[inputMap->inputNo +
j] = origLoc;
1632 convertedTypes, newLocs);
1640 bool fastPath = !
config.listener;
1642 if (
config.allowPatternRollback)
1646 while (!block->
empty())
1653 for (
unsigned i = 0; i != origArgCount; ++i) {
1657 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1658 signatureConversion.getInputMapping(i);
1666 MaterializationKind::Source,
1670 origArgType,
Type(), converter,
1677 if (inputMap->replacedWithValues()) {
1679 assert(inputMap->size == 0 &&
1680 "invalid to provide a replacement value when the argument isn't "
1688 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1692 if (
config.allowPatternRollback)
1713 assert((!originalType || kind == MaterializationKind::Target) &&
1714 "original type is valid only for target materializations");
1715 assert(
TypeRange(inputs) != outputTypes &&
1716 "materialization is not necessary");
1720 OpBuilder builder(outputTypes.front().getContext());
1722 UnrealizedConversionCastOp convertOp =
1723 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1724 if (
config.attachDebugMaterializationKind) {
1726 kind == MaterializationKind::Source ?
"source" :
"target";
1727 convertOp->setAttr(
"__kind__", builder.
getStringAttr(kindStr));
1734 UnresolvedMaterializationInfo(converter, kind, originalType);
1735 if (
config.allowPatternRollback) {
1736 if (!valuesToMap.empty())
1737 mapping.map(valuesToMap, convertOp.getResults());
1739 std::move(valuesToMap));
1743 return convertOp.getResults();
1748 assert(
config.allowPatternRollback &&
1749 "this code path is valid only in rollback mode");
1756 return repl.front();
1763 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1788 MaterializationKind::Source, ip, value.
getLoc(),
1804 bool wasDetached = !previous.
isSet();
1806 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"' (" << op
1809 logger.getOStream() <<
" (was detached)";
1810 logger.getOStream() <<
"\n";
1816 "attempting to insert into a block within a replaced/erased op");
1820 config.listener->notifyOperationInserted(op, previous);
1829 if (
config.allowPatternRollback) {
1843 if (
config.allowPatternRollback)
1853 assert(!
impl.config.allowPatternRollback &&
1854 "this code path is valid only in 'no rollback' mode");
1856 for (
auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1859 repls.push_back(
Value());
1866 Value srcMat =
impl.buildUnresolvedMaterialization(
1871 repls.push_back(srcMat);
1877 repls.push_back(to[0]);
1886 Value srcMat =
impl.buildUnresolvedMaterialization(
1889 Type(), converter)[0];
1890 repls.push_back(srcMat);
1899 "incorrect number of replacement values");
1901 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
1909 for (
auto [
result, repls] :
1910 llvm::zip_equal(op->
getResults(), newValues)) {
1912 auto logProlog = [&, repls = repls]() {
1913 logger.startLine() <<
" Note: Replacing op result of type "
1914 << resultType <<
" with value(s) of type (";
1915 llvm::interleaveComma(repls,
logger.getOStream(), [&](
Value v) {
1916 logger.getOStream() << v.getType();
1918 logger.getOStream() <<
")";
1924 logger.getOStream() <<
", but the type converter failed to legalize "
1925 "the original type.\n";
1930 logger.getOStream() <<
", but the legalized type(s) is/are (";
1931 llvm::interleaveComma(convertedTypes,
logger.getOStream(),
1932 [&](
Type t) { logger.getOStream() << t; });
1933 logger.getOStream() <<
")\n";
1939 if (!
config.allowPatternRollback) {
1948 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1954 if (
config.unlegalizedOps)
1955 config.unlegalizedOps->erase(op);
1963 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1967 "attempting to replace a value that was already replaced");
1972 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1977 "attempting to replace/erase an unresolved materialization");
1993 logger.startLine() <<
"** Replace Value : '" << from <<
"'";
1994 if (
auto blockArg = dyn_cast<BlockArgument>(from)) {
1996 logger.getOStream() <<
" (in region of '" << parentOp->getName()
1997 <<
"' (" << parentOp <<
")";
1999 logger.getOStream() <<
" (unlinked block)";
2003 logger.getOStream() <<
", conditional replacement";
2007 if (!
config.allowPatternRollback) {
2012 Value repl = repls.front();
2029 "attempting to replace a value that was already replaced");
2031 "attempting to replace a op result that was already replaced");
2036 llvm::reportFatalInternalError(
2037 "conditional value replacement is not supported in rollback mode");
2043 if (!
config.allowPatternRollback) {
2050 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2056 if (
config.unlegalizedOps)
2057 config.unlegalizedOps->erase(op);
2066 "attempting to erase a block within a replaced/erased op");
2082 bool wasDetached = !previous;
2088 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
2089 <<
"' (" << parent <<
")";
2092 <<
"** Insert Block into detached Region (nullptr parent op)";
2095 logger.getOStream() <<
" (was detached)";
2096 logger.getOStream() <<
"\n";
2102 "attempting to insert into a region within a replaced/erased op");
2107 config.listener->notifyBlockInserted(block, previous, previousIt);
2111 if (
config.allowPatternRollback) {
2125 if (
config.allowPatternRollback)
2139 reasonCallback(
diag);
2140 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
2141 if (
config.notifyCallback)
2150ConversionPatternRewriter::ConversionPatternRewriter(
2154 *this, config, opConverter)) {
2155 setListener(
impl.get());
2158ConversionPatternRewriter::~ConversionPatternRewriter() =
default;
2160const ConversionConfig &ConversionPatternRewriter::getConfig()
const {
2161 return impl->config;
2165 assert(op && newOp &&
"expected non-null op");
2169void ConversionPatternRewriter::replaceOp(Operation *op,
ValueRange newValues) {
2171 "incorrect # of replacement values");
2175 if (getInsertionPoint() == op->getIterator())
2178 SmallVector<SmallVector<Value>> newVals =
2179 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2180 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2182 impl->replaceOp(op, std::move(newVals));
2185void ConversionPatternRewriter::replaceOpWithMultiple(
2186 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2188 "incorrect # of replacement values");
2192 if (getInsertionPoint() == op->getIterator())
2195 impl->replaceOp(op, std::move(newValues));
2198void ConversionPatternRewriter::eraseOp(Operation *op) {
2200 impl->logger.startLine()
2201 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
2206 if (getInsertionPoint() == op->getIterator())
2209 SmallVector<SmallVector<Value>> nullRepls(op->
getNumResults(), {});
2210 impl->replaceOp(op, std::move(nullRepls));
2213void ConversionPatternRewriter::eraseBlock(
Block *block) {
2214 impl->eraseBlock(block);
2217Block *ConversionPatternRewriter::applySignatureConversion(
2218 Block *block, TypeConverter::SignatureConversion &conversion,
2219 const TypeConverter *converter) {
2220 assert(!impl->wasOpReplaced(block->
getParentOp()) &&
2221 "attempting to apply a signature conversion to a block within a "
2222 "replaced/erased op");
2223 return impl->applySignatureConversion(block, converter, conversion);
2226FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2227 Region *region,
const TypeConverter &converter,
2228 TypeConverter::SignatureConversion *entryConversion) {
2229 assert(!impl->wasOpReplaced(region->
getParentOp()) &&
2230 "attempting to apply a signature conversion to a block within a "
2231 "replaced/erased op");
2232 return impl->convertRegionTypes(region, converter, entryConversion);
2235void ConversionPatternRewriter::replaceAllUsesWith(Value from,
ValueRange to) {
2236 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2239void ConversionPatternRewriter::replaceUsesWithIf(
2241 bool *allUsesReplaced) {
2242 assert(!allUsesReplaced &&
2243 "allUsesReplaced is not supported in a dialect conversion");
2244 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2247Value ConversionPatternRewriter::getRemappedValue(Value key) {
2248 SmallVector<ValueVector> remappedValues;
2249 if (
failed(impl->remapValues(
"value", std::nullopt, key,
2252 assert(remappedValues.front().size() == 1 &&
"1:N conversion not supported");
2253 return remappedValues.front().front();
2257ConversionPatternRewriter::getRemappedValues(
ValueRange keys,
2258 SmallVectorImpl<Value> &results) {
2261 SmallVector<ValueVector> remapped;
2262 if (
failed(impl->remapValues(
"value", std::nullopt, keys,
2265 for (
const auto &values : remapped) {
2266 assert(values.size() == 1 &&
"1:N conversion not supported");
2267 results.push_back(values.front());
2272void ConversionPatternRewriter::inlineBlockBefore(
Block *source,
Block *dest,
2277 "incorrect # of argument replacement values");
2278 assert(!impl->wasOpReplaced(source->
getParentOp()) &&
2279 "attempting to inline a block from a replaced/erased op");
2280 assert(!impl->wasOpReplaced(dest->
getParentOp()) &&
2281 "attempting to inline a block into a replaced/erased op");
2282 auto opIgnored = [&](Operation *op) {
return impl->isOpIgnored(op); };
2285 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
2286 "expected 'source' to have no predecessors");
2295 bool fastPath = !getConfig().listener;
2297 if (fastPath && impl->config.allowPatternRollback)
2298 impl->inlineBlockBefore(source, dest, before);
2301 for (
auto it : llvm::zip(source->
getArguments(), argValues))
2302 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2309 while (!source->
empty())
2310 moveOpBefore(&source->
front(), dest, before);
2315 if (getInsertionBlock() == source)
2316 setInsertionPoint(dest, getInsertionPoint());
2322void ConversionPatternRewriter::startOpModification(Operation *op) {
2323 if (!impl->config.allowPatternRollback) {
2328 assert(!impl->wasOpReplaced(op) &&
2329 "attempting to modify a replaced/erased op");
2331 impl->pendingRootUpdates.insert(op);
2333 impl->appendRewrite<ModifyOperationRewrite>(op);
2336void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2337 impl->patternModifiedOps.insert(op);
2338 if (!impl->config.allowPatternRollback) {
2340 if (getConfig().listener)
2341 getConfig().listener->notifyOperationModified(op);
2348 assert(!impl->wasOpReplaced(op) &&
2349 "attempting to modify a replaced/erased op");
2350 assert(impl->pendingRootUpdates.erase(op) &&
2351 "operation did not have a pending in-place update");
2355void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2356 if (!impl->config.allowPatternRollback) {
2361 assert(impl->pendingRootUpdates.erase(op) &&
2362 "operation did not have a pending in-place update");
2365 auto it = llvm::find_if(
2366 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
2367 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2368 return modifyRewrite && modifyRewrite->getOperation() == op;
2370 assert(it != impl->rewrites.rend() &&
"no root update started on op");
2372 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2373 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2376detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2384FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2385 ArrayRef<ValueRange> operands)
const {
2386 SmallVector<Value> oneToOneOperands;
2387 oneToOneOperands.reserve(operands.size());
2389 if (operand.size() != 1)
2392 oneToOneOperands.push_back(operand.front());
2394 return std::move(oneToOneOperands);
2398ConversionPattern::matchAndRewrite(Operation *op,
2399 PatternRewriter &rewriter)
const {
2400 auto &dialectRewriter =
static_cast<ConversionPatternRewriter &
>(rewriter);
2401 auto &rewriterImpl = dialectRewriter.getImpl();
2405 getTypeConverter());
2408 SmallVector<ValueVector> remapped;
2413 SmallVector<ValueRange> remappedAsRange =
2414 llvm::to_vector_of<ValueRange>(remapped);
2415 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2424using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2427class OperationLegalizer {
2429 using LegalizationAction = ConversionTarget::LegalizationAction;
2431 OperationLegalizer(ConversionPatternRewriter &rewriter,
2432 const ConversionTarget &targetInfo,
2433 const FrozenRewritePatternSet &patterns);
2436 bool isIllegal(Operation *op)
const;
2440 LogicalResult legalize(Operation *op);
2443 const ConversionTarget &getTarget() {
return target; }
2447 LogicalResult legalizeWithFold(Operation *op);
2451 LogicalResult legalizeWithPattern(Operation *op);
2455 bool canApplyPattern(Operation *op,
const Pattern &pattern);
2459 legalizePatternResult(Operation *op,
const Pattern &pattern,
2460 const RewriterState &curState,
2477 void buildLegalizationGraph(
2478 LegalizationPatterns &anyOpLegalizerPatterns,
2489 void computeLegalizationGraphBenefit(
2490 LegalizationPatterns &anyOpLegalizerPatterns,
2495 unsigned computeOpLegalizationDepth(
2502 unsigned applyCostModelToPatterns(
2503 LegalizationPatterns &patterns,
2508 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2511 ConversionPatternRewriter &rewriter;
2514 const ConversionTarget &
target;
2517 PatternApplicator applicator;
2521OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2522 const ConversionTarget &targetInfo,
2523 const FrozenRewritePatternSet &patterns)
2524 : rewriter(rewriter),
target(targetInfo), applicator(patterns) {
2528 LegalizationPatterns anyOpLegalizerPatterns;
2530 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2531 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2534bool OperationLegalizer::isIllegal(Operation *op)
const {
2535 return target.isIllegal(op);
2538LogicalResult OperationLegalizer::legalize(Operation *op) {
2540 const char *logLineComment =
2541 "//===-------------------------------------------===//\n";
2543 auto &logger = rewriter.getImpl().logger;
2547 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2550 logger.getOStream() <<
"\n";
2551 logger.startLine() << logLineComment;
2552 logger.startLine() <<
"Legalizing operation : ";
2557 logger.getOStream() <<
"'" << op->
getName() <<
"' ";
2558 logger.getOStream() <<
"(" << op <<
") {\n";
2563 logger.startLine() << OpWithFlags(op,
2564 OpPrintingFlags().printGenericOpForm())
2571 logSuccess(logger,
"operation marked 'ignored' during conversion");
2572 logger.startLine() << logLineComment;
2578 if (
auto legalityInfo =
target.isLegal(op)) {
2581 logger,
"operation marked legal by the target{0}",
2582 legalityInfo->isRecursivelyLegal
2583 ?
"; NOTE: operation is recursively legal; skipping internals"
2585 logger.startLine() << logLineComment;
2590 if (legalityInfo->isRecursivelyLegal) {
2591 op->
walk([&](Operation *nested) {
2593 rewriter.getImpl().ignoredOps.
insert(nested);
2602 const ConversionConfig &config = rewriter.getConfig();
2603 if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2604 if (succeeded(legalizeWithFold(op))) {
2607 logger.startLine() << logLineComment;
2614 if (succeeded(legalizeWithPattern(op))) {
2617 logger.startLine() << logLineComment;
2624 if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2625 if (succeeded(legalizeWithFold(op))) {
2628 logger.startLine() << logLineComment;
2635 logFailure(logger,
"no matched legalization pattern");
2636 logger.startLine() << logLineComment;
2643template <
typename T>
2645 T
result = std::move(obj);
2650LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2651 auto &rewriterImpl = rewriter.getImpl();
2653 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2654 rewriterImpl.
logger.indent();
2659 llvm::scope_exit cleanup([&]() {
2669 SmallVector<Value, 2> replacementValues;
2670 SmallVector<Operation *, 2> newOps;
2673 if (
failed(rewriter.
tryFold(op, replacementValues, &newOps))) {
2682 if (replacementValues.empty())
2683 return legalize(op);
2686 rewriter.
replaceOp(op, replacementValues);
2689 for (Operation *newOp : newOps) {
2690 if (
failed(legalize(newOp))) {
2692 "failed to legalize generated constant '{0}'",
2694 if (!rewriter.getConfig().allowPatternRollback) {
2696 llvm::reportFatalInternalError(
2698 "' folder rollback of IR modifications requested");
2716 auto newOpNames = llvm::map_range(
2718 auto modifiedOpNames = llvm::map_range(
2720 llvm::reportFatalInternalError(
"pattern '" + pattern.
getDebugName() +
2721 "' produced IR that could not be legalized. " +
2722 "new ops: {" + llvm::join(newOpNames,
", ") +
2723 "}, " +
"modified ops: {" +
2724 llvm::join(modifiedOpNames,
", ") +
"}");
2727LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2728 auto &rewriterImpl = rewriter.getImpl();
2729 const ConversionConfig &config = rewriter.getConfig();
2731#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2733 std::optional<OperationFingerPrint> topLevelFingerPrint;
2734 if (!rewriterImpl.
config.allowPatternRollback) {
2741 topLevelFingerPrint = OperationFingerPrint(checkOp);
2747 rewriterImpl.
logger.startLine()
2748 <<
"WARNING: Multi-threadeding is enabled. Some dialect "
2749 "conversion expensive checks are skipped in multithreading "
2758 auto canApply = [&](
const Pattern &pattern) {
2759 bool canApply = canApplyPattern(op, pattern);
2760 if (canApply && config.listener)
2761 config.listener->notifyPatternBegin(pattern, op);
2767 auto onFailure = [&](
const Pattern &pattern) {
2769 if (!rewriterImpl.
config.allowPatternRollback) {
2776#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2778 if (checkOp && topLevelFingerPrint) {
2779 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2780 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2781 llvm::reportFatalInternalError(
2782 "pattern '" + pattern.getDebugName() +
2783 "' returned failure but IR did change");
2791 if (rewriterImpl.
config.notifyCallback) {
2793 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2799 if (config.listener)
2800 config.listener->notifyPatternEnd(pattern, failure());
2801 rewriterImpl.
resetState(curState, pattern.getDebugName());
2802 appliedPatterns.erase(&pattern);
2807 auto onSuccess = [&](
const Pattern &pattern) {
2809 if (!rewriterImpl.
config.allowPatternRollback) {
2823 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2824 appliedPatterns.erase(&pattern);
2826 if (!rewriterImpl.
config.allowPatternRollback)
2828 rewriterImpl.
resetState(curState, pattern.getDebugName());
2830 if (config.listener)
2831 config.listener->notifyPatternEnd(pattern,
result);
2836 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2840bool OperationLegalizer::canApplyPattern(Operation *op,
2841 const Pattern &pattern) {
2843 auto &os = rewriter.getImpl().logger;
2844 os.getOStream() <<
"\n";
2845 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2847 os.getOStream() <<
")' {\n";
2854 !appliedPatterns.insert(&pattern).second) {
2856 logFailure(rewriter.getImpl().logger,
"pattern was already applied"));
2862LogicalResult OperationLegalizer::legalizePatternResult(
2863 Operation *op,
const Pattern &pattern,
const RewriterState &curState,
2866 [[maybe_unused]]
auto &impl = rewriter.getImpl();
2867 assert(impl.pendingRootUpdates.empty() &&
"dangling root updates");
2869#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2870 if (impl.config.allowPatternRollback) {
2872 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2873 auto replacedRoot = [&] {
2874 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2876 auto updatedRootInPlace = [&] {
2877 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2879 if (!replacedRoot() && !updatedRootInPlace())
2880 llvm::reportFatalInternalError(
2881 "expected pattern to replace the root operation "
2882 "or modify it in place");
2887 if (
failed(legalizePatternRootUpdates(modifiedOps)) ||
2888 failed(legalizePatternCreatedOperations(newOps))) {
2892 LLVM_DEBUG(
logSuccess(impl.logger,
"pattern applied successfully"));
2896LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2898 for (Operation *op : newOps) {
2899 if (
failed(legalize(op))) {
2900 LLVM_DEBUG(
logFailure(rewriter.getImpl().logger,
2901 "failed to legalize generated operation '{0}'({1})",
2909LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2911 for (Operation *op : modifiedOps) {
2912 if (
failed(legalize(op))) {
2915 "failed to legalize operation updated in-place '{0}'",
2927void OperationLegalizer::buildLegalizationGraph(
2928 LegalizationPatterns &anyOpLegalizerPatterns,
2939 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2940 std::optional<OperationName> root = pattern.
getRootKind();
2946 anyOpLegalizerPatterns.push_back(&pattern);
2951 if (
target.getOpAction(*root) == LegalizationAction::Legal)
2956 invalidPatterns[*root].insert(&pattern);
2958 parentOps[op].insert(*root);
2961 patternWorklist.insert(&pattern);
2969 if (!anyOpLegalizerPatterns.empty()) {
2970 for (
const Pattern *pattern : patternWorklist)
2971 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2975 while (!patternWorklist.empty()) {
2976 auto *pattern = patternWorklist.pop_back_val();
2980 std::optional<LegalizationAction> action = target.getOpAction(op);
2981 return !legalizerPatterns.count(op) &&
2982 (!action || action == LegalizationAction::Illegal);
2988 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2989 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2993 for (
auto op : parentOps[*pattern->
getRootKind()])
2994 patternWorklist.set_union(invalidPatterns[op]);
2998void OperationLegalizer::computeLegalizationGraphBenefit(
2999 LegalizationPatterns &anyOpLegalizerPatterns,
3005 for (
auto &opIt : legalizerPatterns)
3006 if (!minOpPatternDepth.count(opIt.first))
3007 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3013 if (!anyOpLegalizerPatterns.empty())
3014 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3020 applicator.applyCostModel([&](
const Pattern &pattern) {
3021 ArrayRef<const Pattern *> orderedPatternList;
3022 if (std::optional<OperationName> rootName = pattern.
getRootKind())
3023 orderedPatternList = legalizerPatterns[*rootName];
3025 orderedPatternList = anyOpLegalizerPatterns;
3028 auto *it = llvm::find(orderedPatternList, &pattern);
3029 if (it == orderedPatternList.end())
3033 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3037unsigned OperationLegalizer::computeOpLegalizationDepth(
3041 auto depthIt = minOpPatternDepth.find(op);
3042 if (depthIt != minOpPatternDepth.end())
3043 return depthIt->second;
3047 auto opPatternsIt = legalizerPatterns.find(op);
3048 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3053 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3057 unsigned minDepth = applyCostModelToPatterns(
3058 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3059 minOpPatternDepth[op] = minDepth;
3063unsigned OperationLegalizer::applyCostModelToPatterns(
3064 LegalizationPatterns &patterns,
3067 unsigned minDepth = std::numeric_limits<unsigned>::max();
3070 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3071 patternsByDepth.reserve(patterns.size());
3072 for (
const Pattern *pattern : patterns) {
3075 unsigned generatedOpDepth = computeOpLegalizationDepth(
3076 generatedOp, minOpPatternDepth, legalizerPatterns);
3077 depth = std::max(depth, generatedOpDepth + 1);
3079 patternsByDepth.emplace_back(pattern, depth);
3082 minDepth = std::min(minDepth, depth);
3087 if (patternsByDepth.size() == 1)
3091 llvm::stable_sort(patternsByDepth,
3092 [](
const std::pair<const Pattern *, unsigned> &
lhs,
3093 const std::pair<const Pattern *, unsigned> &
rhs) {
3096 if (
lhs.second !=
rhs.second)
3097 return lhs.second <
rhs.second;
3100 auto lhsBenefit =
lhs.first->getBenefit();
3101 auto rhsBenefit =
rhs.first->getBenefit();
3102 return lhsBenefit > rhsBenefit;
3107 for (
auto &patternIt : patternsByDepth)
3108 patterns.push_back(patternIt.first);
3122template <
typename RangeT>
3125 function_ref<
bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3134 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3135 if (castOp.getInputs().empty())
3138 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3141 if (inputCastOp.getOutputs() != castOp.getInputs())
3147 while (!worklist.empty()) {
3148 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3152 UnrealizedConversionCastOp nextCast = castOp;
3154 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3155 if (llvm::any_of(nextCast.getInputs(), [&](
Value v) {
3156 return v.getDefiningOp() == castOp;
3164 castOp.replaceAllUsesWith(nextCast.getInputs());
3167 nextCast = getInputCast(nextCast);
3177 auto markOpLive = [&](
Operation *rootOp) {
3179 worklist.push_back(rootOp);
3180 while (!worklist.empty()) {
3181 Operation *op = worklist.pop_back_val();
3182 if (liveOps.insert(op).second) {
3185 if (
auto castOp = v.
getDefiningOp<UnrealizedConversionCastOp>())
3186 if (isCastOpOfInterestFn(castOp))
3187 worklist.push_back(castOp);
3193 for (UnrealizedConversionCastOp op : castOps) {
3196 if (liveOps.contains(op.getOperation()))
3200 if (llvm::any_of(op->getUsers(), [&](
Operation *user) {
3201 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3202 return !castOp || !isCastOpOfInterestFn(castOp);
3208 for (UnrealizedConversionCastOp op : castOps) {
3209 if (liveOps.contains(op)) {
3211 if (remainingCastOps)
3212 remainingCastOps->push_back(op);
3223 ArrayRef<UnrealizedConversionCastOp> castOps,
3224 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3226 DenseSet<UnrealizedConversionCastOp> castOpSet;
3227 for (UnrealizedConversionCastOp op : castOps)
3228 castOpSet.insert(op);
3233 const DenseSet<UnrealizedConversionCastOp> &castOps,
3234 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3236 llvm::make_range(castOps.begin(), castOps.end()),
3237 [&](UnrealizedConversionCastOp castOp) {
3238 return castOps.contains(castOp);
3245 const llvm::MapVector<UnrealizedConversionCastOp,
3246 UnresolvedMaterializationInfo> &castOps,
3250 [&](UnrealizedConversionCastOp castOp) {
3251 return castOps.contains(castOp);
3268 const ConversionConfig &config,
3269 OpConversionMode mode)
3270 : rewriter(ctx, config, *this), opLegalizer(rewriter,
target, patterns),
3279 template <
typename Fn>
3281 bool isRecursiveLegalization =
false);
3283 bool isRecursiveLegalization =
false) {
3285 ops, [&]() {}, isRecursiveLegalization);
3293 LogicalResult convert(
Operation *op,
bool isRecursiveLegalization =
false);
3299 ConversionPatternRewriter rewriter;
3302 OperationLegalizer opLegalizer;
3305 OpConversionMode mode;
3310 bool isRecursiveLegalization) {
3311 const ConversionConfig &config = rewriter.getConfig();
3312 auto emitFailedToLegalizeDiag = [&](
bool wasExplicitlyIllegal) {
3314 <<
"failed to legalize operation '"
3316 if (wasExplicitlyIllegal)
3317 diag <<
" that was explicitly marked illegal";
3322 if (failed(opLegalizer.legalize(op))) {
3325 if (mode == OpConversionMode::Full) {
3326 if (!isRecursiveLegalization)
3327 emitFailedToLegalizeDiag(
false);
3333 if (mode == OpConversionMode::Partial) {
3334 if (opLegalizer.isIllegal(op)) {
3335 if (!isRecursiveLegalization)
3336 emitFailedToLegalizeDiag(
true);
3339 if (config.unlegalizedOps && !isRecursiveLegalization)
3340 config.unlegalizedOps->insert(op);
3342 }
else if (mode == OpConversionMode::Analysis) {
3346 if (config.legalizableOps && !isRecursiveLegalization)
3347 config.legalizableOps->insert(op);
3354 UnrealizedConversionCastOp op,
3355 const UnresolvedMaterializationInfo &info) {
3356 assert(!op.use_empty() &&
3357 "expected that dead materializations have already been DCE'd");
3364 switch (info.getMaterializationKind()) {
3365 case MaterializationKind::Target:
3366 newMaterialization = converter->materializeTargetConversion(
3367 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3368 info.getOriginalType());
3370 case MaterializationKind::Source:
3371 assert(op->getNumResults() == 1 &&
"expected single result");
3372 Value sourceMat = converter->materializeSourceConversion(
3373 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3375 newMaterialization.push_back(sourceMat);
3378 if (!newMaterialization.empty()) {
3380 ValueRange newMaterializationRange(newMaterialization);
3381 assert(
TypeRange(newMaterializationRange) == op.getResultTypes() &&
3382 "materialization callback produced value of incorrect type");
3384 rewriter.
replaceOp(op, newMaterialization);
3390 <<
"failed to legalize unresolved materialization "
3392 << inputOperands.
getTypes() <<
") to ("
3393 << op.getResultTypes()
3394 <<
") that remained live after conversion";
3395 diag.attachNote(op->getUsers().begin()->getLoc())
3396 <<
"see existing live user here: " << *op->getUsers().begin();
3400template <
typename Fn>
3403 bool isRecursiveLegalization) {
3411 toConvert.push_back(op);
3414 auto legalityInfo =
target.isLegal(op);
3415 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3421 if (failed(
convert(op, isRecursiveLegalization))) {
3430LogicalResult ConversionPatternRewriter::legalize(
Operation *op) {
3431 return impl->opConverter.legalizeOperations(op,
3435LogicalResult ConversionPatternRewriter::legalize(
Region *r) {
3451 std::optional<TypeConverter::SignatureConversion> conversion =
3452 converter->convertBlockSignature(&r->front());
3455 applySignatureConversion(&r->front(), *conversion, converter);
3460 return impl->opConverter.legalizeOperations(ops,
3470 if (rewriterImpl.
config.allowPatternRollback) {
3488 const llvm::MapVector<UnrealizedConversionCastOp,
3489 UnresolvedMaterializationInfo> &materializations =
3495 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3499 if (rewriter.getConfig().buildMaterializations) {
3503 rewriter.getConfig().listener);
3504 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3505 auto it = materializations.find(castOp);
3506 assert(it != materializations.end() &&
"inconsistent state");
3520void TypeConverter::SignatureConversion::addInputs(
unsigned origInputNo,
3522 assert(!types.empty() &&
"expected valid types");
3523 remapInput(origInputNo, argTypes.size(), types.size());
3527void TypeConverter::SignatureConversion::addInputs(
ArrayRef<Type> types) {
3528 assert(!types.empty() &&
3529 "1->0 type remappings don't need to be added explicitly");
3530 argTypes.append(types.begin(), types.end());
3533void TypeConverter::SignatureConversion::remapInput(
unsigned origInputNo,
3534 unsigned newInputNo,
3535 unsigned newInputCount) {
3536 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3537 assert(newInputCount != 0 &&
"expected valid input count");
3538 remappedInputs[origInputNo] =
3539 InputMapping{newInputNo, newInputCount, {}};
3542void TypeConverter::SignatureConversion::remapInput(
3543 unsigned origInputNo, ArrayRef<Value> replacements) {
3544 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3545 remappedInputs[origInputNo] = InputMapping{
3547 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3558TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3559 SmallVectorImpl<Type> &results)
const {
3560 assert(typeOrValue &&
"expected non-null type");
3561 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3562 : cast<Type>(typeOrValue);
3564 std::shared_lock<
decltype(cacheMutex)> cacheReadLock(cacheMutex,
3567 cacheReadLock.lock();
3568 auto existingIt = cachedDirectConversions.find(t);
3569 if (existingIt != cachedDirectConversions.end()) {
3570 if (existingIt->second)
3571 results.push_back(existingIt->second);
3572 return success(existingIt->second !=
nullptr);
3574 auto multiIt = cachedMultiConversions.find(t);
3575 if (multiIt != cachedMultiConversions.end()) {
3576 results.append(multiIt->second.begin(), multiIt->second.end());
3582 size_t currentCount = results.size();
3586 auto isCacheable = [&](
int index) {
3587 int numberOfConversionsUntilContextAware =
3588 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3589 return index < numberOfConversionsUntilContextAware;
3592 std::unique_lock<
decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3595 for (
auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3596 const ConversionCallbackFn &converter = indexedConverter.value();
3597 std::optional<LogicalResult>
result = converter(typeOrValue, results);
3599 assert(results.size() == currentCount &&
3600 "failed type conversion should not change results");
3603 if (!isCacheable(indexedConverter.index()))
3606 cacheWriteLock.lock();
3607 if (!succeeded(*
result)) {
3608 assert(results.size() == currentCount &&
3609 "failed type conversion should not change results");
3610 cachedDirectConversions.try_emplace(t,
nullptr);
3613 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3614 if (newTypes.size() == 1)
3615 cachedDirectConversions.try_emplace(t, newTypes.front());
3617 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3623LogicalResult TypeConverter::convertType(Type t,
3624 SmallVectorImpl<Type> &results)
const {
3625 return convertTypeImpl(t, results);
3628LogicalResult TypeConverter::convertType(Value v,
3629 SmallVectorImpl<Type> &results)
const {
3630 return convertTypeImpl(v, results);
3633Type TypeConverter::convertType(Type t)
const {
3635 SmallVector<Type, 1> results;
3636 if (
failed(convertType(t, results)))
3640 return results.size() == 1 ? results.front() :
nullptr;
3643Type TypeConverter::convertType(Value v)
const {
3645 SmallVector<Type, 1> results;
3646 if (
failed(convertType(v, results)))
3650 return results.size() == 1 ? results.front() :
nullptr;
3654TypeConverter::convertTypes(
TypeRange types,
3655 SmallVectorImpl<Type> &results)
const {
3656 for (Type type : types)
3657 if (
failed(convertType(type, results)))
3663TypeConverter::convertTypes(
ValueRange values,
3664 SmallVectorImpl<Type> &results)
const {
3665 for (Value value : values)
3666 if (
failed(convertType(value, results)))
3671bool TypeConverter::isLegal(Type type)
const {
3672 return convertType(type) == type;
3675bool TypeConverter::isLegal(Value value)
const {
3676 return convertType(value) == value.
getType();
3679bool TypeConverter::isLegal(Operation *op)
const {
3683bool TypeConverter::isLegal(Region *region)
const {
3684 return llvm::all_of(
3688bool TypeConverter::isSignatureLegal(FunctionType ty)
const {
3689 if (!isLegal(ty.getInputs()))
3691 if (!isLegal(ty.getResults()))
3697TypeConverter::convertSignatureArg(
unsigned inputNo, Type type,
3698 SignatureConversion &
result)
const {
3700 SmallVector<Type, 1> convertedTypes;
3701 if (
failed(convertType(type, convertedTypes)))
3705 if (convertedTypes.empty())
3709 result.addInputs(inputNo, convertedTypes);
3713TypeConverter::convertSignatureArgs(
TypeRange types,
3714 SignatureConversion &
result,
3715 unsigned origInputOffset)
const {
3716 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3717 if (
failed(convertSignatureArg(origInputOffset + i, types[i],
result)))
3722TypeConverter::convertSignatureArg(
unsigned inputNo, Value value,
3723 SignatureConversion &
result)
const {
3725 SmallVector<Type, 1> convertedTypes;
3726 if (
failed(convertType(value, convertedTypes)))
3730 if (convertedTypes.empty())
3734 result.addInputs(inputNo, convertedTypes);
3738TypeConverter::convertSignatureArgs(
ValueRange values,
3739 SignatureConversion &
result,
3740 unsigned origInputOffset)
const {
3741 for (
unsigned i = 0, e = values.size(); i != e; ++i)
3742 if (
failed(convertSignatureArg(origInputOffset + i, values[i],
result)))
3747Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3748 Location loc, Type resultType,
3750 for (
const SourceMaterializationCallbackFn &fn :
3751 llvm::reverse(sourceMaterializations))
3752 if (Value
result = fn(builder, resultType, inputs, loc))
3757Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3758 Location loc, Type resultType,
3760 Type originalType)
const {
3761 SmallVector<Value>
result = materializeTargetConversion(
3762 builder, loc,
TypeRange(resultType), inputs, originalType);
3765 assert(
result.size() == 1 &&
"expected single result");
3769SmallVector<Value> TypeConverter::materializeTargetConversion(
3771 Type originalType)
const {
3772 for (
const TargetMaterializationCallbackFn &fn :
3773 llvm::reverse(targetMaterializations)) {
3774 SmallVector<Value>
result =
3775 fn(builder, resultTypes, inputs, loc, originalType);
3779 "callback produced incorrect number of values or values with "
3786std::optional<TypeConverter::SignatureConversion>
3787TypeConverter::convertBlockSignature(
Block *block)
const {
3790 return std::nullopt;
3797TypeConverter::AttributeConversionResult
3798TypeConverter::AttributeConversionResult::result(Attribute attr) {
3799 return AttributeConversionResult(attr, resultTag);
3802TypeConverter::AttributeConversionResult
3803TypeConverter::AttributeConversionResult::na() {
3804 return AttributeConversionResult(
nullptr, naTag);
3807TypeConverter::AttributeConversionResult
3808TypeConverter::AttributeConversionResult::abort() {
3809 return AttributeConversionResult(
nullptr, abortTag);
3812bool TypeConverter::AttributeConversionResult::hasResult()
const {
3813 return impl.getInt() == resultTag;
3816bool TypeConverter::AttributeConversionResult::isNa()
const {
3817 return impl.getInt() == naTag;
3820bool TypeConverter::AttributeConversionResult::isAbort()
const {
3821 return impl.getInt() == abortTag;
3824Attribute TypeConverter::AttributeConversionResult::getResult()
const {
3825 assert(hasResult() &&
"Cannot get result from N/A or abort");
3826 return impl.getPointer();
3829std::optional<Attribute>
3830TypeConverter::convertTypeAttribute(Type type, Attribute attr)
const {
3831 for (
const TypeAttributeConversionCallbackFn &fn :
3832 llvm::reverse(typeAttributeConversions)) {
3833 AttributeConversionResult res = fn(type, attr);
3834 if (res.hasResult())
3835 return res.getResult();
3837 return std::nullopt;
3839 return std::nullopt;
3848 ConversionPatternRewriter &rewriter) {
3849 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3854 TypeConverter::SignatureConversion funcConversion(type.getNumInputs());
3856 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
3858 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3867 if (!funcOp.getFunctionBody().empty()) {
3868 Block *entryBlock = &funcOp.getFunctionBody().
front();
3870 unsigned numFuncTypeInputs = type.getNumInputs();
3871 TypeConverter::SignatureConversion blockConversion(numEntryBlockArgs);
3873 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
3878 for (
unsigned i = numFuncTypeInputs; i < numEntryBlockArgs; ++i)
3880 rewriter.applySignatureConversion(entryBlock, blockConversion,
3884 auto newType = FunctionType::get(
3885 rewriter.getContext(), funcConversion.getConvertedTypes(), newResults);
3887 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3896struct FunctionOpInterfaceSignatureConversion :
public ConversionPattern {
3897 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3899 const TypeConverter &converter,
3900 PatternBenefit benefit)
3901 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3904 matchAndRewrite(Operation *op, ArrayRef<Value> ,
3905 ConversionPatternRewriter &rewriter)
const override {
3906 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3911struct AnyFunctionOpInterfaceSignatureConversion
3912 :
public OpInterfaceConversionPattern<FunctionOpInterface> {
3913 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3916 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> ,
3917 ConversionPatternRewriter &rewriter)
const override {
3923FailureOr<Operation *>
3924mlir::convertOpResultTypes(Operation *op,
ValueRange operands,
3925 const TypeConverter &converter,
3926 ConversionPatternRewriter &rewriter) {
3927 assert(op &&
"Invalid op");
3928 Location loc = op->
getLoc();
3929 if (converter.isLegal(op))
3930 return rewriter.notifyMatchFailure(loc,
"op already legal");
3932 OperationState newOp(loc, op->
getName());
3933 newOp.addOperands(operands);
3935 SmallVector<Type> newResultTypes;
3937 return rewriter.notifyMatchFailure(loc,
"couldn't convert return types");
3939 newOp.addTypes(newResultTypes);
3940 newOp.addAttributes(op->
getAttrs());
3941 return rewriter.create(newOp);
3944void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3945 StringRef functionLikeOpName, RewritePatternSet &patterns,
3946 const TypeConverter &converter, PatternBenefit benefit) {
3947 patterns.
add<FunctionOpInterfaceSignatureConversion>(
3948 functionLikeOpName, patterns.
getContext(), converter, benefit);
3951void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3952 RewritePatternSet &patterns,
const TypeConverter &converter,
3953 PatternBenefit benefit) {
3954 patterns.
add<AnyFunctionOpInterfaceSignatureConversion>(
3962void ConversionTarget::setOpAction(OperationName op,
3963 LegalizationAction action) {
3964 legalOperations[op].action = action;
3967void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3968 LegalizationAction action) {
3969 for (StringRef dialect : dialectNames)
3970 legalDialects[dialect] = action;
3973auto ConversionTarget::getOpAction(OperationName op)
const
3974 -> std::optional<LegalizationAction> {
3975 std::optional<LegalizationInfo> info = getOpInfo(op);
3976 return info ? info->action : std::optional<LegalizationAction>();
3979auto ConversionTarget::isLegal(Operation *op)
const
3980 -> std::optional<LegalOpDetails> {
3981 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3983 return std::nullopt;
3986 auto isOpLegal = [&] {
3988 if (info->action == LegalizationAction::Dynamic) {
3989 std::optional<bool>
result = info->legalityFn(op);
3995 return info->action == LegalizationAction::Legal;
3998 return std::nullopt;
4001 LegalOpDetails legalityDetails;
4002 if (info->isRecursivelyLegal) {
4003 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
4004 if (legalityFnIt != opRecursiveLegalityFns.end()) {
4005 legalityDetails.isRecursivelyLegal =
4006 legalityFnIt->second(op).value_or(
true);
4008 legalityDetails.isRecursivelyLegal =
true;
4011 return legalityDetails;
4014bool ConversionTarget::isIllegal(Operation *op)
const {
4015 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
4019 if (info->action == LegalizationAction::Dynamic) {
4020 std::optional<bool>
result = info->legalityFn(op);
4027 return info->action == LegalizationAction::Illegal;
4031 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
4032 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
4036 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
4038 if (std::optional<bool>
result = newCl(op))
4046void ConversionTarget::setLegalityCallback(
4047 OperationName name,
const DynamicLegalityCallbackFn &callback) {
4048 assert(callback &&
"expected valid legality callback");
4049 auto *infoIt = legalOperations.find(name);
4050 assert(infoIt != legalOperations.end() &&
4051 infoIt->second.action == LegalizationAction::Dynamic &&
4052 "expected operation to already be marked as dynamically legal");
4053 infoIt->second.legalityFn =
4057void ConversionTarget::markOpRecursivelyLegal(
4058 OperationName name,
const DynamicLegalityCallbackFn &callback) {
4059 auto *infoIt = legalOperations.find(name);
4060 assert(infoIt != legalOperations.end() &&
4061 infoIt->second.action != LegalizationAction::Illegal &&
4062 "expected operation to already be marked as legal");
4063 infoIt->second.isRecursivelyLegal =
true;
4066 std::move(opRecursiveLegalityFns[name]), callback);
4068 opRecursiveLegalityFns.erase(name);
4071void ConversionTarget::setLegalityCallback(
4072 ArrayRef<StringRef> dialects,
const DynamicLegalityCallbackFn &callback) {
4073 assert(callback &&
"expected valid legality callback");
4074 for (StringRef dialect : dialects)
4076 std::move(dialectLegalityFns[dialect]), callback);
4079void ConversionTarget::setLegalityCallback(
4080 const DynamicLegalityCallbackFn &callback) {
4081 assert(callback &&
"expected valid legality callback");
4085auto ConversionTarget::getOpInfo(OperationName op)
const
4086 -> std::optional<LegalizationInfo> {
4088 const auto *it = legalOperations.find(op);
4089 if (it != legalOperations.end())
4092 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4093 if (dialectIt != legalDialects.end()) {
4094 DynamicLegalityCallbackFn callback;
4095 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4096 if (dialectFn != dialectLegalityFns.end())
4097 callback = dialectFn->second;
4098 return LegalizationInfo{dialectIt->second,
false,
4102 if (unknownLegalityFn)
4103 return LegalizationInfo{LegalizationAction::Dynamic,
4104 false, unknownLegalityFn};
4105 return std::nullopt;
4108#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4113void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4114 auto &rewriterImpl =
4115 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4119void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4120 auto &rewriterImpl =
4121 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4127static FailureOr<SmallVector<Value>>
4128pdllConvertValues(ConversionPatternRewriter &rewriter,
ValueRange values) {
4129 SmallVector<Value> mappedValues;
4130 if (
failed(rewriter.getRemappedValues(values, mappedValues)))
4132 return std::move(mappedValues);
4135void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
4138 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4139 auto results = pdllConvertValues(
4140 static_cast<ConversionPatternRewriter &
>(rewriter), value);
4143 return results->front();
4146 "convertValues", [](PatternRewriter &rewriter,
ValueRange values) {
4147 return pdllConvertValues(
4148 static_cast<ConversionPatternRewriter &
>(rewriter), values);
4152 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4153 auto &rewriterImpl =
4154 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4155 if (
const TypeConverter *converter =
4157 if (Type newType = converter->convertType(type))
4165 [](PatternRewriter &rewriter,
4166 TypeRange types) -> FailureOr<SmallVector<Type>> {
4167 auto &rewriterImpl =
4168 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4171 return SmallVector<Type>(types);
4173 SmallVector<Type> remappedTypes;
4174 if (
failed(converter->convertTypes(types, remappedTypes)))
4176 return std::move(remappedTypes);
4191 static constexpr StringLiteral
tag =
"apply-conversion";
4192 static constexpr StringLiteral
desc =
4193 "Encapsulate the application of a dialect conversion";
4201 ConversionConfig config,
4202 OpConversionMode mode) {
4206 LogicalResult status =
success();
4211 patterns, config, mode);
4222LogicalResult mlir::applyPartialConversion(
4223 ArrayRef<Operation *> ops,
const ConversionTarget &
target,
4224 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4226 OpConversionMode::Partial);
4229mlir::applyPartialConversion(Operation *op,
const ConversionTarget &
target,
4230 const FrozenRewritePatternSet &patterns,
4231 ConversionConfig config) {
4232 return applyPartialConversion(llvm::ArrayRef(op),
target, patterns, config);
4239LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4240 const ConversionTarget &
target,
4241 const FrozenRewritePatternSet &patterns,
4242 ConversionConfig config) {
4245LogicalResult mlir::applyFullConversion(Operation *op,
4246 const ConversionTarget &
target,
4247 const FrozenRewritePatternSet &patterns,
4248 ConversionConfig config) {
4249 return applyFullConversion(llvm::ArrayRef(op),
target, patterns, config);
4266 "expected top-level op to be isolated from above");
4269 "expected ops to have a common ancestor");
4278 for (
Operation *op : ops.drop_front()) {
4282 assert(commonAncestor &&
4283 "expected to find a common isolated from above ancestor");
4287 return commonAncestor;
4290LogicalResult mlir::applyAnalysisConversion(
4291 ArrayRef<Operation *> ops, ConversionTarget &
target,
4292 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4294 if (config.legalizableOps)
4295 assert(config.legalizableOps->empty() &&
"expected empty set");
4301 Operation *clonedAncestor = commonAncestor->
clone(mapping);
4305 inverseOperationMap[it.second] = it.first;
4308 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4309 ops, [&](Operation *op) {
return mapping.
lookup(op); });
4311 OpConversionMode::Analysis);
4315 if (config.legalizableOps) {
4317 for (Operation *op : *config.legalizableOps)
4318 originalLegalizableOps.insert(inverseOperationMap[op]);
4319 *config.legalizableOps = std::move(originalLegalizableOps);
4323 clonedAncestor->
erase();
4328mlir::applyAnalysisConversion(Operation *op, ConversionTarget &
target,
4329 const FrozenRewritePatternSet &patterns,
4330 ConversionConfig config) {
4331 return applyAnalysisConversion(llvm::ArrayRef(op),
target, patterns, config);
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
SmallVector< Value, 2 > ValueVector
A vector of SSA values, optimized for the most common case of one or two values.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
type_range getTypes() const
void destroyOpProperties(PropertyRef properties) const
This hooks destroy the op properties.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
TypeID getOpPropertiesTypeID() const
Return the TypeID of the op properties.
Operation is the basic unit of execution within MLIR.
PropertyRef getPropertiesStorage()
Return a generic (but typed) reference to the property type storage.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
void copyProperties(PropertyRef rhs)
Copy properties from an existing other properties object.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
CRTP Implementation of an action.
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
llvm::SetVector< T, Vector, Set, N > SetVector
static void reconcileUnrealizedCasts(const llvm::MapVector< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This iterator enumerates elements according to their dominance relationship.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
const ConversionTarget & getTarget()
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, bool isRecursiveLegalization=false)
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, Fn onFailure, bool isRecursiveLegalization=false)
Legalizes the given operations (and their nested operations) to the conversion target.
LogicalResult applyConversion(ArrayRef< Operation * > ops)
Applies the conversion to the given operations (and their nested operations).
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
bool wasErased(void *ptr) const
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
llvm::MapVector< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.