10#include "mlir/Config/mlir-config.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/FormatVariadic.h"
25#include "llvm/Support/SaveAndRestore.h"
26#include "llvm/Support/ScopedPrinter.h"
33#define DEBUG_TYPE "dialect-conversion"
36template <
typename... Args>
37static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
40 os.startLine() <<
"} -> SUCCESS";
42 os.getOStream() <<
" : "
43 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
44 os.getOStream() <<
"\n";
49template <
typename... Args>
50static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
53 os.startLine() <<
"} -> FAILURE : "
54 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
64 if (
OpResult inputRes = dyn_cast<OpResult>(value))
65 insertPt = ++inputRes.getOwner()->getIterator();
72 assert(!vals.empty() &&
"expected at least one value");
75 for (
Value v : vals.drop_front()) {
89 assert(dom &&
"unable to find valid insertion point");
97enum OpConversionMode {
123struct ValueVectorMapInfo {
126 static ::llvm::hash_code getHashValue(
const ValueVector &val) {
127 return ::llvm::hash_combine_range(val);
136struct ConversionValueMapping {
139 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
144 template <
typename T>
145 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
148 template <
typename OldVal,
typename NewVal>
149 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
150 map(OldVal &&oldVal, NewVal &&newVal) {
154 assert(next != oldVal &&
"inserting cyclic mapping");
155 auto it = mapping.find(next);
156 if (it == mapping.end())
161 mappedTo.insert_range(newVal);
163 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
167 template <
typename OldVal,
typename NewVal>
168 std::enable_if_t<!IsValueVector<OldVal>::value ||
169 !IsValueVector<NewVal>::value>
170 map(OldVal &&oldVal, NewVal &&newVal) {
171 if constexpr (IsValueVector<OldVal>{}) {
172 map(std::forward<OldVal>(oldVal),
ValueVector{newVal});
173 }
else if constexpr (IsValueVector<NewVal>{}) {
174 map(
ValueVector{oldVal}, std::forward<NewVal>(newVal));
185 void erase(
const ValueVector &value) { mapping.erase(value); }
205 assert(!values.empty() &&
"expected non-empty value vector");
206 Operation *op = values.front().getDefiningOp();
207 for (
Value v : llvm::drop_begin(values)) {
208 if (v.getDefiningOp() != op)
218 assert(!values.empty() &&
"expected non-empty value vector");
224 auto it = mapping.find(from);
225 if (it == mapping.end()) {
238struct RewriterState {
239 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
240 unsigned numReplacedOps)
241 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
242 numReplacedOps(numReplacedOps) {}
245 unsigned numRewrites;
248 unsigned numIgnoredOperations;
251 unsigned numReplacedOps;
258static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
261static void notifyIRErased(RewriterBase::Listener *listener,
Block &
b) {
262 for (Operation &op :
b)
263 notifyIRErased(listener, op);
269static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
272 notifyIRErased(listener,
b);
302 UnresolvedMaterialization,
307 virtual ~IRRewrite() =
default;
310 virtual void rollback() = 0;
324 virtual void commit(RewriterBase &rewriter) {}
327 virtual void cleanup(RewriterBase &rewriter) {}
329 Kind getKind()
const {
return kind; }
331 static bool classof(
const IRRewrite *
rewrite) {
return true; }
334 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
335 : kind(kind), rewriterImpl(rewriterImpl) {}
337 const ConversionConfig &getConfig()
const;
340 ConversionPatternRewriterImpl &rewriterImpl;
344class BlockRewrite :
public IRRewrite {
347 Block *getBlock()
const {
return block; }
349 static bool classof(
const IRRewrite *
rewrite) {
350 return rewrite->getKind() >= Kind::CreateBlock &&
351 rewrite->getKind() <= Kind::BlockTypeConversion;
355 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
357 : IRRewrite(kind, rewriterImpl), block(block) {}
364class ValueRewrite :
public IRRewrite {
367 Value getValue()
const {
return value; }
369 static bool classof(
const IRRewrite *
rewrite) {
370 return rewrite->getKind() == Kind::ReplaceValue;
374 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
376 : IRRewrite(kind, rewriterImpl), value(value) {}
385class CreateBlockRewrite :
public BlockRewrite {
387 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
388 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
390 static bool classof(
const IRRewrite *
rewrite) {
391 return rewrite->getKind() == Kind::CreateBlock;
394 void commit(RewriterBase &rewriter)
override {
400 void rollback()
override {
403 auto &blockOps = block->getOperations();
404 while (!blockOps.empty())
405 blockOps.remove(blockOps.begin());
406 block->dropAllUses();
407 if (block->getParent())
418class EraseBlockRewrite :
public BlockRewrite {
420 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
421 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
422 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
424 static bool classof(
const IRRewrite *
rewrite) {
425 return rewrite->getKind() == Kind::EraseBlock;
428 ~EraseBlockRewrite()
override {
430 "rewrite was neither rolled back nor committed/cleaned up");
433 void rollback()
override {
436 assert(block &&
"expected block");
441 blockList.insert(before, block);
445 void commit(RewriterBase &rewriter)
override {
446 assert(block &&
"expected block");
450 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
451 notifyIRErased(listener, *block);
454 void cleanup(RewriterBase &rewriter)
override {
456 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
458 assert(block->empty() &&
"expected empty block");
461 block->dropAllDefinedValueUses();
472 Block *insertBeforeBlock;
478class InlineBlockRewrite :
public BlockRewrite {
480 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
482 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
483 sourceBlock(sourceBlock),
484 firstInlinedInst(sourceBlock->empty() ?
nullptr
485 : &sourceBlock->front()),
486 lastInlinedInst(sourceBlock->empty() ?
nullptr : &sourceBlock->back()) {
492 assert(!getConfig().listener &&
493 "InlineBlockRewrite not supported if listener is attached");
496 static bool classof(
const IRRewrite *
rewrite) {
497 return rewrite->getKind() == Kind::InlineBlock;
500 void rollback()
override {
503 if (firstInlinedInst) {
504 assert(lastInlinedInst &&
"expected operation");
517 Operation *firstInlinedInst;
520 Operation *lastInlinedInst;
524class MoveBlockRewrite :
public BlockRewrite {
526 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
528 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block),
529 region(previousRegion),
530 insertBeforeBlock(previousIt == previousRegion->end() ?
nullptr
533 static bool classof(
const IRRewrite *
rewrite) {
534 return rewrite->getKind() == Kind::MoveBlock;
537 void commit(RewriterBase &rewriter)
override {
547 void rollback()
override {
551 if (Region *currentParent = block->
getParent()) {
553 region->getBlocks().splice(before, currentParent->getBlocks(), block);
557 region->
getBlocks().insert(before, block);
566 Block *insertBeforeBlock;
570class BlockTypeConversionRewrite :
public BlockRewrite {
572 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
574 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, origBlock),
575 newBlock(newBlock) {}
577 static bool classof(
const IRRewrite *
rewrite) {
578 return rewrite->getKind() == Kind::BlockTypeConversion;
581 Block *getOrigBlock()
const {
return block; }
583 Block *getNewBlock()
const {
return newBlock; }
585 void commit(RewriterBase &rewriter)
override;
587 void rollback()
override;
597class ReplaceValueRewrite :
public ValueRewrite {
599 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
600 const TypeConverter *converter)
601 : ValueRewrite(
Kind::ReplaceValue, rewriterImpl, value),
602 converter(converter) {}
604 static bool classof(
const IRRewrite *
rewrite) {
605 return rewrite->getKind() == Kind::ReplaceValue;
608 void commit(RewriterBase &rewriter)
override;
610 void rollback()
override;
614 const TypeConverter *converter;
618class OperationRewrite :
public IRRewrite {
621 Operation *getOperation()
const {
return op; }
623 static bool classof(
const IRRewrite *
rewrite) {
624 return rewrite->getKind() >= Kind::MoveOperation &&
625 rewrite->getKind() <= Kind::UnresolvedMaterialization;
629 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
631 : IRRewrite(kind, rewriterImpl), op(op) {}
638class MoveOperationRewrite :
public OperationRewrite {
640 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
641 Operation *op, OpBuilder::InsertPoint previous)
642 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op),
643 block(previous.getBlock()),
644 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
646 : &*previous.getPoint()) {}
648 static bool classof(
const IRRewrite *
rewrite) {
649 return rewrite->getKind() == Kind::MoveOperation;
652 void commit(RewriterBase &rewriter)
override {
658 op, OpBuilder::InsertPoint(block,
663 void rollback()
override {
676 Operation *insertBeforeOp;
681class ModifyOperationRewrite :
public OperationRewrite {
683 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
685 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
686 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
687 operands(op->operand_begin(), op->operand_end()),
688 successors(op->successor_begin(), op->successor_end()) {
691 propertiesStorage = operator new(op->getPropertiesStorageSize());
692 OpaqueProperties propCopy(propertiesStorage);
693 name.initOpProperties(propCopy, prop);
697 static bool classof(
const IRRewrite *
rewrite) {
698 return rewrite->getKind() == Kind::ModifyOperation;
701 ~ModifyOperationRewrite()
override {
702 assert(!propertiesStorage &&
703 "rewrite was neither committed nor rolled back");
706 void commit(RewriterBase &rewriter)
override {
709 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
712 if (propertiesStorage) {
713 OpaqueProperties propCopy(propertiesStorage);
717 operator delete(propertiesStorage);
718 propertiesStorage =
nullptr;
722 void rollback()
override {
726 for (
const auto &it : llvm::enumerate(successors))
728 if (propertiesStorage) {
729 OpaqueProperties propCopy(propertiesStorage);
732 operator delete(propertiesStorage);
733 propertiesStorage =
nullptr;
740 DictionaryAttr attrs;
741 SmallVector<Value, 8> operands;
742 SmallVector<Block *, 2> successors;
743 void *propertiesStorage =
nullptr;
750class ReplaceOperationRewrite :
public OperationRewrite {
752 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
753 Operation *op,
const TypeConverter *converter)
754 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
755 converter(converter) {}
757 static bool classof(
const IRRewrite *
rewrite) {
758 return rewrite->getKind() == Kind::ReplaceOperation;
761 void commit(RewriterBase &rewriter)
override;
763 void rollback()
override;
765 void cleanup(RewriterBase &rewriter)
override;
770 const TypeConverter *converter;
773class CreateOperationRewrite :
public OperationRewrite {
775 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
777 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
779 static bool classof(
const IRRewrite *
rewrite) {
780 return rewrite->getKind() == Kind::CreateOperation;
783 void commit(RewriterBase &rewriter)
override {
789 void rollback()
override;
793enum MaterializationKind {
804class UnresolvedMaterializationInfo {
806 UnresolvedMaterializationInfo() =
default;
807 UnresolvedMaterializationInfo(
const TypeConverter *converter,
808 MaterializationKind kind, Type originalType)
809 : converterAndKind(converter, kind), originalType(originalType) {}
812 const TypeConverter *getConverter()
const {
813 return converterAndKind.getPointer();
817 MaterializationKind getMaterializationKind()
const {
818 return converterAndKind.getInt();
822 Type getOriginalType()
const {
return originalType; }
827 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
838class UnresolvedMaterializationRewrite :
public OperationRewrite {
840 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
841 UnrealizedConversionCastOp op,
843 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
844 mappedValues(std::move(mappedValues)) {}
846 static bool classof(
const IRRewrite *
rewrite) {
847 return rewrite->getKind() == Kind::UnresolvedMaterialization;
850 void rollback()
override;
852 UnrealizedConversionCastOp getOperation()
const {
853 return cast<UnrealizedConversionCastOp>(op);
863#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
866template <
typename RewriteTy,
typename R>
867static bool hasRewrite(R &&rewrites, Operation *op) {
868 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
869 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
870 return rewriteTy && rewriteTy->getOperation() == op;
876template <
typename RewriteTy,
typename R>
877static bool hasRewrite(R &&rewrites,
Block *block) {
878 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
879 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
880 return rewriteTy && rewriteTy->getBlock() == block;
892 const ConversionConfig &
config,
902 RewriterState getCurrentState();
906 void applyRewrites();
911 void resetState(RewriterState state, StringRef patternName =
"");
915 template <
typename RewriteTy,
typename... Args>
917 assert(
config.allowPatternRollback &&
"appending rewrites is not allowed");
919 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
925 void undoRewrites(
unsigned numRewritesToKeep = 0, StringRef patternName =
"");
931 LogicalResult remapValues(StringRef valueDiagTag,
932 std::optional<Location> inputLoc,
ValueRange values,
949 bool skipPureTypeConversions =
false)
const;
963 TypeConverter::SignatureConversion *entryConversion);
971 Block *applySignatureConversion(
973 TypeConverter::SignatureConversion &signatureConversion);
993 void eraseBlock(
Block *block);
1031 Value findOrBuildReplacementValue(
Value value,
1039 void notifyOperationInserted(
Operation *op,
1043 void notifyBlockInserted(
Block *block,
Region *previous,
1062 std::function<
void(
Operation *)> opErasedCallback =
nullptr)
1064 opErasedCallback(std::move(opErasedCallback)) {}
1078 assert(block->empty() &&
"expected empty block");
1079 block->dropAllDefinedValueUses();
1087 if (opErasedCallback)
1088 opErasedCallback(op);
1196const ConversionConfig &IRRewrite::getConfig()
const {
1197 return rewriterImpl.
config;
1200void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1204 if (
auto *listener =
1205 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1206 for (Operation *op : getNewBlock()->getUsers())
1210void BlockTypeConversionRewrite::rollback() {
1211 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1218 if (isa<BlockArgument>(repl)) {
1258 result &= functor(operand);
1263void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1270void ReplaceValueRewrite::rollback() {
1271 rewriterImpl.
mapping.erase({value});
1277void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1279 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1282 SmallVector<Value> replacements =
1284 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1292 for (
auto [
result, newValue] :
1293 llvm::zip_equal(op->
getResults(), replacements))
1299 if (getConfig().unlegalizedOps)
1300 getConfig().unlegalizedOps->erase(op);
1304 notifyIRErased(listener, *op);
1311void ReplaceOperationRewrite::rollback() {
1316void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1320void CreateOperationRewrite::rollback() {
1322 while (!region.getBlocks().empty())
1323 region.getBlocks().remove(region.getBlocks().begin());
1329void UnresolvedMaterializationRewrite::rollback() {
1330 if (!mappedValues.empty())
1331 rewriterImpl.
mapping.erase(mappedValues);
1342 for (
size_t i = 0; i <
rewrites.size(); ++i)
1348 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1349 unresolvedMaterializations.erase(castOp);
1352 rewrite->cleanup(eraseRewriter);
1360 Value from,
TypeRange desiredTypes,
bool skipPureTypeConversions)
const {
1363 assert(!values.empty() &&
"expected non-empty value vector");
1367 if (
config.allowPatternRollback)
1368 return mapping.lookup(values);
1375 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1380 if (castOp.getOutputs() != values)
1382 return castOp.getInputs();
1391 for (
Value v : values) {
1394 llvm::append_range(next, r);
1399 if (next != values) {
1428 if (skipPureTypeConversions) {
1431 match &= !pureConversion;
1434 if (!pureConversion)
1435 lastNonMaterialization = current;
1438 desiredValue = current;
1444 current = std::move(next);
1449 if (!desiredTypes.empty())
1450 return desiredValue;
1451 if (skipPureTypeConversions)
1452 return lastNonMaterialization;
1471 StringRef patternName) {
1476 while (
ignoredOps.size() != state.numIgnoredOperations)
1479 while (
replacedOps.size() != state.numReplacedOps)
1484 StringRef patternName) {
1486 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1488 rewrites.resize(numRewritesToKeep);
1492 StringRef valueDiagTag, std::optional<Location> inputLoc,
ValueRange values,
1494 remapped.reserve(llvm::size(values));
1496 for (
const auto &it : llvm::enumerate(values)) {
1497 Value operand = it.value();
1516 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1517 << it.index() <<
", type was " << origType;
1522 if (legalTypes.empty()) {
1523 remapped.push_back({});
1532 remapped.push_back(std::move(repl));
1541 repl, repl, legalTypes,
1543 remapped.push_back(castValues);
1564 TypeConverter::SignatureConversion *entryConversion) {
1566 if (region->
empty())
1571 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1573 std::optional<TypeConverter::SignatureConversion> conversion =
1574 converter.convertBlockSignature(&block);
1583 if (entryConversion)
1586 std::optional<TypeConverter::SignatureConversion> conversion =
1587 converter.convertBlockSignature(®ion->
front());
1595 TypeConverter::SignatureConversion &signatureConversion) {
1596#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1598 if (hasRewrite<BlockTypeConversionRewrite>(
rewrites, block))
1599 llvm::report_fatal_error(
"block was already converted");
1606 auto convertedTypes = signatureConversion.getConvertedTypes();
1613 for (
unsigned i = 0; i < origArgCount; ++i) {
1614 auto inputMap = signatureConversion.getInputMapping(i);
1615 if (!inputMap || inputMap->replacedWithValues())
1618 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1619 newLocs[inputMap->inputNo +
j] = origLoc;
1626 convertedTypes, newLocs);
1634 bool fastPath = !
config.listener;
1636 if (
config.allowPatternRollback)
1640 while (!block->
empty())
1647 for (
unsigned i = 0; i != origArgCount; ++i) {
1651 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1652 signatureConversion.getInputMapping(i);
1660 MaterializationKind::Source,
1664 origArgType,
Type(), converter,
1671 if (inputMap->replacedWithValues()) {
1673 assert(inputMap->size == 0 &&
1674 "invalid to provide a replacement value when the argument isn't "
1682 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1686 if (
config.allowPatternRollback)
1707 assert((!originalType || kind == MaterializationKind::Target) &&
1708 "original type is valid only for target materializations");
1709 assert(
TypeRange(inputs) != outputTypes &&
1710 "materialization is not necessary");
1714 OpBuilder builder(outputTypes.front().getContext());
1716 UnrealizedConversionCastOp convertOp =
1717 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1718 if (
config.attachDebugMaterializationKind) {
1720 kind == MaterializationKind::Source ?
"source" :
"target";
1721 convertOp->setAttr(
"__kind__", builder.
getStringAttr(kindStr));
1728 UnresolvedMaterializationInfo(converter, kind, originalType);
1729 if (
config.allowPatternRollback) {
1730 if (!valuesToMap.empty())
1731 mapping.map(valuesToMap, convertOp.getResults());
1733 std::move(valuesToMap));
1737 return convertOp.getResults();
1742 assert(
config.allowPatternRollback &&
1743 "this code path is valid only in rollback mode");
1750 return repl.front();
1757 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1782 MaterializationKind::Source, ip, value.
getLoc(),
1798 bool wasDetached = !previous.
isSet();
1800 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"' (" << op
1803 logger.getOStream() <<
" (was detached)";
1804 logger.getOStream() <<
"\n";
1810 "attempting to insert into a block within a replaced/erased op");
1814 config.listener->notifyOperationInserted(op, previous);
1823 if (
config.allowPatternRollback) {
1837 if (
config.allowPatternRollback)
1847 assert(!
impl.config.allowPatternRollback &&
1848 "this code path is valid only in 'no rollback' mode");
1850 for (
auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1853 repls.push_back(
Value());
1860 Value srcMat =
impl.buildUnresolvedMaterialization(
1865 repls.push_back(srcMat);
1871 repls.push_back(to[0]);
1880 Value srcMat =
impl.buildUnresolvedMaterialization(
1883 Type(), converter)[0];
1884 repls.push_back(srcMat);
1893 "incorrect number of replacement values");
1895 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
1903 for (
auto [
result, repls] :
1904 llvm::zip_equal(op->
getResults(), newValues)) {
1906 auto logProlog = [&, repls = repls]() {
1907 logger.startLine() <<
" Note: Replacing op result of type "
1908 << resultType <<
" with value(s) of type (";
1909 llvm::interleaveComma(repls,
logger.getOStream(), [&](
Value v) {
1910 logger.getOStream() << v.getType();
1912 logger.getOStream() <<
")";
1918 logger.getOStream() <<
", but the type converter failed to legalize "
1919 "the original type.\n";
1924 logger.getOStream() <<
", but the legalized type(s) is/are (";
1925 llvm::interleaveComma(convertedTypes,
logger.getOStream(),
1926 [&](
Type t) { logger.getOStream() << t; });
1927 logger.getOStream() <<
")\n";
1933 if (!
config.allowPatternRollback) {
1942 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1948 if (
config.unlegalizedOps)
1949 config.unlegalizedOps->erase(op);
1957 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1961 "attempting to replace a value that was already replaced");
1966 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1971 "attempting to replace/erase an unresolved materialization");
1987 logger.startLine() <<
"** Replace Value : '" << from <<
"'";
1988 if (
auto blockArg = dyn_cast<BlockArgument>(from)) {
1990 logger.getOStream() <<
" (in region of '" << parentOp->getName()
1991 <<
"' (" << parentOp <<
")";
1993 logger.getOStream() <<
" (unlinked block)";
1997 logger.getOStream() <<
", conditional replacement";
2001 if (!
config.allowPatternRollback) {
2006 Value repl = repls.front();
2023 "attempting to replace a value that was already replaced");
2025 "attempting to replace a op result that was already replaced");
2030 llvm::report_fatal_error(
2031 "conditional value replacement is not supported in rollback mode");
2037 if (!
config.allowPatternRollback) {
2044 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2050 if (
config.unlegalizedOps)
2051 config.unlegalizedOps->erase(op);
2060 "attempting to erase a block within a replaced/erased op");
2076 bool wasDetached = !previous;
2082 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
2083 <<
"' (" << parent <<
")";
2086 <<
"** Insert Block into detached Region (nullptr parent op)";
2089 logger.getOStream() <<
" (was detached)";
2090 logger.getOStream() <<
"\n";
2096 "attempting to insert into a region within a replaced/erased op");
2101 config.listener->notifyBlockInserted(block, previous, previousIt);
2105 if (
config.allowPatternRollback) {
2119 if (
config.allowPatternRollback)
2133 reasonCallback(
diag);
2134 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
2135 if (
config.notifyCallback)
2144ConversionPatternRewriter::ConversionPatternRewriter(
2148 *this,
config, opConverter)) {
2149 setListener(
impl.get());
2152ConversionPatternRewriter::~ConversionPatternRewriter() =
default;
2154const ConversionConfig &ConversionPatternRewriter::getConfig()
const {
2155 return impl->config;
2159 assert(op && newOp &&
"expected non-null op");
2163void ConversionPatternRewriter::replaceOp(Operation *op,
ValueRange newValues) {
2165 "incorrect # of replacement values");
2169 if (getInsertionPoint() == op->getIterator())
2172 SmallVector<SmallVector<Value>> newVals =
2173 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2174 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2176 impl->replaceOp(op, std::move(newVals));
2179void ConversionPatternRewriter::replaceOpWithMultiple(
2180 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2182 "incorrect # of replacement values");
2186 if (getInsertionPoint() == op->getIterator())
2189 impl->replaceOp(op, std::move(newValues));
2192void ConversionPatternRewriter::eraseOp(Operation *op) {
2194 impl->logger.startLine()
2195 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
2200 if (getInsertionPoint() == op->getIterator())
2203 SmallVector<SmallVector<Value>> nullRepls(op->
getNumResults(), {});
2204 impl->replaceOp(op, std::move(nullRepls));
2207void ConversionPatternRewriter::eraseBlock(
Block *block) {
2208 impl->eraseBlock(block);
2211Block *ConversionPatternRewriter::applySignatureConversion(
2212 Block *block, TypeConverter::SignatureConversion &conversion,
2213 const TypeConverter *converter) {
2214 assert(!impl->wasOpReplaced(block->
getParentOp()) &&
2215 "attempting to apply a signature conversion to a block within a "
2216 "replaced/erased op");
2217 return impl->applySignatureConversion(block, converter, conversion);
2220FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2221 Region *region,
const TypeConverter &converter,
2222 TypeConverter::SignatureConversion *entryConversion) {
2223 assert(!impl->wasOpReplaced(region->
getParentOp()) &&
2224 "attempting to apply a signature conversion to a block within a "
2225 "replaced/erased op");
2226 return impl->convertRegionTypes(region, converter, entryConversion);
2229void ConversionPatternRewriter::replaceAllUsesWith(Value from,
ValueRange to) {
2230 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2233void ConversionPatternRewriter::replaceUsesWithIf(
2235 bool *allUsesReplaced) {
2236 assert(!allUsesReplaced &&
2237 "allUsesReplaced is not supported in a dialect conversion");
2238 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2241Value ConversionPatternRewriter::getRemappedValue(Value key) {
2242 SmallVector<ValueVector> remappedValues;
2243 if (
failed(impl->remapValues(
"value", std::nullopt, key,
2246 assert(remappedValues.front().size() == 1 &&
"1:N conversion not supported");
2247 return remappedValues.front().front();
2251ConversionPatternRewriter::getRemappedValues(
ValueRange keys,
2252 SmallVectorImpl<Value> &results) {
2255 SmallVector<ValueVector> remapped;
2256 if (
failed(impl->remapValues(
"value", std::nullopt, keys,
2259 for (
const auto &values : remapped) {
2260 assert(values.size() == 1 &&
"1:N conversion not supported");
2261 results.push_back(values.front());
2266void ConversionPatternRewriter::inlineBlockBefore(
Block *source,
Block *dest,
2271 "incorrect # of argument replacement values");
2272 assert(!impl->wasOpReplaced(source->
getParentOp()) &&
2273 "attempting to inline a block from a replaced/erased op");
2274 assert(!impl->wasOpReplaced(dest->
getParentOp()) &&
2275 "attempting to inline a block into a replaced/erased op");
2276 auto opIgnored = [&](Operation *op) {
return impl->isOpIgnored(op); };
2279 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
2280 "expected 'source' to have no predecessors");
2289 bool fastPath = !getConfig().listener;
2291 if (fastPath && impl->config.allowPatternRollback)
2292 impl->inlineBlockBefore(source, dest, before);
2295 for (
auto it : llvm::zip(source->
getArguments(), argValues))
2296 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2303 while (!source->
empty())
2304 moveOpBefore(&source->
front(), dest, before);
2309 if (getInsertionBlock() == source)
2310 setInsertionPoint(dest, getInsertionPoint());
2316void ConversionPatternRewriter::startOpModification(Operation *op) {
2317 if (!impl->config.allowPatternRollback) {
2322 assert(!impl->wasOpReplaced(op) &&
2323 "attempting to modify a replaced/erased op");
2325 impl->pendingRootUpdates.insert(op);
2327 impl->appendRewrite<ModifyOperationRewrite>(op);
2330void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2331 impl->patternModifiedOps.insert(op);
2332 if (!impl->config.allowPatternRollback) {
2334 if (getConfig().listener)
2335 getConfig().listener->notifyOperationModified(op);
2342 assert(!impl->wasOpReplaced(op) &&
2343 "attempting to modify a replaced/erased op");
2344 assert(impl->pendingRootUpdates.erase(op) &&
2345 "operation did not have a pending in-place update");
2349void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2350 if (!impl->config.allowPatternRollback) {
2355 assert(impl->pendingRootUpdates.erase(op) &&
2356 "operation did not have a pending in-place update");
2359 auto it = llvm::find_if(
2360 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
2361 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2362 return modifyRewrite && modifyRewrite->getOperation() == op;
2364 assert(it != impl->rewrites.rend() &&
"no root update started on op");
2366 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2367 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2370detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2378FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2379 ArrayRef<ValueRange> operands)
const {
2380 SmallVector<Value> oneToOneOperands;
2381 oneToOneOperands.reserve(operands.size());
2383 if (operand.size() != 1)
2386 oneToOneOperands.push_back(operand.front());
2388 return std::move(oneToOneOperands);
2392ConversionPattern::matchAndRewrite(Operation *op,
2393 PatternRewriter &rewriter)
const {
2394 auto &dialectRewriter =
static_cast<ConversionPatternRewriter &
>(rewriter);
2395 auto &rewriterImpl = dialectRewriter.getImpl();
2399 getTypeConverter());
2402 SmallVector<ValueVector> remapped;
2407 SmallVector<ValueRange> remappedAsRange =
2408 llvm::to_vector_of<ValueRange>(remapped);
2409 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2418using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2421class OperationLegalizer {
2423 using LegalizationAction = ConversionTarget::LegalizationAction;
2425 OperationLegalizer(ConversionPatternRewriter &rewriter,
2426 const ConversionTarget &targetInfo,
2427 const FrozenRewritePatternSet &
patterns);
2430 bool isIllegal(Operation *op)
const;
2434 LogicalResult legalize(Operation *op);
2437 const ConversionTarget &getTarget() {
return target; }
2441 LogicalResult legalizeWithFold(Operation *op);
2445 LogicalResult legalizeWithPattern(Operation *op);
2449 bool canApplyPattern(Operation *op,
const Pattern &pattern);
2453 legalizePatternResult(Operation *op,
const Pattern &pattern,
2454 const RewriterState &curState,
2471 void buildLegalizationGraph(
2472 LegalizationPatterns &anyOpLegalizerPatterns,
2483 void computeLegalizationGraphBenefit(
2484 LegalizationPatterns &anyOpLegalizerPatterns,
2489 unsigned computeOpLegalizationDepth(
2496 unsigned applyCostModelToPatterns(
2502 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2505 ConversionPatternRewriter &rewriter;
2508 const ConversionTarget &
target;
2511 PatternApplicator applicator;
2515OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2516 const ConversionTarget &targetInfo,
2517 const FrozenRewritePatternSet &
patterns)
2522 LegalizationPatterns anyOpLegalizerPatterns;
2524 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2525 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2528bool OperationLegalizer::isIllegal(Operation *op)
const {
2529 return target.isIllegal(op);
2532LogicalResult OperationLegalizer::legalize(Operation *op) {
2534 const char *logLineComment =
2535 "//===-------------------------------------------===//\n";
2537 auto &logger = rewriter.getImpl().logger;
2541 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2544 logger.getOStream() <<
"\n";
2545 logger.startLine() << logLineComment;
2546 logger.startLine() <<
"Legalizing operation : ";
2551 logger.getOStream() <<
"'" << op->
getName() <<
"' ";
2552 logger.getOStream() <<
"(" << op <<
") {\n";
2557 logger.startLine() << OpWithFlags(op,
2558 OpPrintingFlags().printGenericOpForm())
2565 logSuccess(logger,
"operation marked 'ignored' during conversion");
2566 logger.startLine() << logLineComment;
2572 if (
auto legalityInfo =
target.isLegal(op)) {
2575 logger,
"operation marked legal by the target{0}",
2576 legalityInfo->isRecursivelyLegal
2577 ?
"; NOTE: operation is recursively legal; skipping internals"
2579 logger.startLine() << logLineComment;
2584 if (legalityInfo->isRecursivelyLegal) {
2585 op->
walk([&](Operation *nested) {
2587 rewriter.getImpl().ignoredOps.
insert(nested);
2596 const ConversionConfig &
config = rewriter.getConfig();
2597 if (
config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2598 if (succeeded(legalizeWithFold(op))) {
2601 logger.startLine() << logLineComment;
2608 if (succeeded(legalizeWithPattern(op))) {
2611 logger.startLine() << logLineComment;
2618 if (
config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2619 if (succeeded(legalizeWithFold(op))) {
2622 logger.startLine() << logLineComment;
2629 logFailure(logger,
"no matched legalization pattern");
2630 logger.startLine() << logLineComment;
2637template <
typename T>
2639 T
result = std::move(obj);
2644LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2645 auto &rewriterImpl = rewriter.getImpl();
2647 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2648 rewriterImpl.
logger.indent();
2653 llvm::scope_exit cleanup([&]() {
2663 SmallVector<Value, 2> replacementValues;
2664 SmallVector<Operation *, 2> newOps;
2667 if (
failed(rewriter.
tryFold(op, replacementValues, &newOps))) {
2676 if (replacementValues.empty())
2677 return legalize(op);
2680 rewriter.
replaceOp(op, replacementValues);
2683 for (Operation *newOp : newOps) {
2684 if (
failed(legalize(newOp))) {
2686 "failed to legalize generated constant '{0}'",
2688 if (!rewriter.getConfig().allowPatternRollback) {
2690 llvm::report_fatal_error(
2692 "' folder rollback of IR modifications requested");
2710 auto newOpNames = llvm::map_range(
2712 auto modifiedOpNames = llvm::map_range(
2714 llvm::report_fatal_error(
"pattern '" + pattern.
getDebugName() +
2715 "' produced IR that could not be legalized. " +
2716 "new ops: {" + llvm::join(newOpNames,
", ") +
"}, " +
2718 llvm::join(modifiedOpNames,
", ") +
"}");
2721LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2722 auto &rewriterImpl = rewriter.getImpl();
2723 const ConversionConfig &
config = rewriter.getConfig();
2725#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2727 std::optional<OperationFingerPrint> topLevelFingerPrint;
2728 if (!rewriterImpl.
config.allowPatternRollback) {
2735 topLevelFingerPrint = OperationFingerPrint(checkOp);
2741 rewriterImpl.
logger.startLine()
2742 <<
"WARNING: Multi-threadeding is enabled. Some dialect "
2743 "conversion expensive checks are skipped in multithreading "
2752 auto canApply = [&](
const Pattern &pattern) {
2753 bool canApply = canApplyPattern(op, pattern);
2754 if (canApply &&
config.listener)
2755 config.listener->notifyPatternBegin(pattern, op);
2761 auto onFailure = [&](
const Pattern &pattern) {
2763 if (!rewriterImpl.
config.allowPatternRollback) {
2770#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2772 if (checkOp && topLevelFingerPrint) {
2773 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2774 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2775 llvm::report_fatal_error(
"pattern '" + pattern.getDebugName() +
2776 "' returned failure but IR did change");
2784 if (rewriterImpl.
config.notifyCallback) {
2786 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2793 config.listener->notifyPatternEnd(pattern, failure());
2794 rewriterImpl.
resetState(curState, pattern.getDebugName());
2795 appliedPatterns.erase(&pattern);
2800 auto onSuccess = [&](
const Pattern &pattern) {
2802 if (!rewriterImpl.
config.allowPatternRollback) {
2816 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2817 appliedPatterns.erase(&pattern);
2819 if (!rewriterImpl.
config.allowPatternRollback)
2821 rewriterImpl.
resetState(curState, pattern.getDebugName());
2829 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2833bool OperationLegalizer::canApplyPattern(Operation *op,
2834 const Pattern &pattern) {
2836 auto &os = rewriter.getImpl().logger;
2837 os.getOStream() <<
"\n";
2838 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2840 os.getOStream() <<
")' {\n";
2847 !appliedPatterns.insert(&pattern).second) {
2849 logFailure(rewriter.getImpl().logger,
"pattern was already applied"));
2855LogicalResult OperationLegalizer::legalizePatternResult(
2856 Operation *op,
const Pattern &pattern,
const RewriterState &curState,
2859 [[maybe_unused]]
auto &impl = rewriter.getImpl();
2860 assert(impl.pendingRootUpdates.empty() &&
"dangling root updates");
2862#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2863 if (impl.config.allowPatternRollback) {
2865 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2866 auto replacedRoot = [&] {
2867 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2869 auto updatedRootInPlace = [&] {
2870 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2872 if (!replacedRoot() && !updatedRootInPlace())
2873 llvm::report_fatal_error(
"expected pattern to replace the root operation "
2874 "or modify it in place");
2879 if (
failed(legalizePatternRootUpdates(modifiedOps)) ||
2880 failed(legalizePatternCreatedOperations(newOps))) {
2884 LLVM_DEBUG(
logSuccess(impl.logger,
"pattern applied successfully"));
2888LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2890 for (Operation *op : newOps) {
2891 if (
failed(legalize(op))) {
2892 LLVM_DEBUG(
logFailure(rewriter.getImpl().logger,
2893 "failed to legalize generated operation '{0}'({1})",
2901LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2903 for (Operation *op : modifiedOps) {
2904 if (
failed(legalize(op))) {
2907 "failed to legalize operation updated in-place '{0}'",
2919void OperationLegalizer::buildLegalizationGraph(
2920 LegalizationPatterns &anyOpLegalizerPatterns,
2931 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2932 std::optional<OperationName> root = pattern.
getRootKind();
2938 anyOpLegalizerPatterns.push_back(&pattern);
2943 if (
target.getOpAction(*root) == LegalizationAction::Legal)
2948 invalidPatterns[*root].insert(&pattern);
2950 parentOps[op].insert(*root);
2953 patternWorklist.insert(&pattern);
2961 if (!anyOpLegalizerPatterns.empty()) {
2962 for (
const Pattern *pattern : patternWorklist)
2963 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2967 while (!patternWorklist.empty()) {
2968 auto *pattern = patternWorklist.pop_back_val();
2972 std::optional<LegalizationAction> action = target.getOpAction(op);
2973 return !legalizerPatterns.count(op) &&
2974 (!action || action == LegalizationAction::Illegal);
2980 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2981 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2985 for (
auto op : parentOps[*pattern->
getRootKind()])
2986 patternWorklist.set_union(invalidPatterns[op]);
2990void OperationLegalizer::computeLegalizationGraphBenefit(
2991 LegalizationPatterns &anyOpLegalizerPatterns,
2997 for (
auto &opIt : legalizerPatterns)
2998 if (!minOpPatternDepth.count(opIt.first))
2999 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3005 if (!anyOpLegalizerPatterns.empty())
3006 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3012 applicator.applyCostModel([&](
const Pattern &pattern) {
3013 ArrayRef<const Pattern *> orderedPatternList;
3014 if (std::optional<OperationName> rootName = pattern.
getRootKind())
3015 orderedPatternList = legalizerPatterns[*rootName];
3017 orderedPatternList = anyOpLegalizerPatterns;
3020 auto *it = llvm::find(orderedPatternList, &pattern);
3021 if (it == orderedPatternList.end())
3025 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3029unsigned OperationLegalizer::computeOpLegalizationDepth(
3033 auto depthIt = minOpPatternDepth.find(op);
3034 if (depthIt != minOpPatternDepth.end())
3035 return depthIt->second;
3039 auto opPatternsIt = legalizerPatterns.find(op);
3040 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3045 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3049 unsigned minDepth = applyCostModelToPatterns(
3050 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3051 minOpPatternDepth[op] = minDepth;
3055unsigned OperationLegalizer::applyCostModelToPatterns(
3059 unsigned minDepth = std::numeric_limits<unsigned>::max();
3062 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3063 patternsByDepth.reserve(
patterns.size());
3064 for (
const Pattern *pattern :
patterns) {
3067 unsigned generatedOpDepth = computeOpLegalizationDepth(
3068 generatedOp, minOpPatternDepth, legalizerPatterns);
3069 depth = std::max(depth, generatedOpDepth + 1);
3071 patternsByDepth.emplace_back(pattern, depth);
3074 minDepth = std::min(minDepth, depth);
3079 if (patternsByDepth.size() == 1)
3083 llvm::stable_sort(patternsByDepth,
3084 [](
const std::pair<const Pattern *, unsigned> &
lhs,
3085 const std::pair<const Pattern *, unsigned> &
rhs) {
3088 if (
lhs.second !=
rhs.second)
3089 return lhs.second <
rhs.second;
3092 auto lhsBenefit =
lhs.first->getBenefit();
3093 auto rhsBenefit =
rhs.first->getBenefit();
3094 return lhsBenefit > rhsBenefit;
3099 for (
auto &patternIt : patternsByDepth)
3100 patterns.push_back(patternIt.first);
3114template <
typename RangeT>
3117 function_ref<
bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3126 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3127 if (castOp.getInputs().empty())
3130 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3133 if (inputCastOp.getOutputs() != castOp.getInputs())
3139 while (!worklist.empty()) {
3140 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3144 UnrealizedConversionCastOp nextCast = castOp;
3146 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3147 if (llvm::any_of(nextCast.getInputs(), [&](
Value v) {
3148 return v.getDefiningOp() == castOp;
3156 castOp.replaceAllUsesWith(nextCast.getInputs());
3159 nextCast = getInputCast(nextCast);
3169 auto markOpLive = [&](
Operation *rootOp) {
3171 worklist.push_back(rootOp);
3172 while (!worklist.empty()) {
3173 Operation *op = worklist.pop_back_val();
3174 if (liveOps.insert(op).second) {
3177 if (
auto castOp = v.
getDefiningOp<UnrealizedConversionCastOp>())
3178 if (isCastOpOfInterestFn(castOp))
3179 worklist.push_back(castOp);
3185 for (UnrealizedConversionCastOp op : castOps) {
3188 if (liveOps.contains(op.getOperation()))
3192 if (llvm::any_of(op->getUsers(), [&](
Operation *user) {
3193 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3194 return !castOp || !isCastOpOfInterestFn(castOp);
3200 for (UnrealizedConversionCastOp op : castOps) {
3201 if (liveOps.contains(op)) {
3203 if (remainingCastOps)
3204 remainingCastOps->push_back(op);
3215 ArrayRef<UnrealizedConversionCastOp> castOps,
3216 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3218 DenseSet<UnrealizedConversionCastOp> castOpSet;
3219 for (UnrealizedConversionCastOp op : castOps)
3220 castOpSet.insert(op);
3225 const DenseSet<UnrealizedConversionCastOp> &castOps,
3226 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3228 llvm::make_range(castOps.begin(), castOps.end()),
3229 [&](UnrealizedConversionCastOp castOp) {
3230 return castOps.contains(castOp);
3242 [&](UnrealizedConversionCastOp castOp) {
3243 return castOps.contains(castOp);
3260 const ConversionConfig &
config,
3261 OpConversionMode mode)
3271 template <
typename Fn>
3273 bool isRecursiveLegalization =
false);
3275 bool isRecursiveLegalization =
false) {
3277 ops, [&]() {}, isRecursiveLegalization);
3285 LogicalResult convert(
Operation *op,
bool isRecursiveLegalization =
false);
3291 ConversionPatternRewriter rewriter;
3294 OperationLegalizer opLegalizer;
3297 OpConversionMode mode;
3302 bool isRecursiveLegalization) {
3303 const ConversionConfig &
config = rewriter.getConfig();
3306 if (failed(opLegalizer.legalize(op))) {
3309 if (mode == OpConversionMode::Full) {
3310 if (!isRecursiveLegalization)
3318 if (mode == OpConversionMode::Partial) {
3319 if (opLegalizer.isIllegal(op)) {
3320 if (!isRecursiveLegalization)
3322 <<
"' that was explicitly marked illegal";
3325 if (
config.unlegalizedOps && !isRecursiveLegalization)
3326 config.unlegalizedOps->insert(op);
3328 }
else if (mode == OpConversionMode::Analysis) {
3332 if (
config.legalizableOps && !isRecursiveLegalization)
3333 config.legalizableOps->insert(op);
3340 UnrealizedConversionCastOp op,
3341 const UnresolvedMaterializationInfo &info) {
3342 assert(!op.use_empty() &&
3343 "expected that dead materializations have already been DCE'd");
3350 switch (info.getMaterializationKind()) {
3351 case MaterializationKind::Target:
3352 newMaterialization = converter->materializeTargetConversion(
3353 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3354 info.getOriginalType());
3356 case MaterializationKind::Source:
3357 assert(op->getNumResults() == 1 &&
"expected single result");
3358 Value sourceMat = converter->materializeSourceConversion(
3359 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3361 newMaterialization.push_back(sourceMat);
3364 if (!newMaterialization.empty()) {
3366 ValueRange newMaterializationRange(newMaterialization);
3367 assert(
TypeRange(newMaterializationRange) == op.getResultTypes() &&
3368 "materialization callback produced value of incorrect type");
3370 rewriter.
replaceOp(op, newMaterialization);
3376 <<
"failed to legalize unresolved materialization "
3378 << inputOperands.
getTypes() <<
") to ("
3379 << op.getResultTypes()
3380 <<
") that remained live after conversion";
3381 diag.attachNote(op->getUsers().begin()->getLoc())
3382 <<
"see existing live user here: " << *op->getUsers().begin();
3386template <
typename Fn>
3389 bool isRecursiveLegalization) {
3397 toConvert.push_back(op);
3400 auto legalityInfo =
target.isLegal(op);
3401 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3407 if (failed(
convert(op, isRecursiveLegalization))) {
3416LogicalResult ConversionPatternRewriter::legalize(
Operation *op) {
3417 return impl->opConverter.legalizeOperations(op,
3421LogicalResult ConversionPatternRewriter::legalize(
Region *r) {
3437 std::optional<TypeConverter::SignatureConversion> conversion =
3438 converter->convertBlockSignature(&r->front());
3441 applySignatureConversion(&r->front(), *conversion, converter);
3446 return impl->opConverter.legalizeOperations(ops,
3455 if (rewriterImpl.
config.allowPatternRollback) {
3479 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3483 if (rewriter.getConfig().buildMaterializations) {
3487 rewriter.getConfig().listener);
3488 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3489 auto it = materializations.find(castOp);
3490 assert(it != materializations.end() &&
"inconsistent state");
3504void TypeConverter::SignatureConversion::addInputs(
unsigned origInputNo,
3506 assert(!types.empty() &&
"expected valid types");
3507 remapInput(origInputNo, argTypes.size(), types.size());
3511void TypeConverter::SignatureConversion::addInputs(
ArrayRef<Type> types) {
3512 assert(!types.empty() &&
3513 "1->0 type remappings don't need to be added explicitly");
3514 argTypes.append(types.begin(), types.end());
3517void TypeConverter::SignatureConversion::remapInput(
unsigned origInputNo,
3518 unsigned newInputNo,
3519 unsigned newInputCount) {
3520 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3521 assert(newInputCount != 0 &&
"expected valid input count");
3522 remappedInputs[origInputNo] =
3523 InputMapping{newInputNo, newInputCount, {}};
3526void TypeConverter::SignatureConversion::remapInput(
3527 unsigned origInputNo, ArrayRef<Value> replacements) {
3528 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3529 remappedInputs[origInputNo] = InputMapping{
3531 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3542TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3543 SmallVectorImpl<Type> &results)
const {
3544 assert(typeOrValue &&
"expected non-null type");
3545 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3546 : cast<Type>(typeOrValue);
3548 std::shared_lock<
decltype(cacheMutex)> cacheReadLock(cacheMutex,
3551 cacheReadLock.lock();
3552 auto existingIt = cachedDirectConversions.find(t);
3553 if (existingIt != cachedDirectConversions.end()) {
3554 if (existingIt->second)
3555 results.push_back(existingIt->second);
3556 return success(existingIt->second !=
nullptr);
3558 auto multiIt = cachedMultiConversions.find(t);
3559 if (multiIt != cachedMultiConversions.end()) {
3560 results.append(multiIt->second.begin(), multiIt->second.end());
3566 size_t currentCount = results.size();
3570 auto isCacheable = [&](
int index) {
3571 int numberOfConversionsUntilContextAware =
3572 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3573 return index < numberOfConversionsUntilContextAware;
3576 std::unique_lock<
decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3579 for (
auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3580 const ConversionCallbackFn &converter = indexedConverter.value();
3581 std::optional<LogicalResult>
result = converter(typeOrValue, results);
3583 assert(results.size() == currentCount &&
3584 "failed type conversion should not change results");
3587 if (!isCacheable(indexedConverter.index()))
3590 cacheWriteLock.lock();
3591 if (!succeeded(*
result)) {
3592 assert(results.size() == currentCount &&
3593 "failed type conversion should not change results");
3594 cachedDirectConversions.try_emplace(t,
nullptr);
3597 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3598 if (newTypes.size() == 1)
3599 cachedDirectConversions.try_emplace(t, newTypes.front());
3601 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3607LogicalResult TypeConverter::convertType(Type t,
3608 SmallVectorImpl<Type> &results)
const {
3609 return convertTypeImpl(t, results);
3612LogicalResult TypeConverter::convertType(Value v,
3613 SmallVectorImpl<Type> &results)
const {
3614 return convertTypeImpl(v, results);
3617Type TypeConverter::convertType(Type t)
const {
3619 SmallVector<Type, 1> results;
3620 if (
failed(convertType(t, results)))
3624 return results.size() == 1 ? results.front() :
nullptr;
3627Type TypeConverter::convertType(Value v)
const {
3629 SmallVector<Type, 1> results;
3630 if (
failed(convertType(v, results)))
3634 return results.size() == 1 ? results.front() :
nullptr;
3638TypeConverter::convertTypes(
TypeRange types,
3639 SmallVectorImpl<Type> &results)
const {
3640 for (Type type : types)
3641 if (
failed(convertType(type, results)))
3647TypeConverter::convertTypes(
ValueRange values,
3648 SmallVectorImpl<Type> &results)
const {
3649 for (Value value : values)
3650 if (
failed(convertType(value, results)))
3655bool TypeConverter::isLegal(Type type)
const {
3656 return convertType(type) == type;
3659bool TypeConverter::isLegal(Value value)
const {
3660 return convertType(value) == value.
getType();
3663bool TypeConverter::isLegal(Operation *op)
const {
3667bool TypeConverter::isLegal(Region *region)
const {
3668 return llvm::all_of(
3672bool TypeConverter::isSignatureLegal(FunctionType ty)
const {
3673 if (!isLegal(ty.getInputs()))
3675 if (!isLegal(ty.getResults()))
3681TypeConverter::convertSignatureArg(
unsigned inputNo, Type type,
3682 SignatureConversion &
result)
const {
3684 SmallVector<Type, 1> convertedTypes;
3685 if (
failed(convertType(type, convertedTypes)))
3689 if (convertedTypes.empty())
3693 result.addInputs(inputNo, convertedTypes);
3697TypeConverter::convertSignatureArgs(
TypeRange types,
3698 SignatureConversion &
result,
3699 unsigned origInputOffset)
const {
3700 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3701 if (
failed(convertSignatureArg(origInputOffset + i, types[i],
result)))
3706TypeConverter::convertSignatureArg(
unsigned inputNo, Value value,
3707 SignatureConversion &
result)
const {
3709 SmallVector<Type, 1> convertedTypes;
3710 if (
failed(convertType(value, convertedTypes)))
3714 if (convertedTypes.empty())
3718 result.addInputs(inputNo, convertedTypes);
3722TypeConverter::convertSignatureArgs(
ValueRange values,
3723 SignatureConversion &
result,
3724 unsigned origInputOffset)
const {
3725 for (
unsigned i = 0, e = values.size(); i != e; ++i)
3726 if (
failed(convertSignatureArg(origInputOffset + i, values[i],
result)))
3731Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3732 Location loc, Type resultType,
3734 for (
const SourceMaterializationCallbackFn &fn :
3735 llvm::reverse(sourceMaterializations))
3736 if (Value
result = fn(builder, resultType, inputs, loc))
3741Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3742 Location loc, Type resultType,
3744 Type originalType)
const {
3745 SmallVector<Value>
result = materializeTargetConversion(
3746 builder, loc,
TypeRange(resultType), inputs, originalType);
3749 assert(
result.size() == 1 &&
"expected single result");
3753SmallVector<Value> TypeConverter::materializeTargetConversion(
3755 Type originalType)
const {
3756 for (
const TargetMaterializationCallbackFn &fn :
3757 llvm::reverse(targetMaterializations)) {
3758 SmallVector<Value>
result =
3759 fn(builder, resultTypes, inputs, loc, originalType);
3763 "callback produced incorrect number of values or values with "
3770std::optional<TypeConverter::SignatureConversion>
3771TypeConverter::convertBlockSignature(
Block *block)
const {
3774 return std::nullopt;
3781TypeConverter::AttributeConversionResult
3782TypeConverter::AttributeConversionResult::result(Attribute attr) {
3783 return AttributeConversionResult(attr, resultTag);
3786TypeConverter::AttributeConversionResult
3787TypeConverter::AttributeConversionResult::na() {
3788 return AttributeConversionResult(
nullptr, naTag);
3791TypeConverter::AttributeConversionResult
3792TypeConverter::AttributeConversionResult::abort() {
3793 return AttributeConversionResult(
nullptr, abortTag);
3796bool TypeConverter::AttributeConversionResult::hasResult()
const {
3797 return impl.getInt() == resultTag;
3800bool TypeConverter::AttributeConversionResult::isNa()
const {
3801 return impl.getInt() == naTag;
3804bool TypeConverter::AttributeConversionResult::isAbort()
const {
3805 return impl.getInt() == abortTag;
3808Attribute TypeConverter::AttributeConversionResult::getResult()
const {
3809 assert(hasResult() &&
"Cannot get result from N/A or abort");
3810 return impl.getPointer();
3813std::optional<Attribute>
3814TypeConverter::convertTypeAttribute(Type type, Attribute attr)
const {
3815 for (
const TypeAttributeConversionCallbackFn &fn :
3816 llvm::reverse(typeAttributeConversions)) {
3817 AttributeConversionResult res = fn(type, attr);
3818 if (res.hasResult())
3819 return res.getResult();
3821 return std::nullopt;
3823 return std::nullopt;
3832 ConversionPatternRewriter &rewriter) {
3833 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3838 TypeConverter::SignatureConversion
result(type.getNumInputs());
3840 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
result)) ||
3841 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3843 if (!funcOp.getFunctionBody().empty())
3844 rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(),
result,
3848 auto newType = FunctionType::get(rewriter.getContext(),
3849 result.getConvertedTypes(), newResults);
3851 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3860struct FunctionOpInterfaceSignatureConversion :
public ConversionPattern {
3861 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3863 const TypeConverter &converter,
3864 PatternBenefit benefit)
3865 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3868 matchAndRewrite(Operation *op, ArrayRef<Value> ,
3869 ConversionPatternRewriter &rewriter)
const override {
3870 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3875struct AnyFunctionOpInterfaceSignatureConversion
3876 :
public OpInterfaceConversionPattern<FunctionOpInterface> {
3877 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3880 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> ,
3881 ConversionPatternRewriter &rewriter)
const override {
3887FailureOr<Operation *>
3888mlir::convertOpResultTypes(Operation *op,
ValueRange operands,
3889 const TypeConverter &converter,
3890 ConversionPatternRewriter &rewriter) {
3891 assert(op &&
"Invalid op");
3892 Location loc = op->
getLoc();
3893 if (converter.isLegal(op))
3894 return rewriter.notifyMatchFailure(loc,
"op already legal");
3896 OperationState newOp(loc, op->
getName());
3897 newOp.addOperands(operands);
3899 SmallVector<Type> newResultTypes;
3901 return rewriter.notifyMatchFailure(loc,
"couldn't convert return types");
3903 newOp.addTypes(newResultTypes);
3904 newOp.addAttributes(op->
getAttrs());
3905 return rewriter.create(newOp);
3908void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3909 StringRef functionLikeOpName, RewritePatternSet &
patterns,
3910 const TypeConverter &converter, PatternBenefit benefit) {
3911 patterns.add<FunctionOpInterfaceSignatureConversion>(
3912 functionLikeOpName,
patterns.getContext(), converter, benefit);
3915void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3916 RewritePatternSet &
patterns,
const TypeConverter &converter,
3917 PatternBenefit benefit) {
3918 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3919 converter,
patterns.getContext(), benefit);
3926void ConversionTarget::setOpAction(OperationName op,
3927 LegalizationAction action) {
3928 legalOperations[op].action = action;
3931void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3932 LegalizationAction action) {
3933 for (StringRef dialect : dialectNames)
3934 legalDialects[dialect] = action;
3937auto ConversionTarget::getOpAction(OperationName op)
const
3938 -> std::optional<LegalizationAction> {
3939 std::optional<LegalizationInfo> info = getOpInfo(op);
3940 return info ? info->action : std::optional<LegalizationAction>();
3943auto ConversionTarget::isLegal(Operation *op)
const
3944 -> std::optional<LegalOpDetails> {
3945 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3947 return std::nullopt;
3950 auto isOpLegal = [&] {
3952 if (info->action == LegalizationAction::Dynamic) {
3953 std::optional<bool>
result = info->legalityFn(op);
3959 return info->action == LegalizationAction::Legal;
3962 return std::nullopt;
3965 LegalOpDetails legalityDetails;
3966 if (info->isRecursivelyLegal) {
3967 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3968 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3969 legalityDetails.isRecursivelyLegal =
3970 legalityFnIt->second(op).value_or(
true);
3972 legalityDetails.isRecursivelyLegal =
true;
3975 return legalityDetails;
3978bool ConversionTarget::isIllegal(Operation *op)
const {
3979 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3983 if (info->action == LegalizationAction::Dynamic) {
3984 std::optional<bool>
result = info->legalityFn(op);
3991 return info->action == LegalizationAction::Illegal;
3995 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
3996 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
4000 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
4002 if (std::optional<bool>
result = newCl(op))
4010void ConversionTarget::setLegalityCallback(
4011 OperationName name,
const DynamicLegalityCallbackFn &callback) {
4012 assert(callback &&
"expected valid legality callback");
4013 auto *infoIt = legalOperations.find(name);
4014 assert(infoIt != legalOperations.end() &&
4015 infoIt->second.action == LegalizationAction::Dynamic &&
4016 "expected operation to already be marked as dynamically legal");
4017 infoIt->second.legalityFn =
4021void ConversionTarget::markOpRecursivelyLegal(
4022 OperationName name,
const DynamicLegalityCallbackFn &callback) {
4023 auto *infoIt = legalOperations.find(name);
4024 assert(infoIt != legalOperations.end() &&
4025 infoIt->second.action != LegalizationAction::Illegal &&
4026 "expected operation to already be marked as legal");
4027 infoIt->second.isRecursivelyLegal =
true;
4030 std::move(opRecursiveLegalityFns[name]), callback);
4032 opRecursiveLegalityFns.erase(name);
4035void ConversionTarget::setLegalityCallback(
4036 ArrayRef<StringRef> dialects,
const DynamicLegalityCallbackFn &callback) {
4037 assert(callback &&
"expected valid legality callback");
4038 for (StringRef dialect : dialects)
4040 std::move(dialectLegalityFns[dialect]), callback);
4043void ConversionTarget::setLegalityCallback(
4044 const DynamicLegalityCallbackFn &callback) {
4045 assert(callback &&
"expected valid legality callback");
4049auto ConversionTarget::getOpInfo(OperationName op)
const
4050 -> std::optional<LegalizationInfo> {
4052 const auto *it = legalOperations.find(op);
4053 if (it != legalOperations.end())
4056 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4057 if (dialectIt != legalDialects.end()) {
4058 DynamicLegalityCallbackFn callback;
4059 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4060 if (dialectFn != dialectLegalityFns.end())
4061 callback = dialectFn->second;
4062 return LegalizationInfo{dialectIt->second,
false,
4066 if (unknownLegalityFn)
4067 return LegalizationInfo{LegalizationAction::Dynamic,
4068 false, unknownLegalityFn};
4069 return std::nullopt;
4072#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4077void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4078 auto &rewriterImpl =
4079 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4083void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4084 auto &rewriterImpl =
4085 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4091static FailureOr<SmallVector<Value>>
4092pdllConvertValues(ConversionPatternRewriter &rewriter,
ValueRange values) {
4093 SmallVector<Value> mappedValues;
4094 if (
failed(rewriter.getRemappedValues(values, mappedValues)))
4096 return std::move(mappedValues);
4099void mlir::registerConversionPDLFunctions(RewritePatternSet &
patterns) {
4100 patterns.getPDLPatterns().registerRewriteFunction(
4102 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4103 auto results = pdllConvertValues(
4104 static_cast<ConversionPatternRewriter &
>(rewriter), value);
4107 return results->front();
4109 patterns.getPDLPatterns().registerRewriteFunction(
4110 "convertValues", [](PatternRewriter &rewriter,
ValueRange values) {
4111 return pdllConvertValues(
4112 static_cast<ConversionPatternRewriter &
>(rewriter), values);
4114 patterns.getPDLPatterns().registerRewriteFunction(
4116 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4117 auto &rewriterImpl =
4118 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4119 if (
const TypeConverter *converter =
4121 if (Type newType = converter->convertType(type))
4127 patterns.getPDLPatterns().registerRewriteFunction(
4129 [](PatternRewriter &rewriter,
4130 TypeRange types) -> FailureOr<SmallVector<Type>> {
4131 auto &rewriterImpl =
4132 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4135 return SmallVector<Type>(types);
4137 SmallVector<Type> remappedTypes;
4138 if (
failed(converter->convertTypes(types, remappedTypes)))
4140 return std::move(remappedTypes);
4155 static constexpr StringLiteral
tag =
"apply-conversion";
4156 static constexpr StringLiteral
desc =
4157 "Encapsulate the application of a dialect conversion";
4166 OpConversionMode mode) {
4170 LogicalResult status =
success();
4186LogicalResult mlir::applyPartialConversion(
4187 ArrayRef<Operation *> ops,
const ConversionTarget &
target,
4188 const FrozenRewritePatternSet &
patterns, ConversionConfig
config) {
4190 OpConversionMode::Partial);
4193mlir::applyPartialConversion(Operation *op,
const ConversionTarget &
target,
4194 const FrozenRewritePatternSet &
patterns,
4195 ConversionConfig
config) {
4203LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4204 const ConversionTarget &
target,
4205 const FrozenRewritePatternSet &
patterns,
4206 ConversionConfig
config) {
4209LogicalResult mlir::applyFullConversion(Operation *op,
4210 const ConversionTarget &
target,
4211 const FrozenRewritePatternSet &
patterns,
4212 ConversionConfig
config) {
4230 "expected top-level op to be isolated from above");
4233 "expected ops to have a common ancestor");
4242 for (
Operation *op : ops.drop_front()) {
4246 assert(commonAncestor &&
4247 "expected to find a common isolated from above ancestor");
4251 return commonAncestor;
4254LogicalResult mlir::applyAnalysisConversion(
4255 ArrayRef<Operation *> ops, ConversionTarget &
target,
4256 const FrozenRewritePatternSet &
patterns, ConversionConfig
config) {
4258 if (
config.legalizableOps)
4259 assert(
config.legalizableOps->empty() &&
"expected empty set");
4265 Operation *clonedAncestor = commonAncestor->
clone(mapping);
4269 inverseOperationMap[it.second] = it.first;
4272 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4273 ops, [&](Operation *op) {
return mapping.
lookup(op); });
4275 OpConversionMode::Analysis);
4279 if (
config.legalizableOps) {
4281 for (Operation *op : *
config.legalizableOps)
4282 originalLegalizableOps.insert(inverseOperationMap[op]);
4283 *
config.legalizableOps = std::move(originalLegalizableOps);
4287 clonedAncestor->
erase();
4292mlir::applyAnalysisConversion(Operation *op, ConversionTarget &
target,
4293 const FrozenRewritePatternSet &
patterns,
4294 ConversionConfig
config) {
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
void destroyOpProperties(OpaqueProperties properties) const
This hooks destroy the op properties.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
CRTP Implementation of an action.
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This iterator enumerates elements according to their dominance relationship.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
const ConversionTarget & getTarget()
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, bool isRecursiveLegalization=false)
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, Fn onFailure, bool isRecursiveLegalization=false)
Legalizes the given operations (and their nested operations) to the conversion target.
LogicalResult applyConversion(ArrayRef< Operation * > ops)
Applies the conversion to the given operations (and their nested operations).
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
bool wasErased(void *ptr) const
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.