10#include "mlir/Config/mlir-config.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/FormatVariadic.h"
25#include "llvm/Support/SaveAndRestore.h"
26#include "llvm/Support/ScopedPrinter.h"
33#define DEBUG_TYPE "dialect-conversion"
36template <
typename... Args>
37static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
40 os.startLine() <<
"} -> SUCCESS";
42 os.getOStream() <<
" : "
43 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
44 os.getOStream() <<
"\n";
49template <
typename... Args>
50static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
53 os.startLine() <<
"} -> FAILURE : "
54 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
64 if (
OpResult inputRes = dyn_cast<OpResult>(value))
65 insertPt = ++inputRes.getOwner()->getIterator();
72 assert(!vals.empty() &&
"expected at least one value");
75 for (
Value v : vals.drop_front()) {
89 assert(dom &&
"unable to find valid insertion point");
97enum OpConversionMode {
123struct ValueVectorMapInfo {
126 static ::llvm::hash_code getHashValue(
const ValueVector &val) {
127 return ::llvm::hash_combine_range(val);
136struct ConversionValueMapping {
139 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
144 template <
typename T>
145 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
148 template <
typename OldVal,
typename NewVal>
149 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
150 map(OldVal &&oldVal, NewVal &&newVal) {
154 assert(next != oldVal &&
"inserting cyclic mapping");
155 auto it = mapping.find(next);
156 if (it == mapping.end())
161 mappedTo.insert_range(newVal);
163 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
167 template <
typename OldVal,
typename NewVal>
168 std::enable_if_t<!IsValueVector<OldVal>::value ||
169 !IsValueVector<NewVal>::value>
170 map(OldVal &&oldVal, NewVal &&newVal) {
171 if constexpr (IsValueVector<OldVal>{}) {
172 map(std::forward<OldVal>(oldVal),
ValueVector{newVal});
173 }
else if constexpr (IsValueVector<NewVal>{}) {
174 map(
ValueVector{oldVal}, std::forward<NewVal>(newVal));
185 void erase(
const ValueVector &value) { mapping.erase(value); }
205 assert(!values.empty() &&
"expected non-empty value vector");
206 Operation *op = values.front().getDefiningOp();
207 for (
Value v : llvm::drop_begin(values)) {
208 if (v.getDefiningOp() != op)
218 assert(!values.empty() &&
"expected non-empty value vector");
224 auto it = mapping.find(from);
225 if (it == mapping.end()) {
238struct RewriterState {
239 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
240 unsigned numReplacedOps)
241 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
242 numReplacedOps(numReplacedOps) {}
245 unsigned numRewrites;
248 unsigned numIgnoredOperations;
251 unsigned numReplacedOps;
258static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
261static void notifyIRErased(RewriterBase::Listener *listener,
Block &
b) {
262 for (Operation &op :
b)
263 notifyIRErased(listener, op);
269static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
272 notifyIRErased(listener,
b);
302 UnresolvedMaterialization,
307 virtual ~IRRewrite() =
default;
310 virtual void rollback() = 0;
324 virtual void commit(RewriterBase &rewriter) {}
327 virtual void cleanup(RewriterBase &rewriter) {}
329 Kind getKind()
const {
return kind; }
331 static bool classof(
const IRRewrite *
rewrite) {
return true; }
334 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
335 : kind(kind), rewriterImpl(rewriterImpl) {}
337 const ConversionConfig &getConfig()
const;
340 ConversionPatternRewriterImpl &rewriterImpl;
344class BlockRewrite :
public IRRewrite {
347 Block *getBlock()
const {
return block; }
349 static bool classof(
const IRRewrite *
rewrite) {
350 return rewrite->getKind() >= Kind::CreateBlock &&
351 rewrite->getKind() <= Kind::BlockTypeConversion;
355 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
357 : IRRewrite(kind, rewriterImpl), block(block) {}
364class ValueRewrite :
public IRRewrite {
367 Value getValue()
const {
return value; }
369 static bool classof(
const IRRewrite *
rewrite) {
370 return rewrite->getKind() == Kind::ReplaceValue;
374 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
376 : IRRewrite(kind, rewriterImpl), value(value) {}
385class CreateBlockRewrite :
public BlockRewrite {
387 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
388 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
390 static bool classof(
const IRRewrite *
rewrite) {
391 return rewrite->getKind() == Kind::CreateBlock;
394 void commit(RewriterBase &rewriter)
override {
400 void rollback()
override {
403 auto &blockOps = block->getOperations();
404 while (!blockOps.empty())
405 blockOps.remove(blockOps.begin());
406 block->dropAllUses();
407 if (block->getParent())
418class EraseBlockRewrite :
public BlockRewrite {
420 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
421 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
422 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
424 static bool classof(
const IRRewrite *
rewrite) {
425 return rewrite->getKind() == Kind::EraseBlock;
428 ~EraseBlockRewrite()
override {
430 "rewrite was neither rolled back nor committed/cleaned up");
433 void rollback()
override {
436 assert(block &&
"expected block");
441 blockList.insert(before, block);
445 void commit(RewriterBase &rewriter)
override {
446 assert(block &&
"expected block");
450 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
451 notifyIRErased(listener, *block);
454 void cleanup(RewriterBase &rewriter)
override {
456 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
458 assert(block->empty() &&
"expected empty block");
461 block->dropAllDefinedValueUses();
472 Block *insertBeforeBlock;
478class InlineBlockRewrite :
public BlockRewrite {
480 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
482 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
483 sourceBlock(sourceBlock),
484 firstInlinedInst(sourceBlock->empty() ?
nullptr
485 : &sourceBlock->front()),
486 lastInlinedInst(sourceBlock->empty() ?
nullptr : &sourceBlock->back()) {
492 assert(!getConfig().listener &&
493 "InlineBlockRewrite not supported if listener is attached");
496 static bool classof(
const IRRewrite *
rewrite) {
497 return rewrite->getKind() == Kind::InlineBlock;
500 void rollback()
override {
503 if (firstInlinedInst) {
504 assert(lastInlinedInst &&
"expected operation");
517 Operation *firstInlinedInst;
520 Operation *lastInlinedInst;
524class MoveBlockRewrite :
public BlockRewrite {
526 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
528 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block),
529 region(previousRegion),
530 insertBeforeBlock(previousIt == previousRegion->end() ?
nullptr
533 static bool classof(
const IRRewrite *
rewrite) {
534 return rewrite->getKind() == Kind::MoveBlock;
537 void commit(RewriterBase &rewriter)
override {
547 void rollback()
override {
560 Block *insertBeforeBlock;
564class BlockTypeConversionRewrite :
public BlockRewrite {
566 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
568 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, origBlock),
569 newBlock(newBlock) {}
571 static bool classof(
const IRRewrite *
rewrite) {
572 return rewrite->getKind() == Kind::BlockTypeConversion;
575 Block *getOrigBlock()
const {
return block; }
577 Block *getNewBlock()
const {
return newBlock; }
579 void commit(RewriterBase &rewriter)
override;
581 void rollback()
override;
591class ReplaceValueRewrite :
public ValueRewrite {
593 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
594 const TypeConverter *converter)
595 : ValueRewrite(
Kind::ReplaceValue, rewriterImpl, value),
596 converter(converter) {}
598 static bool classof(
const IRRewrite *
rewrite) {
599 return rewrite->getKind() == Kind::ReplaceValue;
602 void commit(RewriterBase &rewriter)
override;
604 void rollback()
override;
608 const TypeConverter *converter;
612class OperationRewrite :
public IRRewrite {
615 Operation *getOperation()
const {
return op; }
617 static bool classof(
const IRRewrite *
rewrite) {
618 return rewrite->getKind() >= Kind::MoveOperation &&
619 rewrite->getKind() <= Kind::UnresolvedMaterialization;
623 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
625 : IRRewrite(kind, rewriterImpl), op(op) {}
632class MoveOperationRewrite :
public OperationRewrite {
634 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
635 Operation *op, OpBuilder::InsertPoint previous)
636 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op),
637 block(previous.getBlock()),
638 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
640 : &*previous.getPoint()) {}
642 static bool classof(
const IRRewrite *
rewrite) {
643 return rewrite->getKind() == Kind::MoveOperation;
646 void commit(RewriterBase &rewriter)
override {
652 op, OpBuilder::InsertPoint(block,
657 void rollback()
override {
670 Operation *insertBeforeOp;
675class ModifyOperationRewrite :
public OperationRewrite {
677 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
679 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
680 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
681 operands(op->operand_begin(), op->operand_end()),
682 successors(op->successor_begin(), op->successor_end()) {
685 propertiesStorage = operator new(op->getPropertiesStorageSize());
686 OpaqueProperties propCopy(propertiesStorage);
687 name.initOpProperties(propCopy, prop);
691 static bool classof(
const IRRewrite *
rewrite) {
692 return rewrite->getKind() == Kind::ModifyOperation;
695 ~ModifyOperationRewrite()
override {
696 assert(!propertiesStorage &&
697 "rewrite was neither committed nor rolled back");
700 void commit(RewriterBase &rewriter)
override {
703 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
706 if (propertiesStorage) {
707 OpaqueProperties propCopy(propertiesStorage);
711 operator delete(propertiesStorage);
712 propertiesStorage =
nullptr;
716 void rollback()
override {
720 for (
const auto &it : llvm::enumerate(successors))
722 if (propertiesStorage) {
723 OpaqueProperties propCopy(propertiesStorage);
726 operator delete(propertiesStorage);
727 propertiesStorage =
nullptr;
734 DictionaryAttr attrs;
735 SmallVector<Value, 8> operands;
736 SmallVector<Block *, 2> successors;
737 void *propertiesStorage =
nullptr;
744class ReplaceOperationRewrite :
public OperationRewrite {
746 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
747 Operation *op,
const TypeConverter *converter)
748 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
749 converter(converter) {}
751 static bool classof(
const IRRewrite *
rewrite) {
752 return rewrite->getKind() == Kind::ReplaceOperation;
755 void commit(RewriterBase &rewriter)
override;
757 void rollback()
override;
759 void cleanup(RewriterBase &rewriter)
override;
764 const TypeConverter *converter;
767class CreateOperationRewrite :
public OperationRewrite {
769 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
771 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
773 static bool classof(
const IRRewrite *
rewrite) {
774 return rewrite->getKind() == Kind::CreateOperation;
777 void commit(RewriterBase &rewriter)
override {
783 void rollback()
override;
787enum MaterializationKind {
798class UnresolvedMaterializationInfo {
800 UnresolvedMaterializationInfo() =
default;
801 UnresolvedMaterializationInfo(
const TypeConverter *converter,
802 MaterializationKind kind, Type originalType)
803 : converterAndKind(converter, kind), originalType(originalType) {}
806 const TypeConverter *getConverter()
const {
807 return converterAndKind.getPointer();
811 MaterializationKind getMaterializationKind()
const {
812 return converterAndKind.getInt();
816 Type getOriginalType()
const {
return originalType; }
821 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
832class UnresolvedMaterializationRewrite :
public OperationRewrite {
834 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
835 UnrealizedConversionCastOp op,
837 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
838 mappedValues(std::move(mappedValues)) {}
840 static bool classof(
const IRRewrite *
rewrite) {
841 return rewrite->getKind() == Kind::UnresolvedMaterialization;
844 void rollback()
override;
846 UnrealizedConversionCastOp getOperation()
const {
847 return cast<UnrealizedConversionCastOp>(op);
857#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
860template <
typename RewriteTy,
typename R>
861static bool hasRewrite(R &&rewrites, Operation *op) {
862 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
863 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
864 return rewriteTy && rewriteTy->getOperation() == op;
870template <
typename RewriteTy,
typename R>
871static bool hasRewrite(R &&rewrites,
Block *block) {
872 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
873 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
874 return rewriteTy && rewriteTy->getBlock() == block;
886 const ConversionConfig &
config,
896 RewriterState getCurrentState();
900 void applyRewrites();
905 void resetState(RewriterState state, StringRef patternName =
"");
909 template <
typename RewriteTy,
typename... Args>
911 assert(
config.allowPatternRollback &&
"appending rewrites is not allowed");
913 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
919 void undoRewrites(
unsigned numRewritesToKeep = 0, StringRef patternName =
"");
925 LogicalResult remapValues(StringRef valueDiagTag,
926 std::optional<Location> inputLoc,
ValueRange values,
943 bool skipPureTypeConversions =
false)
const;
957 TypeConverter::SignatureConversion *entryConversion);
965 Block *applySignatureConversion(
967 TypeConverter::SignatureConversion &signatureConversion);
987 void eraseBlock(
Block *block);
1025 Value findOrBuildReplacementValue(
Value value,
1033 void notifyOperationInserted(
Operation *op,
1037 void notifyBlockInserted(
Block *block,
Region *previous,
1056 std::function<
void(
Operation *)> opErasedCallback =
nullptr)
1058 opErasedCallback(std::move(opErasedCallback)) {}
1072 assert(block->empty() &&
"expected empty block");
1073 block->dropAllDefinedValueUses();
1081 if (opErasedCallback)
1082 opErasedCallback(op);
1190const ConversionConfig &IRRewrite::getConfig()
const {
1191 return rewriterImpl.
config;
1194void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1198 if (
auto *listener =
1199 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1200 for (Operation *op : getNewBlock()->getUsers())
1204void BlockTypeConversionRewrite::rollback() {
1205 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1212 if (isa<BlockArgument>(repl)) {
1252 result &= functor(operand);
1257void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1264void ReplaceValueRewrite::rollback() {
1265 rewriterImpl.
mapping.erase({value});
1271void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1273 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1276 SmallVector<Value> replacements =
1278 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1286 for (
auto [
result, newValue] :
1287 llvm::zip_equal(op->
getResults(), replacements))
1293 if (getConfig().unlegalizedOps)
1294 getConfig().unlegalizedOps->erase(op);
1298 notifyIRErased(listener, *op);
1305void ReplaceOperationRewrite::rollback() {
1310void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1314void CreateOperationRewrite::rollback() {
1316 while (!region.getBlocks().empty())
1317 region.getBlocks().remove(region.getBlocks().begin());
1323void UnresolvedMaterializationRewrite::rollback() {
1324 if (!mappedValues.empty())
1325 rewriterImpl.
mapping.erase(mappedValues);
1336 for (
size_t i = 0; i <
rewrites.size(); ++i)
1342 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1343 unresolvedMaterializations.erase(castOp);
1346 rewrite->cleanup(eraseRewriter);
1354 Value from,
TypeRange desiredTypes,
bool skipPureTypeConversions)
const {
1357 assert(!values.empty() &&
"expected non-empty value vector");
1361 if (
config.allowPatternRollback)
1362 return mapping.lookup(values);
1369 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1374 if (castOp.getOutputs() != values)
1376 return castOp.getInputs();
1385 for (
Value v : values) {
1388 llvm::append_range(next, r);
1393 if (next != values) {
1422 if (skipPureTypeConversions) {
1425 match &= !pureConversion;
1428 if (!pureConversion)
1429 lastNonMaterialization = current;
1432 desiredValue = current;
1438 current = std::move(next);
1443 if (!desiredTypes.empty())
1444 return desiredValue;
1445 if (skipPureTypeConversions)
1446 return lastNonMaterialization;
1465 StringRef patternName) {
1470 while (
ignoredOps.size() != state.numIgnoredOperations)
1473 while (
replacedOps.size() != state.numReplacedOps)
1478 StringRef patternName) {
1480 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1482 rewrites.resize(numRewritesToKeep);
1486 StringRef valueDiagTag, std::optional<Location> inputLoc,
ValueRange values,
1488 remapped.reserve(llvm::size(values));
1490 for (
const auto &it : llvm::enumerate(values)) {
1491 Value operand = it.value();
1510 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1511 << it.index() <<
", type was " << origType;
1516 if (legalTypes.empty()) {
1517 remapped.push_back({});
1526 remapped.push_back(std::move(repl));
1535 repl, repl, legalTypes,
1537 remapped.push_back(castValues);
1558 TypeConverter::SignatureConversion *entryConversion) {
1560 if (region->
empty())
1565 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1567 std::optional<TypeConverter::SignatureConversion> conversion =
1568 converter.convertBlockSignature(&block);
1577 if (entryConversion)
1580 std::optional<TypeConverter::SignatureConversion> conversion =
1581 converter.convertBlockSignature(®ion->
front());
1589 TypeConverter::SignatureConversion &signatureConversion) {
1590#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1592 if (hasRewrite<BlockTypeConversionRewrite>(
rewrites, block))
1593 llvm::report_fatal_error(
"block was already converted");
1600 auto convertedTypes = signatureConversion.getConvertedTypes();
1607 for (
unsigned i = 0; i < origArgCount; ++i) {
1608 auto inputMap = signatureConversion.getInputMapping(i);
1609 if (!inputMap || inputMap->replacedWithValues())
1612 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1613 newLocs[inputMap->inputNo +
j] = origLoc;
1620 convertedTypes, newLocs);
1628 bool fastPath = !
config.listener;
1630 if (
config.allowPatternRollback)
1634 while (!block->
empty())
1641 for (
unsigned i = 0; i != origArgCount; ++i) {
1645 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1646 signatureConversion.getInputMapping(i);
1654 MaterializationKind::Source,
1658 origArgType,
Type(), converter,
1665 if (inputMap->replacedWithValues()) {
1667 assert(inputMap->size == 0 &&
1668 "invalid to provide a replacement value when the argument isn't "
1676 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1680 if (
config.allowPatternRollback)
1701 assert((!originalType || kind == MaterializationKind::Target) &&
1702 "original type is valid only for target materializations");
1703 assert(
TypeRange(inputs) != outputTypes &&
1704 "materialization is not necessary");
1708 OpBuilder builder(outputTypes.front().getContext());
1710 UnrealizedConversionCastOp convertOp =
1711 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1712 if (
config.attachDebugMaterializationKind) {
1714 kind == MaterializationKind::Source ?
"source" :
"target";
1715 convertOp->setAttr(
"__kind__", builder.
getStringAttr(kindStr));
1722 UnresolvedMaterializationInfo(converter, kind, originalType);
1723 if (
config.allowPatternRollback) {
1724 if (!valuesToMap.empty())
1725 mapping.map(valuesToMap, convertOp.getResults());
1727 std::move(valuesToMap));
1731 return convertOp.getResults();
1736 assert(
config.allowPatternRollback &&
1737 "this code path is valid only in rollback mode");
1744 return repl.front();
1751 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1776 MaterializationKind::Source, ip, value.
getLoc(),
1792 bool wasDetached = !previous.
isSet();
1794 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"' (" << op
1797 logger.getOStream() <<
" (was detached)";
1798 logger.getOStream() <<
"\n";
1804 "attempting to insert into a block within a replaced/erased op");
1808 config.listener->notifyOperationInserted(op, previous);
1817 if (
config.allowPatternRollback) {
1831 if (
config.allowPatternRollback)
1841 assert(!
impl.config.allowPatternRollback &&
1842 "this code path is valid only in 'no rollback' mode");
1844 for (
auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1847 repls.push_back(
Value());
1854 Value srcMat =
impl.buildUnresolvedMaterialization(
1859 repls.push_back(srcMat);
1865 repls.push_back(to[0]);
1874 Value srcMat =
impl.buildUnresolvedMaterialization(
1877 Type(), converter)[0];
1878 repls.push_back(srcMat);
1887 "incorrect number of replacement values");
1889 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
1897 for (
auto [
result, repls] :
1898 llvm::zip_equal(op->
getResults(), newValues)) {
1900 auto logProlog = [&, repls = repls]() {
1901 logger.startLine() <<
" Note: Replacing op result of type "
1902 << resultType <<
" with value(s) of type (";
1903 llvm::interleaveComma(repls,
logger.getOStream(), [&](
Value v) {
1904 logger.getOStream() << v.getType();
1906 logger.getOStream() <<
")";
1912 logger.getOStream() <<
", but the type converter failed to legalize "
1913 "the original type.\n";
1918 logger.getOStream() <<
", but the legalized type(s) is/are (";
1919 llvm::interleaveComma(convertedTypes,
logger.getOStream(),
1920 [&](
Type t) { logger.getOStream() << t; });
1921 logger.getOStream() <<
")\n";
1927 if (!
config.allowPatternRollback) {
1936 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1942 if (
config.unlegalizedOps)
1943 config.unlegalizedOps->erase(op);
1951 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1955 "attempting to replace a value that was already replaced");
1960 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1965 "attempting to replace/erase an unresolved materialization");
1981 logger.startLine() <<
"** Replace Value : '" << from <<
"'";
1982 if (
auto blockArg = dyn_cast<BlockArgument>(from)) {
1984 logger.getOStream() <<
" (in region of '" << parentOp->getName()
1985 <<
"' (" << parentOp <<
")";
1987 logger.getOStream() <<
" (unlinked block)";
1991 logger.getOStream() <<
", conditional replacement";
1995 if (!
config.allowPatternRollback) {
2000 Value repl = repls.front();
2017 "attempting to replace a value that was already replaced");
2019 "attempting to replace a op result that was already replaced");
2024 llvm::report_fatal_error(
2025 "conditional value replacement is not supported in rollback mode");
2031 if (!
config.allowPatternRollback) {
2038 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2044 if (
config.unlegalizedOps)
2045 config.unlegalizedOps->erase(op);
2054 "attempting to erase a block within a replaced/erased op");
2070 bool wasDetached = !previous;
2076 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
2077 <<
"' (" << parent <<
")";
2080 <<
"** Insert Block into detached Region (nullptr parent op)";
2083 logger.getOStream() <<
" (was detached)";
2084 logger.getOStream() <<
"\n";
2090 "attempting to insert into a region within a replaced/erased op");
2095 config.listener->notifyBlockInserted(block, previous, previousIt);
2099 if (
config.allowPatternRollback) {
2113 if (
config.allowPatternRollback)
2127 reasonCallback(
diag);
2128 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
2129 if (
config.notifyCallback)
2138ConversionPatternRewriter::ConversionPatternRewriter(
2142 *this,
config, opConverter)) {
2143 setListener(
impl.get());
2146ConversionPatternRewriter::~ConversionPatternRewriter() =
default;
2148const ConversionConfig &ConversionPatternRewriter::getConfig()
const {
2149 return impl->config;
2153 assert(op && newOp &&
"expected non-null op");
2157void ConversionPatternRewriter::replaceOp(Operation *op,
ValueRange newValues) {
2159 "incorrect # of replacement values");
2163 if (getInsertionPoint() == op->getIterator())
2166 SmallVector<SmallVector<Value>> newVals =
2167 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2168 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2170 impl->replaceOp(op, std::move(newVals));
2173void ConversionPatternRewriter::replaceOpWithMultiple(
2174 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2176 "incorrect # of replacement values");
2180 if (getInsertionPoint() == op->getIterator())
2183 impl->replaceOp(op, std::move(newValues));
2186void ConversionPatternRewriter::eraseOp(Operation *op) {
2188 impl->logger.startLine()
2189 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
2194 if (getInsertionPoint() == op->getIterator())
2197 SmallVector<SmallVector<Value>> nullRepls(op->
getNumResults(), {});
2198 impl->replaceOp(op, std::move(nullRepls));
2201void ConversionPatternRewriter::eraseBlock(
Block *block) {
2202 impl->eraseBlock(block);
2205Block *ConversionPatternRewriter::applySignatureConversion(
2206 Block *block, TypeConverter::SignatureConversion &conversion,
2207 const TypeConverter *converter) {
2208 assert(!impl->wasOpReplaced(block->
getParentOp()) &&
2209 "attempting to apply a signature conversion to a block within a "
2210 "replaced/erased op");
2211 return impl->applySignatureConversion(block, converter, conversion);
2214FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2215 Region *region,
const TypeConverter &converter,
2216 TypeConverter::SignatureConversion *entryConversion) {
2217 assert(!impl->wasOpReplaced(region->
getParentOp()) &&
2218 "attempting to apply a signature conversion to a block within a "
2219 "replaced/erased op");
2220 return impl->convertRegionTypes(region, converter, entryConversion);
2223void ConversionPatternRewriter::replaceAllUsesWith(Value from,
ValueRange to) {
2224 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2227void ConversionPatternRewriter::replaceUsesWithIf(
2229 bool *allUsesReplaced) {
2230 assert(!allUsesReplaced &&
2231 "allUsesReplaced is not supported in a dialect conversion");
2232 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2235Value ConversionPatternRewriter::getRemappedValue(Value key) {
2236 SmallVector<ValueVector> remappedValues;
2237 if (
failed(impl->remapValues(
"value", std::nullopt, key,
2240 assert(remappedValues.front().size() == 1 &&
"1:N conversion not supported");
2241 return remappedValues.front().front();
2245ConversionPatternRewriter::getRemappedValues(
ValueRange keys,
2246 SmallVectorImpl<Value> &results) {
2249 SmallVector<ValueVector> remapped;
2250 if (
failed(impl->remapValues(
"value", std::nullopt, keys,
2253 for (
const auto &values : remapped) {
2254 assert(values.size() == 1 &&
"1:N conversion not supported");
2255 results.push_back(values.front());
2260void ConversionPatternRewriter::inlineBlockBefore(
Block *source,
Block *dest,
2265 "incorrect # of argument replacement values");
2266 assert(!impl->wasOpReplaced(source->
getParentOp()) &&
2267 "attempting to inline a block from a replaced/erased op");
2268 assert(!impl->wasOpReplaced(dest->
getParentOp()) &&
2269 "attempting to inline a block into a replaced/erased op");
2270 auto opIgnored = [&](Operation *op) {
return impl->isOpIgnored(op); };
2273 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
2274 "expected 'source' to have no predecessors");
2283 bool fastPath = !getConfig().listener;
2285 if (fastPath && impl->config.allowPatternRollback)
2286 impl->inlineBlockBefore(source, dest, before);
2289 for (
auto it : llvm::zip(source->
getArguments(), argValues))
2290 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2297 while (!source->
empty())
2298 moveOpBefore(&source->
front(), dest, before);
2303 if (getInsertionBlock() == source)
2304 setInsertionPoint(dest, getInsertionPoint());
2310void ConversionPatternRewriter::startOpModification(Operation *op) {
2311 if (!impl->config.allowPatternRollback) {
2316 assert(!impl->wasOpReplaced(op) &&
2317 "attempting to modify a replaced/erased op");
2319 impl->pendingRootUpdates.insert(op);
2321 impl->appendRewrite<ModifyOperationRewrite>(op);
2324void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2325 impl->patternModifiedOps.insert(op);
2326 if (!impl->config.allowPatternRollback) {
2328 if (getConfig().listener)
2329 getConfig().listener->notifyOperationModified(op);
2336 assert(!impl->wasOpReplaced(op) &&
2337 "attempting to modify a replaced/erased op");
2338 assert(impl->pendingRootUpdates.erase(op) &&
2339 "operation did not have a pending in-place update");
2343void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2344 if (!impl->config.allowPatternRollback) {
2349 assert(impl->pendingRootUpdates.erase(op) &&
2350 "operation did not have a pending in-place update");
2353 auto it = llvm::find_if(
2354 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
2355 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2356 return modifyRewrite && modifyRewrite->getOperation() == op;
2358 assert(it != impl->rewrites.rend() &&
"no root update started on op");
2360 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2361 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2364detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2372FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2373 ArrayRef<ValueRange> operands)
const {
2374 SmallVector<Value> oneToOneOperands;
2375 oneToOneOperands.reserve(operands.size());
2377 if (operand.size() != 1)
2380 oneToOneOperands.push_back(operand.front());
2382 return std::move(oneToOneOperands);
2386ConversionPattern::matchAndRewrite(Operation *op,
2387 PatternRewriter &rewriter)
const {
2388 auto &dialectRewriter =
static_cast<ConversionPatternRewriter &
>(rewriter);
2389 auto &rewriterImpl = dialectRewriter.getImpl();
2393 getTypeConverter());
2396 SmallVector<ValueVector> remapped;
2401 SmallVector<ValueRange> remappedAsRange =
2402 llvm::to_vector_of<ValueRange>(remapped);
2403 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2412using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2415class OperationLegalizer {
2417 using LegalizationAction = ConversionTarget::LegalizationAction;
2419 OperationLegalizer(ConversionPatternRewriter &rewriter,
2420 const ConversionTarget &targetInfo,
2421 const FrozenRewritePatternSet &
patterns);
2424 bool isIllegal(Operation *op)
const;
2428 LogicalResult legalize(Operation *op);
2431 const ConversionTarget &getTarget() {
return target; }
2435 LogicalResult legalizeWithFold(Operation *op);
2439 LogicalResult legalizeWithPattern(Operation *op);
2443 bool canApplyPattern(Operation *op,
const Pattern &pattern);
2447 legalizePatternResult(Operation *op,
const Pattern &pattern,
2448 const RewriterState &curState,
2465 void buildLegalizationGraph(
2466 LegalizationPatterns &anyOpLegalizerPatterns,
2477 void computeLegalizationGraphBenefit(
2478 LegalizationPatterns &anyOpLegalizerPatterns,
2483 unsigned computeOpLegalizationDepth(
2490 unsigned applyCostModelToPatterns(
2496 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2499 ConversionPatternRewriter &rewriter;
2502 const ConversionTarget &
target;
2505 PatternApplicator applicator;
2509OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2510 const ConversionTarget &targetInfo,
2511 const FrozenRewritePatternSet &
patterns)
2516 LegalizationPatterns anyOpLegalizerPatterns;
2518 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2519 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2522bool OperationLegalizer::isIllegal(Operation *op)
const {
2523 return target.isIllegal(op);
2526LogicalResult OperationLegalizer::legalize(Operation *op) {
2528 const char *logLineComment =
2529 "//===-------------------------------------------===//\n";
2531 auto &logger = rewriter.getImpl().logger;
2535 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2538 logger.getOStream() <<
"\n";
2539 logger.startLine() << logLineComment;
2540 logger.startLine() <<
"Legalizing operation : ";
2545 logger.getOStream() <<
"'" << op->
getName() <<
"' ";
2546 logger.getOStream() <<
"(" << op <<
") {\n";
2551 logger.startLine() << OpWithFlags(op,
2552 OpPrintingFlags().printGenericOpForm())
2559 logSuccess(logger,
"operation marked 'ignored' during conversion");
2560 logger.startLine() << logLineComment;
2566 if (
auto legalityInfo =
target.isLegal(op)) {
2569 logger,
"operation marked legal by the target{0}",
2570 legalityInfo->isRecursivelyLegal
2571 ?
"; NOTE: operation is recursively legal; skipping internals"
2573 logger.startLine() << logLineComment;
2578 if (legalityInfo->isRecursivelyLegal) {
2579 op->
walk([&](Operation *nested) {
2581 rewriter.getImpl().ignoredOps.
insert(nested);
2590 const ConversionConfig &
config = rewriter.getConfig();
2591 if (
config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2592 if (succeeded(legalizeWithFold(op))) {
2595 logger.startLine() << logLineComment;
2602 if (succeeded(legalizeWithPattern(op))) {
2605 logger.startLine() << logLineComment;
2612 if (
config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2613 if (succeeded(legalizeWithFold(op))) {
2616 logger.startLine() << logLineComment;
2623 logFailure(logger,
"no matched legalization pattern");
2624 logger.startLine() << logLineComment;
2631template <
typename T>
2633 T
result = std::move(obj);
2638LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2639 auto &rewriterImpl = rewriter.getImpl();
2641 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2642 rewriterImpl.
logger.indent();
2647 auto cleanup = llvm::make_scope_exit([&]() {
2657 SmallVector<Value, 2> replacementValues;
2658 SmallVector<Operation *, 2> newOps;
2661 if (
failed(rewriter.
tryFold(op, replacementValues, &newOps))) {
2670 if (replacementValues.empty())
2671 return legalize(op);
2674 rewriter.
replaceOp(op, replacementValues);
2677 for (Operation *newOp : newOps) {
2678 if (
failed(legalize(newOp))) {
2680 "failed to legalize generated constant '{0}'",
2682 if (!rewriter.getConfig().allowPatternRollback) {
2684 llvm::report_fatal_error(
2686 "' folder rollback of IR modifications requested");
2704 auto newOpNames = llvm::map_range(
2706 auto modifiedOpNames = llvm::map_range(
2708 llvm::report_fatal_error(
"pattern '" + pattern.
getDebugName() +
2709 "' produced IR that could not be legalized. " +
2710 "new ops: {" + llvm::join(newOpNames,
", ") +
"}, " +
2712 llvm::join(modifiedOpNames,
", ") +
"}");
2715LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2716 auto &rewriterImpl = rewriter.getImpl();
2717 const ConversionConfig &
config = rewriter.getConfig();
2719#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2721 std::optional<OperationFingerPrint> topLevelFingerPrint;
2722 if (!rewriterImpl.
config.allowPatternRollback) {
2729 topLevelFingerPrint = OperationFingerPrint(checkOp);
2735 rewriterImpl.
logger.startLine()
2736 <<
"WARNING: Multi-threadeding is enabled. Some dialect "
2737 "conversion expensive checks are skipped in multithreading "
2746 auto canApply = [&](
const Pattern &pattern) {
2747 bool canApply = canApplyPattern(op, pattern);
2748 if (canApply &&
config.listener)
2749 config.listener->notifyPatternBegin(pattern, op);
2755 auto onFailure = [&](
const Pattern &pattern) {
2757 if (!rewriterImpl.
config.allowPatternRollback) {
2764#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2766 if (checkOp && topLevelFingerPrint) {
2767 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2768 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2769 llvm::report_fatal_error(
"pattern '" + pattern.getDebugName() +
2770 "' returned failure but IR did change");
2778 if (rewriterImpl.
config.notifyCallback) {
2780 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2787 config.listener->notifyPatternEnd(pattern, failure());
2788 rewriterImpl.
resetState(curState, pattern.getDebugName());
2789 appliedPatterns.erase(&pattern);
2794 auto onSuccess = [&](
const Pattern &pattern) {
2796 if (!rewriterImpl.
config.allowPatternRollback) {
2810 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2811 appliedPatterns.erase(&pattern);
2813 if (!rewriterImpl.
config.allowPatternRollback)
2815 rewriterImpl.
resetState(curState, pattern.getDebugName());
2823 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2827bool OperationLegalizer::canApplyPattern(Operation *op,
2828 const Pattern &pattern) {
2830 auto &os = rewriter.getImpl().logger;
2831 os.getOStream() <<
"\n";
2832 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2834 os.getOStream() <<
")' {\n";
2841 !appliedPatterns.insert(&pattern).second) {
2843 logFailure(rewriter.getImpl().logger,
"pattern was already applied"));
2849LogicalResult OperationLegalizer::legalizePatternResult(
2850 Operation *op,
const Pattern &pattern,
const RewriterState &curState,
2853 [[maybe_unused]]
auto &impl = rewriter.getImpl();
2854 assert(impl.pendingRootUpdates.empty() &&
"dangling root updates");
2856#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2857 if (impl.config.allowPatternRollback) {
2859 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2860 auto replacedRoot = [&] {
2861 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2863 auto updatedRootInPlace = [&] {
2864 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2866 if (!replacedRoot() && !updatedRootInPlace())
2867 llvm::report_fatal_error(
"expected pattern to replace the root operation "
2868 "or modify it in place");
2873 if (
failed(legalizePatternRootUpdates(modifiedOps)) ||
2874 failed(legalizePatternCreatedOperations(newOps))) {
2878 LLVM_DEBUG(
logSuccess(impl.logger,
"pattern applied successfully"));
2882LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2884 for (Operation *op : newOps) {
2885 if (
failed(legalize(op))) {
2886 LLVM_DEBUG(
logFailure(rewriter.getImpl().logger,
2887 "failed to legalize generated operation '{0}'({1})",
2895LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2897 for (Operation *op : modifiedOps) {
2898 if (
failed(legalize(op))) {
2901 "failed to legalize operation updated in-place '{0}'",
2913void OperationLegalizer::buildLegalizationGraph(
2914 LegalizationPatterns &anyOpLegalizerPatterns,
2925 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2926 std::optional<OperationName> root = pattern.
getRootKind();
2932 anyOpLegalizerPatterns.push_back(&pattern);
2937 if (
target.getOpAction(*root) == LegalizationAction::Legal)
2942 invalidPatterns[*root].insert(&pattern);
2944 parentOps[op].insert(*root);
2947 patternWorklist.insert(&pattern);
2955 if (!anyOpLegalizerPatterns.empty()) {
2956 for (
const Pattern *pattern : patternWorklist)
2957 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2961 while (!patternWorklist.empty()) {
2962 auto *pattern = patternWorklist.pop_back_val();
2966 std::optional<LegalizationAction> action = target.getOpAction(op);
2967 return !legalizerPatterns.count(op) &&
2968 (!action || action == LegalizationAction::Illegal);
2974 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2975 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2979 for (
auto op : parentOps[*pattern->
getRootKind()])
2980 patternWorklist.set_union(invalidPatterns[op]);
2984void OperationLegalizer::computeLegalizationGraphBenefit(
2985 LegalizationPatterns &anyOpLegalizerPatterns,
2991 for (
auto &opIt : legalizerPatterns)
2992 if (!minOpPatternDepth.count(opIt.first))
2993 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2999 if (!anyOpLegalizerPatterns.empty())
3000 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3006 applicator.applyCostModel([&](
const Pattern &pattern) {
3007 ArrayRef<const Pattern *> orderedPatternList;
3008 if (std::optional<OperationName> rootName = pattern.
getRootKind())
3009 orderedPatternList = legalizerPatterns[*rootName];
3011 orderedPatternList = anyOpLegalizerPatterns;
3014 auto *it = llvm::find(orderedPatternList, &pattern);
3015 if (it == orderedPatternList.end())
3019 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3023unsigned OperationLegalizer::computeOpLegalizationDepth(
3027 auto depthIt = minOpPatternDepth.find(op);
3028 if (depthIt != minOpPatternDepth.end())
3029 return depthIt->second;
3033 auto opPatternsIt = legalizerPatterns.find(op);
3034 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3039 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3043 unsigned minDepth = applyCostModelToPatterns(
3044 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3045 minOpPatternDepth[op] = minDepth;
3049unsigned OperationLegalizer::applyCostModelToPatterns(
3053 unsigned minDepth = std::numeric_limits<unsigned>::max();
3056 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3057 patternsByDepth.reserve(
patterns.size());
3058 for (
const Pattern *pattern :
patterns) {
3061 unsigned generatedOpDepth = computeOpLegalizationDepth(
3062 generatedOp, minOpPatternDepth, legalizerPatterns);
3063 depth = std::max(depth, generatedOpDepth + 1);
3065 patternsByDepth.emplace_back(pattern, depth);
3068 minDepth = std::min(minDepth, depth);
3073 if (patternsByDepth.size() == 1)
3077 llvm::stable_sort(patternsByDepth,
3078 [](
const std::pair<const Pattern *, unsigned> &
lhs,
3079 const std::pair<const Pattern *, unsigned> &
rhs) {
3082 if (
lhs.second !=
rhs.second)
3083 return lhs.second <
rhs.second;
3086 auto lhsBenefit =
lhs.first->getBenefit();
3087 auto rhsBenefit =
rhs.first->getBenefit();
3088 return lhsBenefit > rhsBenefit;
3093 for (
auto &patternIt : patternsByDepth)
3094 patterns.push_back(patternIt.first);
3108template <
typename RangeT>
3111 function_ref<
bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3120 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3121 if (castOp.getInputs().empty())
3124 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3127 if (inputCastOp.getOutputs() != castOp.getInputs())
3133 while (!worklist.empty()) {
3134 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3138 UnrealizedConversionCastOp nextCast = castOp;
3140 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3141 if (llvm::any_of(nextCast.getInputs(), [&](
Value v) {
3142 return v.getDefiningOp() == castOp;
3150 castOp.replaceAllUsesWith(nextCast.getInputs());
3153 nextCast = getInputCast(nextCast);
3163 auto markOpLive = [&](
Operation *rootOp) {
3165 worklist.push_back(rootOp);
3166 while (!worklist.empty()) {
3167 Operation *op = worklist.pop_back_val();
3168 if (liveOps.insert(op).second) {
3171 if (
auto castOp = v.
getDefiningOp<UnrealizedConversionCastOp>())
3172 if (isCastOpOfInterestFn(castOp))
3173 worklist.push_back(castOp);
3179 for (UnrealizedConversionCastOp op : castOps) {
3182 if (liveOps.contains(op.getOperation()))
3186 if (llvm::any_of(op->getUsers(), [&](
Operation *user) {
3187 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3188 return !castOp || !isCastOpOfInterestFn(castOp);
3194 for (UnrealizedConversionCastOp op : castOps) {
3195 if (liveOps.contains(op)) {
3197 if (remainingCastOps)
3198 remainingCastOps->push_back(op);
3209 ArrayRef<UnrealizedConversionCastOp> castOps,
3210 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3212 DenseSet<UnrealizedConversionCastOp> castOpSet;
3213 for (UnrealizedConversionCastOp op : castOps)
3214 castOpSet.insert(op);
3219 const DenseSet<UnrealizedConversionCastOp> &castOps,
3220 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3222 llvm::make_range(castOps.begin(), castOps.end()),
3223 [&](UnrealizedConversionCastOp castOp) {
3224 return castOps.contains(castOp);
3236 [&](UnrealizedConversionCastOp castOp) {
3237 return castOps.contains(castOp);
3254 const ConversionConfig &
config,
3255 OpConversionMode mode)
3265 template <
typename Fn>
3267 bool isRecursiveLegalization =
false);
3269 bool isRecursiveLegalization =
false) {
3271 ops, [&]() {}, isRecursiveLegalization);
3279 LogicalResult convert(
Operation *op,
bool isRecursiveLegalization =
false);
3285 ConversionPatternRewriter rewriter;
3288 OperationLegalizer opLegalizer;
3291 OpConversionMode mode;
3296 bool isRecursiveLegalization) {
3297 const ConversionConfig &
config = rewriter.getConfig();
3300 if (failed(opLegalizer.legalize(op))) {
3303 if (mode == OpConversionMode::Full) {
3304 if (!isRecursiveLegalization)
3312 if (mode == OpConversionMode::Partial) {
3313 if (opLegalizer.isIllegal(op)) {
3314 if (!isRecursiveLegalization)
3316 <<
"' that was explicitly marked illegal";
3319 if (
config.unlegalizedOps && !isRecursiveLegalization)
3320 config.unlegalizedOps->insert(op);
3322 }
else if (mode == OpConversionMode::Analysis) {
3326 if (
config.legalizableOps && !isRecursiveLegalization)
3327 config.legalizableOps->insert(op);
3334 UnrealizedConversionCastOp op,
3335 const UnresolvedMaterializationInfo &info) {
3336 assert(!op.use_empty() &&
3337 "expected that dead materializations have already been DCE'd");
3344 switch (info.getMaterializationKind()) {
3345 case MaterializationKind::Target:
3346 newMaterialization = converter->materializeTargetConversion(
3347 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3348 info.getOriginalType());
3350 case MaterializationKind::Source:
3351 assert(op->getNumResults() == 1 &&
"expected single result");
3352 Value sourceMat = converter->materializeSourceConversion(
3353 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3355 newMaterialization.push_back(sourceMat);
3358 if (!newMaterialization.empty()) {
3360 ValueRange newMaterializationRange(newMaterialization);
3361 assert(
TypeRange(newMaterializationRange) == op.getResultTypes() &&
3362 "materialization callback produced value of incorrect type");
3364 rewriter.
replaceOp(op, newMaterialization);
3370 <<
"failed to legalize unresolved materialization "
3372 << inputOperands.
getTypes() <<
") to ("
3373 << op.getResultTypes()
3374 <<
") that remained live after conversion";
3375 diag.attachNote(op->getUsers().begin()->getLoc())
3376 <<
"see existing live user here: " << *op->getUsers().begin();
3380template <
typename Fn>
3383 bool isRecursiveLegalization) {
3391 toConvert.push_back(op);
3394 auto legalityInfo =
target.isLegal(op);
3395 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3401 if (failed(
convert(op, isRecursiveLegalization))) {
3410LogicalResult ConversionPatternRewriter::legalize(
Operation *op) {
3411 return impl->opConverter.legalizeOperations(op,
3415LogicalResult ConversionPatternRewriter::legalize(
Region *r) {
3431 std::optional<TypeConverter::SignatureConversion> conversion =
3432 converter->convertBlockSignature(&r->front());
3435 applySignatureConversion(&r->front(), *conversion, converter);
3440 return impl->opConverter.legalizeOperations(ops,
3449 if (rewriterImpl.
config.allowPatternRollback) {
3473 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3477 if (rewriter.getConfig().buildMaterializations) {
3481 rewriter.getConfig().listener);
3482 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3483 auto it = materializations.find(castOp);
3484 assert(it != materializations.end() &&
"inconsistent state");
3498void TypeConverter::SignatureConversion::addInputs(
unsigned origInputNo,
3500 assert(!types.empty() &&
"expected valid types");
3501 remapInput(origInputNo, argTypes.size(), types.size());
3505void TypeConverter::SignatureConversion::addInputs(
ArrayRef<Type> types) {
3506 assert(!types.empty() &&
3507 "1->0 type remappings don't need to be added explicitly");
3508 argTypes.append(types.begin(), types.end());
3511void TypeConverter::SignatureConversion::remapInput(
unsigned origInputNo,
3512 unsigned newInputNo,
3513 unsigned newInputCount) {
3514 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3515 assert(newInputCount != 0 &&
"expected valid input count");
3516 remappedInputs[origInputNo] =
3517 InputMapping{newInputNo, newInputCount, {}};
3520void TypeConverter::SignatureConversion::remapInput(
3521 unsigned origInputNo, ArrayRef<Value> replacements) {
3522 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3523 remappedInputs[origInputNo] = InputMapping{
3525 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3536TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3537 SmallVectorImpl<Type> &results)
const {
3538 assert(typeOrValue &&
"expected non-null type");
3539 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3540 : cast<Type>(typeOrValue);
3542 std::shared_lock<
decltype(cacheMutex)> cacheReadLock(cacheMutex,
3545 cacheReadLock.lock();
3546 auto existingIt = cachedDirectConversions.find(t);
3547 if (existingIt != cachedDirectConversions.end()) {
3548 if (existingIt->second)
3549 results.push_back(existingIt->second);
3550 return success(existingIt->second !=
nullptr);
3552 auto multiIt = cachedMultiConversions.find(t);
3553 if (multiIt != cachedMultiConversions.end()) {
3554 results.append(multiIt->second.begin(), multiIt->second.end());
3560 size_t currentCount = results.size();
3564 auto isCacheable = [&](
int index) {
3565 int numberOfConversionsUntilContextAware =
3566 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3567 return index < numberOfConversionsUntilContextAware;
3570 std::unique_lock<
decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3573 for (
auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3574 const ConversionCallbackFn &converter = indexedConverter.value();
3575 std::optional<LogicalResult>
result = converter(typeOrValue, results);
3577 assert(results.size() == currentCount &&
3578 "failed type conversion should not change results");
3581 if (!isCacheable(indexedConverter.index()))
3584 cacheWriteLock.lock();
3585 if (!succeeded(*
result)) {
3586 assert(results.size() == currentCount &&
3587 "failed type conversion should not change results");
3588 cachedDirectConversions.try_emplace(t,
nullptr);
3591 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3592 if (newTypes.size() == 1)
3593 cachedDirectConversions.try_emplace(t, newTypes.front());
3595 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3601LogicalResult TypeConverter::convertType(Type t,
3602 SmallVectorImpl<Type> &results)
const {
3603 return convertTypeImpl(t, results);
3606LogicalResult TypeConverter::convertType(Value v,
3607 SmallVectorImpl<Type> &results)
const {
3608 return convertTypeImpl(v, results);
3611Type TypeConverter::convertType(Type t)
const {
3613 SmallVector<Type, 1> results;
3614 if (
failed(convertType(t, results)))
3618 return results.size() == 1 ? results.front() :
nullptr;
3621Type TypeConverter::convertType(Value v)
const {
3623 SmallVector<Type, 1> results;
3624 if (
failed(convertType(v, results)))
3628 return results.size() == 1 ? results.front() :
nullptr;
3632TypeConverter::convertTypes(
TypeRange types,
3633 SmallVectorImpl<Type> &results)
const {
3634 for (Type type : types)
3635 if (
failed(convertType(type, results)))
3641TypeConverter::convertTypes(
ValueRange values,
3642 SmallVectorImpl<Type> &results)
const {
3643 for (Value value : values)
3644 if (
failed(convertType(value, results)))
3649bool TypeConverter::isLegal(Type type)
const {
3650 return convertType(type) == type;
3653bool TypeConverter::isLegal(Value value)
const {
3654 return convertType(value) == value.
getType();
3657bool TypeConverter::isLegal(Operation *op)
const {
3661bool TypeConverter::isLegal(Region *region)
const {
3662 return llvm::all_of(
3666bool TypeConverter::isSignatureLegal(FunctionType ty)
const {
3667 if (!isLegal(ty.getInputs()))
3669 if (!isLegal(ty.getResults()))
3675TypeConverter::convertSignatureArg(
unsigned inputNo, Type type,
3676 SignatureConversion &
result)
const {
3678 SmallVector<Type, 1> convertedTypes;
3679 if (
failed(convertType(type, convertedTypes)))
3683 if (convertedTypes.empty())
3687 result.addInputs(inputNo, convertedTypes);
3691TypeConverter::convertSignatureArgs(
TypeRange types,
3692 SignatureConversion &
result,
3693 unsigned origInputOffset)
const {
3694 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3695 if (
failed(convertSignatureArg(origInputOffset + i, types[i],
result)))
3700TypeConverter::convertSignatureArg(
unsigned inputNo, Value value,
3701 SignatureConversion &
result)
const {
3703 SmallVector<Type, 1> convertedTypes;
3704 if (
failed(convertType(value, convertedTypes)))
3708 if (convertedTypes.empty())
3712 result.addInputs(inputNo, convertedTypes);
3716TypeConverter::convertSignatureArgs(
ValueRange values,
3717 SignatureConversion &
result,
3718 unsigned origInputOffset)
const {
3719 for (
unsigned i = 0, e = values.size(); i != e; ++i)
3720 if (
failed(convertSignatureArg(origInputOffset + i, values[i],
result)))
3725Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3726 Location loc, Type resultType,
3728 for (
const SourceMaterializationCallbackFn &fn :
3729 llvm::reverse(sourceMaterializations))
3730 if (Value
result = fn(builder, resultType, inputs, loc))
3735Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3736 Location loc, Type resultType,
3738 Type originalType)
const {
3739 SmallVector<Value>
result = materializeTargetConversion(
3740 builder, loc,
TypeRange(resultType), inputs, originalType);
3743 assert(
result.size() == 1 &&
"expected single result");
3747SmallVector<Value> TypeConverter::materializeTargetConversion(
3749 Type originalType)
const {
3750 for (
const TargetMaterializationCallbackFn &fn :
3751 llvm::reverse(targetMaterializations)) {
3752 SmallVector<Value>
result =
3753 fn(builder, resultTypes, inputs, loc, originalType);
3757 "callback produced incorrect number of values or values with "
3764std::optional<TypeConverter::SignatureConversion>
3765TypeConverter::convertBlockSignature(
Block *block)
const {
3768 return std::nullopt;
3775TypeConverter::AttributeConversionResult
3776TypeConverter::AttributeConversionResult::result(Attribute attr) {
3777 return AttributeConversionResult(attr, resultTag);
3780TypeConverter::AttributeConversionResult
3781TypeConverter::AttributeConversionResult::na() {
3782 return AttributeConversionResult(
nullptr, naTag);
3785TypeConverter::AttributeConversionResult
3786TypeConverter::AttributeConversionResult::abort() {
3787 return AttributeConversionResult(
nullptr, abortTag);
3790bool TypeConverter::AttributeConversionResult::hasResult()
const {
3791 return impl.getInt() == resultTag;
3794bool TypeConverter::AttributeConversionResult::isNa()
const {
3795 return impl.getInt() == naTag;
3798bool TypeConverter::AttributeConversionResult::isAbort()
const {
3799 return impl.getInt() == abortTag;
3802Attribute TypeConverter::AttributeConversionResult::getResult()
const {
3803 assert(hasResult() &&
"Cannot get result from N/A or abort");
3804 return impl.getPointer();
3807std::optional<Attribute>
3808TypeConverter::convertTypeAttribute(Type type, Attribute attr)
const {
3809 for (
const TypeAttributeConversionCallbackFn &fn :
3810 llvm::reverse(typeAttributeConversions)) {
3811 AttributeConversionResult res = fn(type, attr);
3812 if (res.hasResult())
3813 return res.getResult();
3815 return std::nullopt;
3817 return std::nullopt;
3826 ConversionPatternRewriter &rewriter) {
3827 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3832 TypeConverter::SignatureConversion
result(type.getNumInputs());
3834 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
result)) ||
3835 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3837 if (!funcOp.getFunctionBody().empty())
3838 rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(),
result,
3842 auto newType = FunctionType::get(rewriter.getContext(),
3843 result.getConvertedTypes(), newResults);
3845 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3854struct FunctionOpInterfaceSignatureConversion :
public ConversionPattern {
3855 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3857 const TypeConverter &converter,
3858 PatternBenefit benefit)
3859 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3862 matchAndRewrite(Operation *op, ArrayRef<Value> ,
3863 ConversionPatternRewriter &rewriter)
const override {
3864 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3869struct AnyFunctionOpInterfaceSignatureConversion
3870 :
public OpInterfaceConversionPattern<FunctionOpInterface> {
3871 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3874 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> ,
3875 ConversionPatternRewriter &rewriter)
const override {
3881FailureOr<Operation *>
3882mlir::convertOpResultTypes(Operation *op,
ValueRange operands,
3883 const TypeConverter &converter,
3884 ConversionPatternRewriter &rewriter) {
3885 assert(op &&
"Invalid op");
3886 Location loc = op->
getLoc();
3887 if (converter.isLegal(op))
3888 return rewriter.notifyMatchFailure(loc,
"op already legal");
3890 OperationState newOp(loc, op->
getName());
3891 newOp.addOperands(operands);
3893 SmallVector<Type> newResultTypes;
3895 return rewriter.notifyMatchFailure(loc,
"couldn't convert return types");
3897 newOp.addTypes(newResultTypes);
3898 newOp.addAttributes(op->
getAttrs());
3899 return rewriter.create(newOp);
3902void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3903 StringRef functionLikeOpName, RewritePatternSet &
patterns,
3904 const TypeConverter &converter, PatternBenefit benefit) {
3905 patterns.add<FunctionOpInterfaceSignatureConversion>(
3906 functionLikeOpName,
patterns.getContext(), converter, benefit);
3909void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3910 RewritePatternSet &
patterns,
const TypeConverter &converter,
3911 PatternBenefit benefit) {
3912 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3913 converter,
patterns.getContext(), benefit);
3920void ConversionTarget::setOpAction(OperationName op,
3921 LegalizationAction action) {
3922 legalOperations[op].action = action;
3925void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3926 LegalizationAction action) {
3927 for (StringRef dialect : dialectNames)
3928 legalDialects[dialect] = action;
3931auto ConversionTarget::getOpAction(OperationName op)
const
3932 -> std::optional<LegalizationAction> {
3933 std::optional<LegalizationInfo> info = getOpInfo(op);
3934 return info ? info->action : std::optional<LegalizationAction>();
3937auto ConversionTarget::isLegal(Operation *op)
const
3938 -> std::optional<LegalOpDetails> {
3939 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3941 return std::nullopt;
3944 auto isOpLegal = [&] {
3946 if (info->action == LegalizationAction::Dynamic) {
3947 std::optional<bool>
result = info->legalityFn(op);
3953 return info->action == LegalizationAction::Legal;
3956 return std::nullopt;
3959 LegalOpDetails legalityDetails;
3960 if (info->isRecursivelyLegal) {
3961 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3962 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3963 legalityDetails.isRecursivelyLegal =
3964 legalityFnIt->second(op).value_or(
true);
3966 legalityDetails.isRecursivelyLegal =
true;
3969 return legalityDetails;
3972bool ConversionTarget::isIllegal(Operation *op)
const {
3973 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3977 if (info->action == LegalizationAction::Dynamic) {
3978 std::optional<bool>
result = info->legalityFn(op);
3985 return info->action == LegalizationAction::Illegal;
3989 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
3990 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
3994 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3996 if (std::optional<bool>
result = newCl(op))
4004void ConversionTarget::setLegalityCallback(
4005 OperationName name,
const DynamicLegalityCallbackFn &callback) {
4006 assert(callback &&
"expected valid legality callback");
4007 auto *infoIt = legalOperations.find(name);
4008 assert(infoIt != legalOperations.end() &&
4009 infoIt->second.action == LegalizationAction::Dynamic &&
4010 "expected operation to already be marked as dynamically legal");
4011 infoIt->second.legalityFn =
4015void ConversionTarget::markOpRecursivelyLegal(
4016 OperationName name,
const DynamicLegalityCallbackFn &callback) {
4017 auto *infoIt = legalOperations.find(name);
4018 assert(infoIt != legalOperations.end() &&
4019 infoIt->second.action != LegalizationAction::Illegal &&
4020 "expected operation to already be marked as legal");
4021 infoIt->second.isRecursivelyLegal =
true;
4024 std::move(opRecursiveLegalityFns[name]), callback);
4026 opRecursiveLegalityFns.erase(name);
4029void ConversionTarget::setLegalityCallback(
4030 ArrayRef<StringRef> dialects,
const DynamicLegalityCallbackFn &callback) {
4031 assert(callback &&
"expected valid legality callback");
4032 for (StringRef dialect : dialects)
4034 std::move(dialectLegalityFns[dialect]), callback);
4037void ConversionTarget::setLegalityCallback(
4038 const DynamicLegalityCallbackFn &callback) {
4039 assert(callback &&
"expected valid legality callback");
4043auto ConversionTarget::getOpInfo(OperationName op)
const
4044 -> std::optional<LegalizationInfo> {
4046 const auto *it = legalOperations.find(op);
4047 if (it != legalOperations.end())
4050 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4051 if (dialectIt != legalDialects.end()) {
4052 DynamicLegalityCallbackFn callback;
4053 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4054 if (dialectFn != dialectLegalityFns.end())
4055 callback = dialectFn->second;
4056 return LegalizationInfo{dialectIt->second,
false,
4060 if (unknownLegalityFn)
4061 return LegalizationInfo{LegalizationAction::Dynamic,
4062 false, unknownLegalityFn};
4063 return std::nullopt;
4066#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4071void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4072 auto &rewriterImpl =
4073 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4077void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4078 auto &rewriterImpl =
4079 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4085static FailureOr<SmallVector<Value>>
4086pdllConvertValues(ConversionPatternRewriter &rewriter,
ValueRange values) {
4087 SmallVector<Value> mappedValues;
4088 if (
failed(rewriter.getRemappedValues(values, mappedValues)))
4090 return std::move(mappedValues);
4093void mlir::registerConversionPDLFunctions(RewritePatternSet &
patterns) {
4094 patterns.getPDLPatterns().registerRewriteFunction(
4096 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4097 auto results = pdllConvertValues(
4098 static_cast<ConversionPatternRewriter &
>(rewriter), value);
4101 return results->front();
4103 patterns.getPDLPatterns().registerRewriteFunction(
4104 "convertValues", [](PatternRewriter &rewriter,
ValueRange values) {
4105 return pdllConvertValues(
4106 static_cast<ConversionPatternRewriter &
>(rewriter), values);
4108 patterns.getPDLPatterns().registerRewriteFunction(
4110 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4111 auto &rewriterImpl =
4112 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4113 if (
const TypeConverter *converter =
4115 if (Type newType = converter->convertType(type))
4121 patterns.getPDLPatterns().registerRewriteFunction(
4123 [](PatternRewriter &rewriter,
4124 TypeRange types) -> FailureOr<SmallVector<Type>> {
4125 auto &rewriterImpl =
4126 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4129 return SmallVector<Type>(types);
4131 SmallVector<Type> remappedTypes;
4132 if (
failed(converter->convertTypes(types, remappedTypes)))
4134 return std::move(remappedTypes);
4149 static constexpr StringLiteral
tag =
"apply-conversion";
4150 static constexpr StringLiteral
desc =
4151 "Encapsulate the application of a dialect conversion";
4160 OpConversionMode mode) {
4164 LogicalResult status =
success();
4180LogicalResult mlir::applyPartialConversion(
4181 ArrayRef<Operation *> ops,
const ConversionTarget &
target,
4182 const FrozenRewritePatternSet &
patterns, ConversionConfig
config) {
4184 OpConversionMode::Partial);
4187mlir::applyPartialConversion(Operation *op,
const ConversionTarget &
target,
4188 const FrozenRewritePatternSet &
patterns,
4189 ConversionConfig
config) {
4197LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4198 const ConversionTarget &
target,
4199 const FrozenRewritePatternSet &
patterns,
4200 ConversionConfig
config) {
4203LogicalResult mlir::applyFullConversion(Operation *op,
4204 const ConversionTarget &
target,
4205 const FrozenRewritePatternSet &
patterns,
4206 ConversionConfig
config) {
4224 "expected top-level op to be isolated from above");
4227 "expected ops to have a common ancestor");
4236 for (
Operation *op : ops.drop_front()) {
4240 assert(commonAncestor &&
4241 "expected to find a common isolated from above ancestor");
4245 return commonAncestor;
4248LogicalResult mlir::applyAnalysisConversion(
4249 ArrayRef<Operation *> ops, ConversionTarget &
target,
4250 const FrozenRewritePatternSet &
patterns, ConversionConfig
config) {
4252 if (
config.legalizableOps)
4253 assert(
config.legalizableOps->empty() &&
"expected empty set");
4259 Operation *clonedAncestor = commonAncestor->
clone(mapping);
4263 inverseOperationMap[it.second] = it.first;
4266 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4267 ops, [&](Operation *op) {
return mapping.
lookup(op); });
4269 OpConversionMode::Analysis);
4273 if (
config.legalizableOps) {
4275 for (Operation *op : *
config.legalizableOps)
4276 originalLegalizableOps.insert(inverseOperationMap[op]);
4277 *
config.legalizableOps = std::move(originalLegalizableOps);
4281 clonedAncestor->
erase();
4286mlir::applyAnalysisConversion(Operation *op, ConversionTarget &
target,
4287 const FrozenRewritePatternSet &
patterns,
4288 ConversionConfig
config) {
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
void destroyOpProperties(OpaqueProperties properties) const
This hooks destroy the op properties.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
CRTP Implementation of an action.
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This iterator enumerates elements according to their dominance relationship.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
const ConversionTarget & getTarget()
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, bool isRecursiveLegalization=false)
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, Fn onFailure, bool isRecursiveLegalization=false)
Legalizes the given operations (and their nested operations) to the conversion target.
LogicalResult applyConversion(ArrayRef< Operation * > ops)
Applies the conversion to the given operations (and their nested operations).
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
bool wasErased(void *ptr) const
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.