10#include "mlir/Config/mlir-config.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/FormatVariadic.h"
25#include "llvm/Support/SaveAndRestore.h"
26#include "llvm/Support/ScopedPrinter.h"
33#define DEBUG_TYPE "dialect-conversion"
36template <
typename... Args>
37static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
40 os.startLine() <<
"} -> SUCCESS";
42 os.getOStream() <<
" : "
43 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
44 os.getOStream() <<
"\n";
49template <
typename... Args>
50static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
53 os.startLine() <<
"} -> FAILURE : "
54 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
64 if (
OpResult inputRes = dyn_cast<OpResult>(value))
65 insertPt = ++inputRes.getOwner()->getIterator();
72 assert(!vals.empty() &&
"expected at least one value");
75 for (
Value v : vals.drop_front()) {
89 assert(dom &&
"unable to find valid insertion point");
97enum OpConversionMode {
123struct ValueVectorMapInfo {
126 static ::llvm::hash_code getHashValue(
const ValueVector &val) {
127 return ::llvm::hash_combine_range(val);
136struct ConversionValueMapping {
139 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
144 template <
typename T>
145 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
148 template <
typename OldVal,
typename NewVal>
149 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
150 map(OldVal &&oldVal, NewVal &&newVal) {
154 assert(next != oldVal &&
"inserting cyclic mapping");
155 auto it = mapping.find(next);
156 if (it == mapping.end())
161 mappedTo.insert_range(newVal);
163 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
167 template <
typename OldVal,
typename NewVal>
168 std::enable_if_t<!IsValueVector<OldVal>::value ||
169 !IsValueVector<NewVal>::value>
170 map(OldVal &&oldVal, NewVal &&newVal) {
171 if constexpr (IsValueVector<OldVal>{}) {
172 map(std::forward<OldVal>(oldVal),
ValueVector{newVal});
173 }
else if constexpr (IsValueVector<NewVal>{}) {
174 map(
ValueVector{oldVal}, std::forward<NewVal>(newVal));
185 void erase(
const ValueVector &value) { mapping.erase(value); }
205 assert(!values.empty() &&
"expected non-empty value vector");
206 Operation *op = values.front().getDefiningOp();
207 for (
Value v : llvm::drop_begin(values)) {
208 if (v.getDefiningOp() != op)
218 assert(!values.empty() &&
"expected non-empty value vector");
224 auto it = mapping.find(from);
225 if (it == mapping.end()) {
238struct RewriterState {
239 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
240 unsigned numReplacedOps)
241 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
242 numReplacedOps(numReplacedOps) {}
245 unsigned numRewrites;
248 unsigned numIgnoredOperations;
251 unsigned numReplacedOps;
258static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
261static void notifyIRErased(RewriterBase::Listener *listener,
Block &
b) {
262 for (Operation &op :
b)
263 notifyIRErased(listener, op);
269static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
272 notifyIRErased(listener,
b);
302 UnresolvedMaterialization,
307 virtual ~IRRewrite() =
default;
310 virtual void rollback() = 0;
324 virtual void commit(RewriterBase &rewriter) {}
327 virtual void cleanup(RewriterBase &rewriter) {}
329 Kind getKind()
const {
return kind; }
331 static bool classof(
const IRRewrite *
rewrite) {
return true; }
334 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
335 : kind(kind), rewriterImpl(rewriterImpl) {}
337 const ConversionConfig &getConfig()
const;
340 ConversionPatternRewriterImpl &rewriterImpl;
344class BlockRewrite :
public IRRewrite {
347 Block *getBlock()
const {
return block; }
349 static bool classof(
const IRRewrite *
rewrite) {
350 return rewrite->getKind() >= Kind::CreateBlock &&
351 rewrite->getKind() <= Kind::BlockTypeConversion;
355 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
357 : IRRewrite(kind, rewriterImpl), block(block) {}
364class ValueRewrite :
public IRRewrite {
367 Value getValue()
const {
return value; }
369 static bool classof(
const IRRewrite *
rewrite) {
370 return rewrite->getKind() == Kind::ReplaceValue;
374 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
376 : IRRewrite(kind, rewriterImpl), value(value) {}
385class CreateBlockRewrite :
public BlockRewrite {
387 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
388 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
390 static bool classof(
const IRRewrite *
rewrite) {
391 return rewrite->getKind() == Kind::CreateBlock;
394 void commit(RewriterBase &rewriter)
override {
400 void rollback()
override {
403 auto &blockOps = block->getOperations();
404 while (!blockOps.empty())
405 blockOps.remove(blockOps.begin());
406 block->dropAllUses();
407 if (block->getParent())
418class EraseBlockRewrite :
public BlockRewrite {
420 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
421 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
422 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
424 static bool classof(
const IRRewrite *
rewrite) {
425 return rewrite->getKind() == Kind::EraseBlock;
428 ~EraseBlockRewrite()
override {
430 "rewrite was neither rolled back nor committed/cleaned up");
433 void rollback()
override {
436 assert(block &&
"expected block");
441 blockList.insert(before, block);
445 void commit(RewriterBase &rewriter)
override {
446 assert(block &&
"expected block");
450 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
451 notifyIRErased(listener, *block);
454 void cleanup(RewriterBase &rewriter)
override {
456 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
458 assert(block->empty() &&
"expected empty block");
461 block->dropAllDefinedValueUses();
472 Block *insertBeforeBlock;
478class InlineBlockRewrite :
public BlockRewrite {
480 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
482 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
483 sourceBlock(sourceBlock),
484 firstInlinedInst(sourceBlock->empty() ?
nullptr
485 : &sourceBlock->front()),
486 lastInlinedInst(sourceBlock->empty() ?
nullptr : &sourceBlock->back()) {
492 assert(!getConfig().listener &&
493 "InlineBlockRewrite not supported if listener is attached");
496 static bool classof(
const IRRewrite *
rewrite) {
497 return rewrite->getKind() == Kind::InlineBlock;
500 void rollback()
override {
503 if (firstInlinedInst) {
504 assert(lastInlinedInst &&
"expected operation");
517 Operation *firstInlinedInst;
520 Operation *lastInlinedInst;
524class MoveBlockRewrite :
public BlockRewrite {
526 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
528 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block),
529 region(previousRegion),
530 insertBeforeBlock(previousIt == previousRegion->end() ?
nullptr
533 static bool classof(
const IRRewrite *
rewrite) {
534 return rewrite->getKind() == Kind::MoveBlock;
537 void commit(RewriterBase &rewriter)
override {
547 void rollback()
override {
560 Block *insertBeforeBlock;
564class BlockTypeConversionRewrite :
public BlockRewrite {
566 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
568 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, origBlock),
569 newBlock(newBlock) {}
571 static bool classof(
const IRRewrite *
rewrite) {
572 return rewrite->getKind() == Kind::BlockTypeConversion;
575 Block *getOrigBlock()
const {
return block; }
577 Block *getNewBlock()
const {
return newBlock; }
579 void commit(RewriterBase &rewriter)
override;
581 void rollback()
override;
591class ReplaceValueRewrite :
public ValueRewrite {
593 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
594 const TypeConverter *converter)
595 : ValueRewrite(
Kind::ReplaceValue, rewriterImpl, value),
596 converter(converter) {}
598 static bool classof(
const IRRewrite *
rewrite) {
599 return rewrite->getKind() == Kind::ReplaceValue;
602 void commit(RewriterBase &rewriter)
override;
604 void rollback()
override;
608 const TypeConverter *converter;
612class OperationRewrite :
public IRRewrite {
615 Operation *getOperation()
const {
return op; }
617 static bool classof(
const IRRewrite *
rewrite) {
618 return rewrite->getKind() >= Kind::MoveOperation &&
619 rewrite->getKind() <= Kind::UnresolvedMaterialization;
623 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
625 : IRRewrite(kind, rewriterImpl), op(op) {}
632class MoveOperationRewrite :
public OperationRewrite {
634 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
635 Operation *op, OpBuilder::InsertPoint previous)
636 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op),
637 block(previous.getBlock()),
638 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
640 : &*previous.getPoint()) {}
642 static bool classof(
const IRRewrite *
rewrite) {
643 return rewrite->getKind() == Kind::MoveOperation;
646 void commit(RewriterBase &rewriter)
override {
652 op, OpBuilder::InsertPoint(block,
657 void rollback()
override {
670 Operation *insertBeforeOp;
675class ModifyOperationRewrite :
public OperationRewrite {
677 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
679 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
680 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
681 operands(op->operand_begin(), op->operand_end()),
682 successors(op->successor_begin(), op->successor_end()) {
685 propertiesStorage = operator new(op->getPropertiesStorageSize());
686 OpaqueProperties propCopy(propertiesStorage);
687 name.initOpProperties(propCopy, prop);
691 static bool classof(
const IRRewrite *
rewrite) {
692 return rewrite->getKind() == Kind::ModifyOperation;
695 ~ModifyOperationRewrite()
override {
696 assert(!propertiesStorage &&
697 "rewrite was neither committed nor rolled back");
700 void commit(RewriterBase &rewriter)
override {
703 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
706 if (propertiesStorage) {
707 OpaqueProperties propCopy(propertiesStorage);
711 operator delete(propertiesStorage);
712 propertiesStorage =
nullptr;
716 void rollback()
override {
720 for (
const auto &it : llvm::enumerate(successors))
722 if (propertiesStorage) {
723 OpaqueProperties propCopy(propertiesStorage);
726 operator delete(propertiesStorage);
727 propertiesStorage =
nullptr;
734 DictionaryAttr attrs;
735 SmallVector<Value, 8> operands;
736 SmallVector<Block *, 2> successors;
737 void *propertiesStorage =
nullptr;
744class ReplaceOperationRewrite :
public OperationRewrite {
746 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
747 Operation *op,
const TypeConverter *converter)
748 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
749 converter(converter) {}
751 static bool classof(
const IRRewrite *
rewrite) {
752 return rewrite->getKind() == Kind::ReplaceOperation;
755 void commit(RewriterBase &rewriter)
override;
757 void rollback()
override;
759 void cleanup(RewriterBase &rewriter)
override;
764 const TypeConverter *converter;
767class CreateOperationRewrite :
public OperationRewrite {
769 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
771 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
773 static bool classof(
const IRRewrite *
rewrite) {
774 return rewrite->getKind() == Kind::CreateOperation;
777 void commit(RewriterBase &rewriter)
override {
783 void rollback()
override;
787enum MaterializationKind {
798class UnresolvedMaterializationInfo {
800 UnresolvedMaterializationInfo() =
default;
801 UnresolvedMaterializationInfo(
const TypeConverter *converter,
802 MaterializationKind kind, Type originalType)
803 : converterAndKind(converter, kind), originalType(originalType) {}
806 const TypeConverter *getConverter()
const {
807 return converterAndKind.getPointer();
811 MaterializationKind getMaterializationKind()
const {
812 return converterAndKind.getInt();
816 Type getOriginalType()
const {
return originalType; }
821 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
832class UnresolvedMaterializationRewrite :
public OperationRewrite {
834 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
835 UnrealizedConversionCastOp op,
837 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
838 mappedValues(std::move(mappedValues)) {}
840 static bool classof(
const IRRewrite *
rewrite) {
841 return rewrite->getKind() == Kind::UnresolvedMaterialization;
844 void rollback()
override;
846 UnrealizedConversionCastOp getOperation()
const {
847 return cast<UnrealizedConversionCastOp>(op);
857#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
860template <
typename RewriteTy,
typename R>
861static bool hasRewrite(R &&rewrites, Operation *op) {
862 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
863 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
864 return rewriteTy && rewriteTy->getOperation() == op;
870template <
typename RewriteTy,
typename R>
871static bool hasRewrite(R &&rewrites,
Block *block) {
872 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
873 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
874 return rewriteTy && rewriteTy->getBlock() == block;
886 const ConversionConfig &
config,
896 RewriterState getCurrentState();
900 void applyRewrites();
905 void resetState(RewriterState state, StringRef patternName =
"");
909 template <
typename RewriteTy,
typename... Args>
911 assert(
config.allowPatternRollback &&
"appending rewrites is not allowed");
913 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
919 void undoRewrites(
unsigned numRewritesToKeep = 0, StringRef patternName =
"");
925 LogicalResult remapValues(StringRef valueDiagTag,
926 std::optional<Location> inputLoc,
ValueRange values,
943 bool skipPureTypeConversions =
false)
const;
957 TypeConverter::SignatureConversion *entryConversion);
965 Block *applySignatureConversion(
967 TypeConverter::SignatureConversion &signatureConversion);
987 void eraseBlock(
Block *block);
1025 Value findOrBuildReplacementValue(
Value value,
1033 void notifyOperationInserted(
Operation *op,
1037 void notifyBlockInserted(
Block *block,
Region *previous,
1056 std::function<
void(
Operation *)> opErasedCallback =
nullptr)
1058 opErasedCallback(std::move(opErasedCallback)) {}
1072 assert(block->empty() &&
"expected empty block");
1073 block->dropAllDefinedValueUses();
1081 if (opErasedCallback)
1082 opErasedCallback(op);
1190const ConversionConfig &IRRewrite::getConfig()
const {
1191 return rewriterImpl.
config;
1194void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1198 if (
auto *listener =
1199 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1200 for (Operation *op : getNewBlock()->getUsers())
1204void BlockTypeConversionRewrite::rollback() {
1205 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1212 if (isa<BlockArgument>(repl)) {
1252 result &= functor(operand);
1257void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1264void ReplaceValueRewrite::rollback() {
1265 rewriterImpl.
mapping.erase({value});
1271void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1273 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1276 SmallVector<Value> replacements =
1278 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1286 for (
auto [
result, newValue] :
1287 llvm::zip_equal(op->
getResults(), replacements))
1293 if (getConfig().unlegalizedOps)
1294 getConfig().unlegalizedOps->erase(op);
1298 notifyIRErased(listener, *op);
1305void ReplaceOperationRewrite::rollback() {
1310void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1314void CreateOperationRewrite::rollback() {
1316 while (!region.getBlocks().empty())
1317 region.getBlocks().remove(region.getBlocks().begin());
1323void UnresolvedMaterializationRewrite::rollback() {
1324 if (!mappedValues.empty())
1325 rewriterImpl.
mapping.erase(mappedValues);
1336 for (
size_t i = 0; i <
rewrites.size(); ++i)
1342 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1343 unresolvedMaterializations.erase(castOp);
1346 rewrite->cleanup(eraseRewriter);
1354 Value from,
TypeRange desiredTypes,
bool skipPureTypeConversions)
const {
1357 assert(!values.empty() &&
"expected non-empty value vector");
1361 if (
config.allowPatternRollback)
1362 return mapping.lookup(values);
1369 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1374 if (castOp.getOutputs() != values)
1376 return castOp.getInputs();
1385 for (
Value v : values) {
1388 llvm::append_range(next, r);
1393 if (next != values) {
1422 if (skipPureTypeConversions) {
1425 match &= !pureConversion;
1428 if (!pureConversion)
1429 lastNonMaterialization = current;
1432 desiredValue = current;
1438 current = std::move(next);
1443 if (!desiredTypes.empty())
1444 return desiredValue;
1445 if (skipPureTypeConversions)
1446 return lastNonMaterialization;
1465 StringRef patternName) {
1470 while (
ignoredOps.size() != state.numIgnoredOperations)
1473 while (
replacedOps.size() != state.numReplacedOps)
1478 StringRef patternName) {
1480 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1482 rewrites.resize(numRewritesToKeep);
1486 StringRef valueDiagTag, std::optional<Location> inputLoc,
ValueRange values,
1488 remapped.reserve(llvm::size(values));
1490 for (
const auto &it : llvm::enumerate(values)) {
1491 Value operand = it.value();
1510 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1511 << it.index() <<
", type was " << origType;
1516 if (legalTypes.empty()) {
1517 remapped.push_back({});
1526 remapped.push_back(std::move(repl));
1535 repl, repl, legalTypes,
1537 remapped.push_back(castValues);
1558 TypeConverter::SignatureConversion *entryConversion) {
1560 if (region->
empty())
1565 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1567 std::optional<TypeConverter::SignatureConversion> conversion =
1568 converter.convertBlockSignature(&block);
1577 if (entryConversion)
1580 std::optional<TypeConverter::SignatureConversion> conversion =
1581 converter.convertBlockSignature(®ion->
front());
1589 TypeConverter::SignatureConversion &signatureConversion) {
1590#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1592 if (hasRewrite<BlockTypeConversionRewrite>(
rewrites, block))
1593 llvm::report_fatal_error(
"block was already converted");
1600 auto convertedTypes = signatureConversion.getConvertedTypes();
1607 for (
unsigned i = 0; i < origArgCount; ++i) {
1608 auto inputMap = signatureConversion.getInputMapping(i);
1609 if (!inputMap || inputMap->replacedWithValues())
1612 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1613 newLocs[inputMap->inputNo +
j] = origLoc;
1620 convertedTypes, newLocs);
1628 bool fastPath = !
config.listener;
1630 if (
config.allowPatternRollback)
1634 while (!block->
empty())
1641 for (
unsigned i = 0; i != origArgCount; ++i) {
1645 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1646 signatureConversion.getInputMapping(i);
1654 MaterializationKind::Source,
1658 origArgType,
Type(), converter,
1665 if (inputMap->replacedWithValues()) {
1667 assert(inputMap->size == 0 &&
1668 "invalid to provide a replacement value when the argument isn't "
1676 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1680 if (
config.allowPatternRollback)
1701 assert((!originalType || kind == MaterializationKind::Target) &&
1702 "original type is valid only for target materializations");
1703 assert(
TypeRange(inputs) != outputTypes &&
1704 "materialization is not necessary");
1708 OpBuilder builder(outputTypes.front().getContext());
1710 UnrealizedConversionCastOp convertOp =
1711 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1712 if (
config.attachDebugMaterializationKind) {
1714 kind == MaterializationKind::Source ?
"source" :
"target";
1715 convertOp->setAttr(
"__kind__", builder.
getStringAttr(kindStr));
1722 UnresolvedMaterializationInfo(converter, kind, originalType);
1723 if (
config.allowPatternRollback) {
1724 if (!valuesToMap.empty())
1725 mapping.map(valuesToMap, convertOp.getResults());
1727 std::move(valuesToMap));
1731 return convertOp.getResults();
1736 assert(
config.allowPatternRollback &&
1737 "this code path is valid only in rollback mode");
1744 return repl.front();
1751 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1776 MaterializationKind::Source, ip, value.
getLoc(),
1792 bool wasDetached = !previous.
isSet();
1794 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"' (" << op
1797 logger.getOStream() <<
" (was detached)";
1798 logger.getOStream() <<
"\n";
1804 "attempting to insert into a block within a replaced/erased op");
1808 config.listener->notifyOperationInserted(op, previous);
1817 if (
config.allowPatternRollback) {
1831 if (
config.allowPatternRollback)
1841 assert(!
impl.config.allowPatternRollback &&
1842 "this code path is valid only in 'no rollback' mode");
1844 for (
auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1847 repls.push_back(
Value());
1854 Value srcMat =
impl.buildUnresolvedMaterialization(
1859 repls.push_back(srcMat);
1865 repls.push_back(to[0]);
1874 Value srcMat =
impl.buildUnresolvedMaterialization(
1877 Type(), converter)[0];
1878 repls.push_back(srcMat);
1887 "incorrect number of replacement values");
1889 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
1897 for (
auto [
result, repls] :
1898 llvm::zip_equal(op->
getResults(), newValues)) {
1900 auto logProlog = [&, repls = repls]() {
1901 logger.startLine() <<
" Note: Replacing op result of type "
1902 << resultType <<
" with value(s) of type (";
1903 llvm::interleaveComma(repls,
logger.getOStream(), [&](
Value v) {
1904 logger.getOStream() << v.getType();
1906 logger.getOStream() <<
")";
1912 logger.getOStream() <<
", but the type converter failed to legalize "
1913 "the original type.\n";
1918 logger.getOStream() <<
", but the legalized type(s) is/are (";
1919 llvm::interleaveComma(convertedTypes,
logger.getOStream(),
1920 [&](
Type t) { logger.getOStream() << t; });
1921 logger.getOStream() <<
")\n";
1927 if (!
config.allowPatternRollback) {
1936 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1942 if (
config.unlegalizedOps)
1943 config.unlegalizedOps->erase(op);
1951 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1955 "attempting to replace a value that was already replaced");
1960 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1965 "attempting to replace/erase an unresolved materialization");
1981 logger.startLine() <<
"** Replace Value : '" << from <<
"'";
1982 if (
auto blockArg = dyn_cast<BlockArgument>(from)) {
1984 logger.getOStream() <<
" (in region of '" << parentOp->getName()
1985 <<
"' (" << parentOp <<
")";
1987 logger.getOStream() <<
" (unlinked block)";
1991 logger.getOStream() <<
", conditional replacement";
1995 if (!
config.allowPatternRollback) {
2000 Value repl = repls.front();
2017 "attempting to replace a value that was already replaced");
2019 "attempting to replace a op result that was already replaced");
2024 llvm::report_fatal_error(
2025 "conditional value replacement is not supported in rollback mode");
2031 if (!
config.allowPatternRollback) {
2038 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2044 if (
config.unlegalizedOps)
2045 config.unlegalizedOps->erase(op);
2054 "attempting to erase a block within a replaced/erased op");
2070 bool wasDetached = !previous;
2076 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
2077 <<
"' (" << parent <<
")";
2080 <<
"** Insert Block into detached Region (nullptr parent op)";
2083 logger.getOStream() <<
" (was detached)";
2084 logger.getOStream() <<
"\n";
2090 "attempting to insert into a region within a replaced/erased op");
2095 config.listener->notifyBlockInserted(block, previous, previousIt);
2099 if (
config.allowPatternRollback) {
2113 if (
config.allowPatternRollback)
2127 reasonCallback(
diag);
2128 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
2129 if (
config.notifyCallback)
2138ConversionPatternRewriter::ConversionPatternRewriter(
2142 *this,
config, opConverter)) {
2143 setListener(
impl.get());
2146ConversionPatternRewriter::~ConversionPatternRewriter() =
default;
2148const ConversionConfig &ConversionPatternRewriter::getConfig()
const {
2149 return impl->config;
2153 assert(op && newOp &&
"expected non-null op");
2157void ConversionPatternRewriter::replaceOp(Operation *op,
ValueRange newValues) {
2159 "incorrect # of replacement values");
2163 if (getInsertionPoint() == op->getIterator())
2166 SmallVector<SmallVector<Value>> newVals =
2167 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2168 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2170 impl->replaceOp(op, std::move(newVals));
2173void ConversionPatternRewriter::replaceOpWithMultiple(
2174 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2176 "incorrect # of replacement values");
2180 if (getInsertionPoint() == op->getIterator())
2183 impl->replaceOp(op, std::move(newValues));
2186void ConversionPatternRewriter::eraseOp(Operation *op) {
2188 impl->logger.startLine()
2189 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
2194 if (getInsertionPoint() == op->getIterator())
2197 SmallVector<SmallVector<Value>> nullRepls(op->
getNumResults(), {});
2198 impl->replaceOp(op, std::move(nullRepls));
2201void ConversionPatternRewriter::eraseBlock(
Block *block) {
2202 impl->eraseBlock(block);
2205Block *ConversionPatternRewriter::applySignatureConversion(
2206 Block *block, TypeConverter::SignatureConversion &conversion,
2207 const TypeConverter *converter) {
2208 assert(!impl->wasOpReplaced(block->
getParentOp()) &&
2209 "attempting to apply a signature conversion to a block within a "
2210 "replaced/erased op");
2211 return impl->applySignatureConversion(block, converter, conversion);
2214FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2215 Region *region,
const TypeConverter &converter,
2216 TypeConverter::SignatureConversion *entryConversion) {
2217 assert(!impl->wasOpReplaced(region->
getParentOp()) &&
2218 "attempting to apply a signature conversion to a block within a "
2219 "replaced/erased op");
2220 return impl->convertRegionTypes(region, converter, entryConversion);
2223void ConversionPatternRewriter::replaceAllUsesWith(Value from,
ValueRange to) {
2224 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2227void ConversionPatternRewriter::replaceUsesWithIf(
2229 bool *allUsesReplaced) {
2230 assert(!allUsesReplaced &&
2231 "allUsesReplaced is not supported in a dialect conversion");
2232 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2235Value ConversionPatternRewriter::getRemappedValue(Value key) {
2236 SmallVector<ValueVector> remappedValues;
2237 if (
failed(impl->remapValues(
"value", std::nullopt, key,
2240 assert(remappedValues.front().size() == 1 &&
"1:N conversion not supported");
2241 return remappedValues.front().front();
2245ConversionPatternRewriter::getRemappedValues(
ValueRange keys,
2246 SmallVectorImpl<Value> &results) {
2249 SmallVector<ValueVector> remapped;
2250 if (
failed(impl->remapValues(
"value", std::nullopt, keys,
2253 for (
const auto &values : remapped) {
2254 assert(values.size() == 1 &&
"1:N conversion not supported");
2255 results.push_back(values.front());
2260LogicalResult ConversionPatternRewriter::legalize(Region *r) {
2268 SmallVector<Operation *> ops;
2270 for (Operation &op :
b)
2275 if (
const TypeConverter *converter = impl->currentTypeConverter) {
2276 std::optional<TypeConverter::SignatureConversion> conversion =
2277 converter->convertBlockSignature(&r->front());
2280 applySignatureConversion(&r->front(), *conversion, converter);
2284 for (Operation *op : ops)
2285 if (
failed(legalize(op)))
2291void ConversionPatternRewriter::inlineBlockBefore(
Block *source,
Block *dest,
2296 "incorrect # of argument replacement values");
2297 assert(!impl->wasOpReplaced(source->
getParentOp()) &&
2298 "attempting to inline a block from a replaced/erased op");
2299 assert(!impl->wasOpReplaced(dest->
getParentOp()) &&
2300 "attempting to inline a block into a replaced/erased op");
2301 auto opIgnored = [&](Operation *op) {
return impl->isOpIgnored(op); };
2304 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
2305 "expected 'source' to have no predecessors");
2314 bool fastPath = !getConfig().listener;
2316 if (fastPath && impl->config.allowPatternRollback)
2317 impl->inlineBlockBefore(source, dest, before);
2320 for (
auto it : llvm::zip(source->
getArguments(), argValues))
2321 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2328 while (!source->
empty())
2329 moveOpBefore(&source->
front(), dest, before);
2334 if (getInsertionBlock() == source)
2335 setInsertionPoint(dest, getInsertionPoint());
2341void ConversionPatternRewriter::startOpModification(Operation *op) {
2342 if (!impl->config.allowPatternRollback) {
2347 assert(!impl->wasOpReplaced(op) &&
2348 "attempting to modify a replaced/erased op");
2350 impl->pendingRootUpdates.insert(op);
2352 impl->appendRewrite<ModifyOperationRewrite>(op);
2355void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2356 impl->patternModifiedOps.insert(op);
2357 if (!impl->config.allowPatternRollback) {
2359 if (getConfig().listener)
2360 getConfig().listener->notifyOperationModified(op);
2367 assert(!impl->wasOpReplaced(op) &&
2368 "attempting to modify a replaced/erased op");
2369 assert(impl->pendingRootUpdates.erase(op) &&
2370 "operation did not have a pending in-place update");
2374void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2375 if (!impl->config.allowPatternRollback) {
2380 assert(impl->pendingRootUpdates.erase(op) &&
2381 "operation did not have a pending in-place update");
2384 auto it = llvm::find_if(
2385 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
2386 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2387 return modifyRewrite && modifyRewrite->getOperation() == op;
2389 assert(it != impl->rewrites.rend() &&
"no root update started on op");
2391 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2392 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2395detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2403FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2404 ArrayRef<ValueRange> operands)
const {
2405 SmallVector<Value> oneToOneOperands;
2406 oneToOneOperands.reserve(operands.size());
2408 if (operand.size() != 1)
2411 oneToOneOperands.push_back(operand.front());
2413 return std::move(oneToOneOperands);
2417ConversionPattern::matchAndRewrite(Operation *op,
2418 PatternRewriter &rewriter)
const {
2419 auto &dialectRewriter =
static_cast<ConversionPatternRewriter &
>(rewriter);
2420 auto &rewriterImpl = dialectRewriter.getImpl();
2424 getTypeConverter());
2427 SmallVector<ValueVector> remapped;
2432 SmallVector<ValueRange> remappedAsRange =
2433 llvm::to_vector_of<ValueRange>(remapped);
2434 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2443using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2446class OperationLegalizer {
2448 using LegalizationAction = ConversionTarget::LegalizationAction;
2450 OperationLegalizer(ConversionPatternRewriter &rewriter,
2451 const ConversionTarget &targetInfo,
2452 const FrozenRewritePatternSet &
patterns);
2455 bool isIllegal(Operation *op)
const;
2459 LogicalResult legalize(Operation *op);
2462 const ConversionTarget &getTarget() {
return target; }
2466 LogicalResult legalizeWithFold(Operation *op);
2470 LogicalResult legalizeWithPattern(Operation *op);
2474 bool canApplyPattern(Operation *op,
const Pattern &pattern);
2478 legalizePatternResult(Operation *op,
const Pattern &pattern,
2479 const RewriterState &curState,
2496 void buildLegalizationGraph(
2497 LegalizationPatterns &anyOpLegalizerPatterns,
2508 void computeLegalizationGraphBenefit(
2509 LegalizationPatterns &anyOpLegalizerPatterns,
2514 unsigned computeOpLegalizationDepth(
2521 unsigned applyCostModelToPatterns(
2527 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2530 ConversionPatternRewriter &rewriter;
2533 const ConversionTarget &
target;
2536 PatternApplicator applicator;
2540OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2541 const ConversionTarget &targetInfo,
2542 const FrozenRewritePatternSet &
patterns)
2547 LegalizationPatterns anyOpLegalizerPatterns;
2549 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2550 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2553bool OperationLegalizer::isIllegal(Operation *op)
const {
2554 return target.isIllegal(op);
2557LogicalResult OperationLegalizer::legalize(Operation *op) {
2559 const char *logLineComment =
2560 "//===-------------------------------------------===//\n";
2562 auto &logger = rewriter.getImpl().logger;
2566 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2569 logger.getOStream() <<
"\n";
2570 logger.startLine() << logLineComment;
2571 logger.startLine() <<
"Legalizing operation : ";
2576 logger.getOStream() <<
"'" << op->
getName() <<
"' ";
2577 logger.getOStream() <<
"(" << op <<
") {\n";
2582 logger.startLine() << OpWithFlags(op,
2583 OpPrintingFlags().printGenericOpForm())
2590 logSuccess(logger,
"operation marked 'ignored' during conversion");
2591 logger.startLine() << logLineComment;
2597 if (
auto legalityInfo =
target.isLegal(op)) {
2600 logger,
"operation marked legal by the target{0}",
2601 legalityInfo->isRecursivelyLegal
2602 ?
"; NOTE: operation is recursively legal; skipping internals"
2604 logger.startLine() << logLineComment;
2609 if (legalityInfo->isRecursivelyLegal) {
2610 op->
walk([&](Operation *nested) {
2612 rewriter.getImpl().ignoredOps.
insert(nested);
2621 const ConversionConfig &
config = rewriter.getConfig();
2622 if (
config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2623 if (succeeded(legalizeWithFold(op))) {
2626 logger.startLine() << logLineComment;
2633 if (succeeded(legalizeWithPattern(op))) {
2636 logger.startLine() << logLineComment;
2643 if (
config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2644 if (succeeded(legalizeWithFold(op))) {
2647 logger.startLine() << logLineComment;
2654 logFailure(logger,
"no matched legalization pattern");
2655 logger.startLine() << logLineComment;
2662template <
typename T>
2664 T
result = std::move(obj);
2669LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2670 auto &rewriterImpl = rewriter.getImpl();
2672 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2673 rewriterImpl.
logger.indent();
2678 auto cleanup = llvm::make_scope_exit([&]() {
2688 SmallVector<Value, 2> replacementValues;
2689 SmallVector<Operation *, 2> newOps;
2692 if (
failed(rewriter.
tryFold(op, replacementValues, &newOps))) {
2701 if (replacementValues.empty())
2702 return legalize(op);
2705 rewriter.
replaceOp(op, replacementValues);
2708 for (Operation *newOp : newOps) {
2709 if (
failed(legalize(newOp))) {
2711 "failed to legalize generated constant '{0}'",
2713 if (!rewriter.getConfig().allowPatternRollback) {
2715 llvm::report_fatal_error(
2717 "' folder rollback of IR modifications requested");
2735 auto newOpNames = llvm::map_range(
2737 auto modifiedOpNames = llvm::map_range(
2739 llvm::report_fatal_error(
"pattern '" + pattern.
getDebugName() +
2740 "' produced IR that could not be legalized. " +
2741 "new ops: {" + llvm::join(newOpNames,
", ") +
"}, " +
2743 llvm::join(modifiedOpNames,
", ") +
"}");
2746LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2747 auto &rewriterImpl = rewriter.getImpl();
2748 const ConversionConfig &
config = rewriter.getConfig();
2750#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2752 std::optional<OperationFingerPrint> topLevelFingerPrint;
2753 if (!rewriterImpl.
config.allowPatternRollback) {
2760 topLevelFingerPrint = OperationFingerPrint(checkOp);
2766 rewriterImpl.
logger.startLine()
2767 <<
"WARNING: Multi-threadeding is enabled. Some dialect "
2768 "conversion expensive checks are skipped in multithreading "
2777 auto canApply = [&](
const Pattern &pattern) {
2778 bool canApply = canApplyPattern(op, pattern);
2779 if (canApply &&
config.listener)
2780 config.listener->notifyPatternBegin(pattern, op);
2786 auto onFailure = [&](
const Pattern &pattern) {
2788 if (!rewriterImpl.
config.allowPatternRollback) {
2795#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2797 if (checkOp && topLevelFingerPrint) {
2798 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2799 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2800 llvm::report_fatal_error(
"pattern '" + pattern.getDebugName() +
2801 "' returned failure but IR did change");
2809 if (rewriterImpl.
config.notifyCallback) {
2811 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2818 config.listener->notifyPatternEnd(pattern, failure());
2819 rewriterImpl.
resetState(curState, pattern.getDebugName());
2820 appliedPatterns.erase(&pattern);
2825 auto onSuccess = [&](
const Pattern &pattern) {
2827 if (!rewriterImpl.
config.allowPatternRollback) {
2841 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2842 appliedPatterns.erase(&pattern);
2844 if (!rewriterImpl.
config.allowPatternRollback)
2846 rewriterImpl.
resetState(curState, pattern.getDebugName());
2854 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2858bool OperationLegalizer::canApplyPattern(Operation *op,
2859 const Pattern &pattern) {
2861 auto &os = rewriter.getImpl().logger;
2862 os.getOStream() <<
"\n";
2863 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2865 os.getOStream() <<
")' {\n";
2872 !appliedPatterns.insert(&pattern).second) {
2874 logFailure(rewriter.getImpl().logger,
"pattern was already applied"));
2880LogicalResult OperationLegalizer::legalizePatternResult(
2881 Operation *op,
const Pattern &pattern,
const RewriterState &curState,
2884 [[maybe_unused]]
auto &impl = rewriter.getImpl();
2885 assert(impl.pendingRootUpdates.empty() &&
"dangling root updates");
2887#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2888 if (impl.config.allowPatternRollback) {
2890 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2891 auto replacedRoot = [&] {
2892 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2894 auto updatedRootInPlace = [&] {
2895 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2897 if (!replacedRoot() && !updatedRootInPlace())
2898 llvm::report_fatal_error(
"expected pattern to replace the root operation "
2899 "or modify it in place");
2904 if (
failed(legalizePatternRootUpdates(modifiedOps)) ||
2905 failed(legalizePatternCreatedOperations(newOps))) {
2909 LLVM_DEBUG(
logSuccess(impl.logger,
"pattern applied successfully"));
2913LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2915 for (Operation *op : newOps) {
2916 if (
failed(legalize(op))) {
2917 LLVM_DEBUG(
logFailure(rewriter.getImpl().logger,
2918 "failed to legalize generated operation '{0}'({1})",
2926LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2928 for (Operation *op : modifiedOps) {
2929 if (
failed(legalize(op))) {
2932 "failed to legalize operation updated in-place '{0}'",
2944void OperationLegalizer::buildLegalizationGraph(
2945 LegalizationPatterns &anyOpLegalizerPatterns,
2956 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2957 std::optional<OperationName> root = pattern.
getRootKind();
2963 anyOpLegalizerPatterns.push_back(&pattern);
2968 if (
target.getOpAction(*root) == LegalizationAction::Legal)
2973 invalidPatterns[*root].insert(&pattern);
2975 parentOps[op].insert(*root);
2978 patternWorklist.insert(&pattern);
2986 if (!anyOpLegalizerPatterns.empty()) {
2987 for (
const Pattern *pattern : patternWorklist)
2988 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2992 while (!patternWorklist.empty()) {
2993 auto *pattern = patternWorklist.pop_back_val();
2997 std::optional<LegalizationAction> action = target.getOpAction(op);
2998 return !legalizerPatterns.count(op) &&
2999 (!action || action == LegalizationAction::Illegal);
3005 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
3006 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
3010 for (
auto op : parentOps[*pattern->
getRootKind()])
3011 patternWorklist.set_union(invalidPatterns[op]);
3015void OperationLegalizer::computeLegalizationGraphBenefit(
3016 LegalizationPatterns &anyOpLegalizerPatterns,
3022 for (
auto &opIt : legalizerPatterns)
3023 if (!minOpPatternDepth.count(opIt.first))
3024 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3030 if (!anyOpLegalizerPatterns.empty())
3031 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3037 applicator.applyCostModel([&](
const Pattern &pattern) {
3038 ArrayRef<const Pattern *> orderedPatternList;
3039 if (std::optional<OperationName> rootName = pattern.
getRootKind())
3040 orderedPatternList = legalizerPatterns[*rootName];
3042 orderedPatternList = anyOpLegalizerPatterns;
3045 auto *it = llvm::find(orderedPatternList, &pattern);
3046 if (it == orderedPatternList.end())
3050 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3054unsigned OperationLegalizer::computeOpLegalizationDepth(
3058 auto depthIt = minOpPatternDepth.find(op);
3059 if (depthIt != minOpPatternDepth.end())
3060 return depthIt->second;
3064 auto opPatternsIt = legalizerPatterns.find(op);
3065 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3070 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3074 unsigned minDepth = applyCostModelToPatterns(
3075 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3076 minOpPatternDepth[op] = minDepth;
3080unsigned OperationLegalizer::applyCostModelToPatterns(
3084 unsigned minDepth = std::numeric_limits<unsigned>::max();
3087 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3088 patternsByDepth.reserve(
patterns.size());
3089 for (
const Pattern *pattern :
patterns) {
3092 unsigned generatedOpDepth = computeOpLegalizationDepth(
3093 generatedOp, minOpPatternDepth, legalizerPatterns);
3094 depth = std::max(depth, generatedOpDepth + 1);
3096 patternsByDepth.emplace_back(pattern, depth);
3099 minDepth = std::min(minDepth, depth);
3104 if (patternsByDepth.size() == 1)
3108 llvm::stable_sort(patternsByDepth,
3109 [](
const std::pair<const Pattern *, unsigned> &
lhs,
3110 const std::pair<const Pattern *, unsigned> &
rhs) {
3113 if (
lhs.second !=
rhs.second)
3114 return lhs.second <
rhs.second;
3117 auto lhsBenefit =
lhs.first->getBenefit();
3118 auto rhsBenefit =
rhs.first->getBenefit();
3119 return lhsBenefit > rhsBenefit;
3124 for (
auto &patternIt : patternsByDepth)
3125 patterns.push_back(patternIt.first);
3139template <
typename RangeT>
3142 function_ref<
bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3151 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3152 if (castOp.getInputs().empty())
3155 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3158 if (inputCastOp.getOutputs() != castOp.getInputs())
3164 while (!worklist.empty()) {
3165 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3169 UnrealizedConversionCastOp nextCast = castOp;
3171 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3172 if (llvm::any_of(nextCast.getInputs(), [&](
Value v) {
3173 return v.getDefiningOp() == castOp;
3181 castOp.replaceAllUsesWith(nextCast.getInputs());
3184 nextCast = getInputCast(nextCast);
3194 auto markOpLive = [&](
Operation *rootOp) {
3196 worklist.push_back(rootOp);
3197 while (!worklist.empty()) {
3198 Operation *op = worklist.pop_back_val();
3199 if (liveOps.insert(op).second) {
3202 if (
auto castOp = v.
getDefiningOp<UnrealizedConversionCastOp>())
3203 if (isCastOpOfInterestFn(castOp))
3204 worklist.push_back(castOp);
3210 for (UnrealizedConversionCastOp op : castOps) {
3213 if (liveOps.contains(op.getOperation()))
3217 if (llvm::any_of(op->getUsers(), [&](
Operation *user) {
3218 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3219 return !castOp || !isCastOpOfInterestFn(castOp);
3225 for (UnrealizedConversionCastOp op : castOps) {
3226 if (liveOps.contains(op)) {
3228 if (remainingCastOps)
3229 remainingCastOps->push_back(op);
3240 ArrayRef<UnrealizedConversionCastOp> castOps,
3241 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3243 DenseSet<UnrealizedConversionCastOp> castOpSet;
3244 for (UnrealizedConversionCastOp op : castOps)
3245 castOpSet.insert(op);
3250 const DenseSet<UnrealizedConversionCastOp> &castOps,
3251 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3253 llvm::make_range(castOps.begin(), castOps.end()),
3254 [&](UnrealizedConversionCastOp castOp) {
3255 return castOps.contains(castOp);
3267 [&](UnrealizedConversionCastOp castOp) {
3268 return castOps.contains(castOp);
3285 const ConversionConfig &
config,
3286 OpConversionMode mode)
3298 LogicalResult convert(
Operation *op,
bool isRecursiveLegalization =
false);
3302 ConversionPatternRewriter rewriter;
3305 OperationLegalizer opLegalizer;
3308 OpConversionMode mode;
3312LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
3313 return impl->opConverter.convert(op,
true);
3317 bool isRecursiveLegalization) {
3318 const ConversionConfig &
config = rewriter.getConfig();
3321 if (failed(opLegalizer.legalize(op))) {
3324 if (mode == OpConversionMode::Full) {
3325 if (!isRecursiveLegalization)
3333 if (mode == OpConversionMode::Partial) {
3334 if (opLegalizer.isIllegal(op)) {
3335 if (!isRecursiveLegalization)
3337 <<
"' that was explicitly marked illegal";
3340 if (
config.unlegalizedOps && !isRecursiveLegalization)
3341 config.unlegalizedOps->insert(op);
3343 }
else if (mode == OpConversionMode::Analysis) {
3347 if (
config.legalizableOps && !isRecursiveLegalization)
3348 config.legalizableOps->insert(op);
3355 UnrealizedConversionCastOp op,
3356 const UnresolvedMaterializationInfo &info) {
3357 assert(!op.use_empty() &&
3358 "expected that dead materializations have already been DCE'd");
3365 switch (info.getMaterializationKind()) {
3366 case MaterializationKind::Target:
3367 newMaterialization = converter->materializeTargetConversion(
3368 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3369 info.getOriginalType());
3371 case MaterializationKind::Source:
3372 assert(op->getNumResults() == 1 &&
"expected single result");
3373 Value sourceMat = converter->materializeSourceConversion(
3374 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3376 newMaterialization.push_back(sourceMat);
3379 if (!newMaterialization.empty()) {
3381 ValueRange newMaterializationRange(newMaterialization);
3382 assert(
TypeRange(newMaterializationRange) == op.getResultTypes() &&
3383 "materialization callback produced value of incorrect type");
3385 rewriter.
replaceOp(op, newMaterialization);
3391 <<
"failed to legalize unresolved materialization "
3393 << inputOperands.
getTypes() <<
") to ("
3394 << op.getResultTypes()
3395 <<
") that remained live after conversion";
3396 diag.attachNote(op->getUsers().begin()->getLoc())
3397 <<
"see existing live user here: " << *op->getUsers().begin();
3406 for (
auto *op : ops) {
3409 toConvert.push_back(op);
3412 auto legalityInfo =
target.isLegal(op);
3413 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3422 for (
auto *op : toConvert) {
3425 if (rewriterImpl.
config.allowPatternRollback) {
3449 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3453 if (rewriter.getConfig().buildMaterializations) {
3457 rewriter.getConfig().listener);
3458 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3459 auto it = materializations.find(castOp);
3460 assert(it != materializations.end() &&
"inconsistent state");
3474void TypeConverter::SignatureConversion::addInputs(
unsigned origInputNo,
3476 assert(!types.empty() &&
"expected valid types");
3477 remapInput(origInputNo, argTypes.size(), types.size());
3481void TypeConverter::SignatureConversion::addInputs(
ArrayRef<Type> types) {
3482 assert(!types.empty() &&
3483 "1->0 type remappings don't need to be added explicitly");
3484 argTypes.append(types.begin(), types.end());
3487void TypeConverter::SignatureConversion::remapInput(
unsigned origInputNo,
3488 unsigned newInputNo,
3489 unsigned newInputCount) {
3490 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3491 assert(newInputCount != 0 &&
"expected valid input count");
3492 remappedInputs[origInputNo] =
3493 InputMapping{newInputNo, newInputCount, {}};
3496void TypeConverter::SignatureConversion::remapInput(
3497 unsigned origInputNo, ArrayRef<Value> replacements) {
3498 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3499 remappedInputs[origInputNo] = InputMapping{
3501 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3512TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3513 SmallVectorImpl<Type> &results)
const {
3514 assert(typeOrValue &&
"expected non-null type");
3515 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3516 : cast<Type>(typeOrValue);
3518 std::shared_lock<
decltype(cacheMutex)> cacheReadLock(cacheMutex,
3521 cacheReadLock.lock();
3522 auto existingIt = cachedDirectConversions.find(t);
3523 if (existingIt != cachedDirectConversions.end()) {
3524 if (existingIt->second)
3525 results.push_back(existingIt->second);
3526 return success(existingIt->second !=
nullptr);
3528 auto multiIt = cachedMultiConversions.find(t);
3529 if (multiIt != cachedMultiConversions.end()) {
3530 results.append(multiIt->second.begin(), multiIt->second.end());
3536 size_t currentCount = results.size();
3540 auto isCacheable = [&](
int index) {
3541 int numberOfConversionsUntilContextAware =
3542 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3543 return index < numberOfConversionsUntilContextAware;
3546 std::unique_lock<
decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3549 for (
auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3550 const ConversionCallbackFn &converter = indexedConverter.value();
3551 std::optional<LogicalResult>
result = converter(typeOrValue, results);
3553 assert(results.size() == currentCount &&
3554 "failed type conversion should not change results");
3557 if (!isCacheable(indexedConverter.index()))
3560 cacheWriteLock.lock();
3561 if (!succeeded(*
result)) {
3562 assert(results.size() == currentCount &&
3563 "failed type conversion should not change results");
3564 cachedDirectConversions.try_emplace(t,
nullptr);
3567 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3568 if (newTypes.size() == 1)
3569 cachedDirectConversions.try_emplace(t, newTypes.front());
3571 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3577LogicalResult TypeConverter::convertType(Type t,
3578 SmallVectorImpl<Type> &results)
const {
3579 return convertTypeImpl(t, results);
3582LogicalResult TypeConverter::convertType(Value v,
3583 SmallVectorImpl<Type> &results)
const {
3584 return convertTypeImpl(v, results);
3587Type TypeConverter::convertType(Type t)
const {
3589 SmallVector<Type, 1> results;
3590 if (
failed(convertType(t, results)))
3594 return results.size() == 1 ? results.front() :
nullptr;
3597Type TypeConverter::convertType(Value v)
const {
3599 SmallVector<Type, 1> results;
3600 if (
failed(convertType(v, results)))
3604 return results.size() == 1 ? results.front() :
nullptr;
3608TypeConverter::convertTypes(
TypeRange types,
3609 SmallVectorImpl<Type> &results)
const {
3610 for (Type type : types)
3611 if (
failed(convertType(type, results)))
3617TypeConverter::convertTypes(
ValueRange values,
3618 SmallVectorImpl<Type> &results)
const {
3619 for (Value value : values)
3620 if (
failed(convertType(value, results)))
3625bool TypeConverter::isLegal(Type type)
const {
3626 return convertType(type) == type;
3629bool TypeConverter::isLegal(Value value)
const {
3630 return convertType(value) == value.
getType();
3633bool TypeConverter::isLegal(Operation *op)
const {
3637bool TypeConverter::isLegal(Region *region)
const {
3638 return llvm::all_of(
3642bool TypeConverter::isSignatureLegal(FunctionType ty)
const {
3643 if (!isLegal(ty.getInputs()))
3645 if (!isLegal(ty.getResults()))
3651TypeConverter::convertSignatureArg(
unsigned inputNo, Type type,
3652 SignatureConversion &
result)
const {
3654 SmallVector<Type, 1> convertedTypes;
3655 if (
failed(convertType(type, convertedTypes)))
3659 if (convertedTypes.empty())
3663 result.addInputs(inputNo, convertedTypes);
3667TypeConverter::convertSignatureArgs(
TypeRange types,
3668 SignatureConversion &
result,
3669 unsigned origInputOffset)
const {
3670 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3671 if (
failed(convertSignatureArg(origInputOffset + i, types[i],
result)))
3676TypeConverter::convertSignatureArg(
unsigned inputNo, Value value,
3677 SignatureConversion &
result)
const {
3679 SmallVector<Type, 1> convertedTypes;
3680 if (
failed(convertType(value, convertedTypes)))
3684 if (convertedTypes.empty())
3688 result.addInputs(inputNo, convertedTypes);
3692TypeConverter::convertSignatureArgs(
ValueRange values,
3693 SignatureConversion &
result,
3694 unsigned origInputOffset)
const {
3695 for (
unsigned i = 0, e = values.size(); i != e; ++i)
3696 if (
failed(convertSignatureArg(origInputOffset + i, values[i],
result)))
3701Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3702 Location loc, Type resultType,
3704 for (
const SourceMaterializationCallbackFn &fn :
3705 llvm::reverse(sourceMaterializations))
3706 if (Value
result = fn(builder, resultType, inputs, loc))
3711Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3712 Location loc, Type resultType,
3714 Type originalType)
const {
3715 SmallVector<Value>
result = materializeTargetConversion(
3716 builder, loc,
TypeRange(resultType), inputs, originalType);
3719 assert(
result.size() == 1 &&
"expected single result");
3723SmallVector<Value> TypeConverter::materializeTargetConversion(
3725 Type originalType)
const {
3726 for (
const TargetMaterializationCallbackFn &fn :
3727 llvm::reverse(targetMaterializations)) {
3728 SmallVector<Value>
result =
3729 fn(builder, resultTypes, inputs, loc, originalType);
3733 "callback produced incorrect number of values or values with "
3740std::optional<TypeConverter::SignatureConversion>
3741TypeConverter::convertBlockSignature(
Block *block)
const {
3744 return std::nullopt;
3751TypeConverter::AttributeConversionResult
3752TypeConverter::AttributeConversionResult::result(Attribute attr) {
3753 return AttributeConversionResult(attr, resultTag);
3756TypeConverter::AttributeConversionResult
3757TypeConverter::AttributeConversionResult::na() {
3758 return AttributeConversionResult(
nullptr, naTag);
3761TypeConverter::AttributeConversionResult
3762TypeConverter::AttributeConversionResult::abort() {
3763 return AttributeConversionResult(
nullptr, abortTag);
3766bool TypeConverter::AttributeConversionResult::hasResult()
const {
3767 return impl.getInt() == resultTag;
3770bool TypeConverter::AttributeConversionResult::isNa()
const {
3771 return impl.getInt() == naTag;
3774bool TypeConverter::AttributeConversionResult::isAbort()
const {
3775 return impl.getInt() == abortTag;
3778Attribute TypeConverter::AttributeConversionResult::getResult()
const {
3779 assert(hasResult() &&
"Cannot get result from N/A or abort");
3780 return impl.getPointer();
3783std::optional<Attribute>
3784TypeConverter::convertTypeAttribute(Type type, Attribute attr)
const {
3785 for (
const TypeAttributeConversionCallbackFn &fn :
3786 llvm::reverse(typeAttributeConversions)) {
3787 AttributeConversionResult res = fn(type, attr);
3788 if (res.hasResult())
3789 return res.getResult();
3791 return std::nullopt;
3793 return std::nullopt;
3802 ConversionPatternRewriter &rewriter) {
3803 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3808 TypeConverter::SignatureConversion
result(type.getNumInputs());
3810 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
result)) ||
3811 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3813 if (!funcOp.getFunctionBody().empty())
3814 rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(),
result,
3818 auto newType = FunctionType::get(rewriter.getContext(),
3819 result.getConvertedTypes(), newResults);
3821 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3830struct FunctionOpInterfaceSignatureConversion :
public ConversionPattern {
3831 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3833 const TypeConverter &converter,
3834 PatternBenefit benefit)
3835 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3838 matchAndRewrite(Operation *op, ArrayRef<Value> ,
3839 ConversionPatternRewriter &rewriter)
const override {
3840 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3845struct AnyFunctionOpInterfaceSignatureConversion
3846 :
public OpInterfaceConversionPattern<FunctionOpInterface> {
3847 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3850 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> ,
3851 ConversionPatternRewriter &rewriter)
const override {
3857FailureOr<Operation *>
3858mlir::convertOpResultTypes(Operation *op,
ValueRange operands,
3859 const TypeConverter &converter,
3860 ConversionPatternRewriter &rewriter) {
3861 assert(op &&
"Invalid op");
3862 Location loc = op->
getLoc();
3863 if (converter.isLegal(op))
3864 return rewriter.notifyMatchFailure(loc,
"op already legal");
3866 OperationState newOp(loc, op->
getName());
3867 newOp.addOperands(operands);
3869 SmallVector<Type> newResultTypes;
3871 return rewriter.notifyMatchFailure(loc,
"couldn't convert return types");
3873 newOp.addTypes(newResultTypes);
3874 newOp.addAttributes(op->
getAttrs());
3875 return rewriter.create(newOp);
3878void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3879 StringRef functionLikeOpName, RewritePatternSet &
patterns,
3880 const TypeConverter &converter, PatternBenefit benefit) {
3881 patterns.add<FunctionOpInterfaceSignatureConversion>(
3882 functionLikeOpName,
patterns.getContext(), converter, benefit);
3885void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3886 RewritePatternSet &
patterns,
const TypeConverter &converter,
3887 PatternBenefit benefit) {
3888 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3889 converter,
patterns.getContext(), benefit);
3896void ConversionTarget::setOpAction(OperationName op,
3897 LegalizationAction action) {
3898 legalOperations[op].action = action;
3901void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3902 LegalizationAction action) {
3903 for (StringRef dialect : dialectNames)
3904 legalDialects[dialect] = action;
3907auto ConversionTarget::getOpAction(OperationName op)
const
3908 -> std::optional<LegalizationAction> {
3909 std::optional<LegalizationInfo> info = getOpInfo(op);
3910 return info ? info->action : std::optional<LegalizationAction>();
3913auto ConversionTarget::isLegal(Operation *op)
const
3914 -> std::optional<LegalOpDetails> {
3915 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3917 return std::nullopt;
3920 auto isOpLegal = [&] {
3922 if (info->action == LegalizationAction::Dynamic) {
3923 std::optional<bool>
result = info->legalityFn(op);
3929 return info->action == LegalizationAction::Legal;
3932 return std::nullopt;
3935 LegalOpDetails legalityDetails;
3936 if (info->isRecursivelyLegal) {
3937 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3938 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3939 legalityDetails.isRecursivelyLegal =
3940 legalityFnIt->second(op).value_or(
true);
3942 legalityDetails.isRecursivelyLegal =
true;
3945 return legalityDetails;
3948bool ConversionTarget::isIllegal(Operation *op)
const {
3949 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3953 if (info->action == LegalizationAction::Dynamic) {
3954 std::optional<bool>
result = info->legalityFn(op);
3961 return info->action == LegalizationAction::Illegal;
3965 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
3966 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
3970 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3972 if (std::optional<bool>
result = newCl(op))
3980void ConversionTarget::setLegalityCallback(
3981 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3982 assert(callback &&
"expected valid legality callback");
3983 auto *infoIt = legalOperations.find(name);
3984 assert(infoIt != legalOperations.end() &&
3985 infoIt->second.action == LegalizationAction::Dynamic &&
3986 "expected operation to already be marked as dynamically legal");
3987 infoIt->second.legalityFn =
3991void ConversionTarget::markOpRecursivelyLegal(
3992 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3993 auto *infoIt = legalOperations.find(name);
3994 assert(infoIt != legalOperations.end() &&
3995 infoIt->second.action != LegalizationAction::Illegal &&
3996 "expected operation to already be marked as legal");
3997 infoIt->second.isRecursivelyLegal =
true;
4000 std::move(opRecursiveLegalityFns[name]), callback);
4002 opRecursiveLegalityFns.erase(name);
4005void ConversionTarget::setLegalityCallback(
4006 ArrayRef<StringRef> dialects,
const DynamicLegalityCallbackFn &callback) {
4007 assert(callback &&
"expected valid legality callback");
4008 for (StringRef dialect : dialects)
4010 std::move(dialectLegalityFns[dialect]), callback);
4013void ConversionTarget::setLegalityCallback(
4014 const DynamicLegalityCallbackFn &callback) {
4015 assert(callback &&
"expected valid legality callback");
4019auto ConversionTarget::getOpInfo(OperationName op)
const
4020 -> std::optional<LegalizationInfo> {
4022 const auto *it = legalOperations.find(op);
4023 if (it != legalOperations.end())
4026 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4027 if (dialectIt != legalDialects.end()) {
4028 DynamicLegalityCallbackFn callback;
4029 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4030 if (dialectFn != dialectLegalityFns.end())
4031 callback = dialectFn->second;
4032 return LegalizationInfo{dialectIt->second,
false,
4036 if (unknownLegalityFn)
4037 return LegalizationInfo{LegalizationAction::Dynamic,
4038 false, unknownLegalityFn};
4039 return std::nullopt;
4042#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4047void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4048 auto &rewriterImpl =
4049 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4053void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4054 auto &rewriterImpl =
4055 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4061static FailureOr<SmallVector<Value>>
4062pdllConvertValues(ConversionPatternRewriter &rewriter,
ValueRange values) {
4063 SmallVector<Value> mappedValues;
4064 if (
failed(rewriter.getRemappedValues(values, mappedValues)))
4066 return std::move(mappedValues);
4069void mlir::registerConversionPDLFunctions(RewritePatternSet &
patterns) {
4070 patterns.getPDLPatterns().registerRewriteFunction(
4072 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4073 auto results = pdllConvertValues(
4074 static_cast<ConversionPatternRewriter &
>(rewriter), value);
4077 return results->front();
4079 patterns.getPDLPatterns().registerRewriteFunction(
4080 "convertValues", [](PatternRewriter &rewriter,
ValueRange values) {
4081 return pdllConvertValues(
4082 static_cast<ConversionPatternRewriter &
>(rewriter), values);
4084 patterns.getPDLPatterns().registerRewriteFunction(
4086 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4087 auto &rewriterImpl =
4088 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4089 if (
const TypeConverter *converter =
4091 if (Type newType = converter->convertType(type))
4097 patterns.getPDLPatterns().registerRewriteFunction(
4099 [](PatternRewriter &rewriter,
4100 TypeRange types) -> FailureOr<SmallVector<Type>> {
4101 auto &rewriterImpl =
4102 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4105 return SmallVector<Type>(types);
4107 SmallVector<Type> remappedTypes;
4108 if (
failed(converter->convertTypes(types, remappedTypes)))
4110 return std::move(remappedTypes);
4125 static constexpr StringLiteral
tag =
"apply-conversion";
4126 static constexpr StringLiteral
desc =
4127 "Encapsulate the application of a dialect conversion";
4136 OpConversionMode mode) {
4140 LogicalResult status =
success();
4156LogicalResult mlir::applyPartialConversion(
4157 ArrayRef<Operation *> ops,
const ConversionTarget &
target,
4158 const FrozenRewritePatternSet &
patterns, ConversionConfig
config) {
4160 OpConversionMode::Partial);
4163mlir::applyPartialConversion(Operation *op,
const ConversionTarget &
target,
4164 const FrozenRewritePatternSet &
patterns,
4165 ConversionConfig
config) {
4173LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4174 const ConversionTarget &
target,
4175 const FrozenRewritePatternSet &
patterns,
4176 ConversionConfig
config) {
4179LogicalResult mlir::applyFullConversion(Operation *op,
4180 const ConversionTarget &
target,
4181 const FrozenRewritePatternSet &
patterns,
4182 ConversionConfig
config) {
4200 "expected top-level op to be isolated from above");
4203 "expected ops to have a common ancestor");
4212 for (
Operation *op : ops.drop_front()) {
4216 assert(commonAncestor &&
4217 "expected to find a common isolated from above ancestor");
4221 return commonAncestor;
4224LogicalResult mlir::applyAnalysisConversion(
4225 ArrayRef<Operation *> ops, ConversionTarget &
target,
4226 const FrozenRewritePatternSet &
patterns, ConversionConfig
config) {
4228 if (
config.legalizableOps)
4229 assert(
config.legalizableOps->empty() &&
"expected empty set");
4235 Operation *clonedAncestor = commonAncestor->
clone(mapping);
4239 inverseOperationMap[it.second] = it.first;
4242 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4243 ops, [&](Operation *op) {
return mapping.
lookup(op); });
4245 OpConversionMode::Analysis);
4249 if (
config.legalizableOps) {
4251 for (Operation *op : *
config.legalizableOps)
4252 originalLegalizableOps.insert(inverseOperationMap[op]);
4253 *
config.legalizableOps = std::move(originalLegalizableOps);
4257 clonedAncestor->
erase();
4262mlir::applyAnalysisConversion(Operation *op, ConversionTarget &
target,
4263 const FrozenRewritePatternSet &
patterns,
4264 ConversionConfig
config) {
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
void destroyOpProperties(OpaqueProperties properties) const
This hooks destroy the op properties.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
CRTP Implementation of an action.
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This iterator enumerates elements according to their dominance relationship.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
bool wasErased(void *ptr) const
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.