10#include "mlir/Config/mlir-config.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/FormatVariadic.h"
25#include "llvm/Support/SaveAndRestore.h"
26#include "llvm/Support/ScopedPrinter.h"
33#define DEBUG_TYPE "dialect-conversion"
36template <
typename... Args>
37static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
40 os.startLine() <<
"} -> SUCCESS";
42 os.getOStream() <<
" : "
43 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
44 os.getOStream() <<
"\n";
49template <
typename... Args>
50static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
53 os.startLine() <<
"} -> FAILURE : "
54 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
64 if (
OpResult inputRes = dyn_cast<OpResult>(value))
65 insertPt = ++inputRes.getOwner()->getIterator();
72 assert(!vals.empty() &&
"expected at least one value");
75 for (
Value v : vals.drop_front()) {
89 assert(dom &&
"unable to find valid insertion point");
97enum OpConversionMode {
123struct ValueVectorMapInfo {
126 static ::llvm::hash_code getHashValue(
const ValueVector &val) {
127 return ::llvm::hash_combine_range(val);
136struct ConversionValueMapping {
139 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
144 template <
typename T>
145 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
148 template <
typename OldVal,
typename NewVal>
149 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
150 map(OldVal &&oldVal, NewVal &&newVal) {
154 assert(next != oldVal &&
"inserting cyclic mapping");
155 auto it = mapping.find(next);
156 if (it == mapping.end())
161 mappedTo.insert_range(newVal);
163 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
167 template <
typename OldVal,
typename NewVal>
168 std::enable_if_t<!IsValueVector<OldVal>::value ||
169 !IsValueVector<NewVal>::value>
170 map(OldVal &&oldVal, NewVal &&newVal) {
171 if constexpr (IsValueVector<OldVal>{}) {
172 map(std::forward<OldVal>(oldVal),
ValueVector{newVal});
173 }
else if constexpr (IsValueVector<NewVal>{}) {
174 map(
ValueVector{oldVal}, std::forward<NewVal>(newVal));
185 void erase(
const ValueVector &value) { mapping.erase(value); }
205 assert(!values.empty() &&
"expected non-empty value vector");
206 Operation *op = values.front().getDefiningOp();
207 for (
Value v : llvm::drop_begin(values)) {
208 if (v.getDefiningOp() != op)
218 assert(!values.empty() &&
"expected non-empty value vector");
224 auto it = mapping.find(from);
225 if (it == mapping.end()) {
238struct RewriterState {
239 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
240 unsigned numReplacedOps)
241 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
242 numReplacedOps(numReplacedOps) {}
245 unsigned numRewrites;
248 unsigned numIgnoredOperations;
251 unsigned numReplacedOps;
258static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
261static void notifyIRErased(RewriterBase::Listener *listener,
Block &
b) {
262 for (Operation &op :
b)
263 notifyIRErased(listener, op);
269static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
272 notifyIRErased(listener,
b);
302 UnresolvedMaterialization,
307 virtual ~IRRewrite() =
default;
310 virtual void rollback() = 0;
324 virtual void commit(RewriterBase &rewriter) {}
327 virtual void cleanup(RewriterBase &rewriter) {}
329 Kind getKind()
const {
return kind; }
331 static bool classof(
const IRRewrite *
rewrite) {
return true; }
334 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
335 : kind(kind), rewriterImpl(rewriterImpl) {}
337 const ConversionConfig &getConfig()
const;
340 ConversionPatternRewriterImpl &rewriterImpl;
344class BlockRewrite :
public IRRewrite {
347 Block *getBlock()
const {
return block; }
349 static bool classof(
const IRRewrite *
rewrite) {
350 return rewrite->getKind() >= Kind::CreateBlock &&
351 rewrite->getKind() <= Kind::BlockTypeConversion;
355 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
357 : IRRewrite(kind, rewriterImpl), block(block) {}
364class ValueRewrite :
public IRRewrite {
367 Value getValue()
const {
return value; }
369 static bool classof(
const IRRewrite *
rewrite) {
370 return rewrite->getKind() == Kind::ReplaceValue;
374 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
376 : IRRewrite(kind, rewriterImpl), value(value) {}
385class CreateBlockRewrite :
public BlockRewrite {
387 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
388 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
390 static bool classof(
const IRRewrite *
rewrite) {
391 return rewrite->getKind() == Kind::CreateBlock;
394 void commit(RewriterBase &rewriter)
override {
400 void rollback()
override {
403 auto &blockOps = block->getOperations();
404 while (!blockOps.empty())
405 blockOps.remove(blockOps.begin());
406 block->dropAllUses();
407 if (block->getParent())
418class EraseBlockRewrite :
public BlockRewrite {
420 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block)
421 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
422 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
424 static bool classof(
const IRRewrite *
rewrite) {
425 return rewrite->getKind() == Kind::EraseBlock;
428 ~EraseBlockRewrite()
override {
430 "rewrite was neither rolled back nor committed/cleaned up");
433 void rollback()
override {
436 assert(block &&
"expected block");
441 blockList.insert(before, block);
445 void commit(RewriterBase &rewriter)
override {
446 assert(block &&
"expected block");
450 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
451 notifyIRErased(listener, *block);
454 void cleanup(RewriterBase &rewriter)
override {
456 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
458 assert(block->empty() &&
"expected empty block");
461 block->dropAllDefinedValueUses();
472 Block *insertBeforeBlock;
478class InlineBlockRewrite :
public BlockRewrite {
480 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
482 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
483 sourceBlock(sourceBlock),
484 firstInlinedInst(sourceBlock->empty() ?
nullptr
485 : &sourceBlock->front()),
486 lastInlinedInst(sourceBlock->empty() ?
nullptr : &sourceBlock->back()) {
492 assert(!getConfig().listener &&
493 "InlineBlockRewrite not supported if listener is attached");
496 static bool classof(
const IRRewrite *
rewrite) {
497 return rewrite->getKind() == Kind::InlineBlock;
500 void rollback()
override {
503 if (firstInlinedInst) {
504 assert(lastInlinedInst &&
"expected operation");
517 Operation *firstInlinedInst;
520 Operation *lastInlinedInst;
524class MoveBlockRewrite :
public BlockRewrite {
526 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block,
528 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block),
529 region(previousRegion),
530 insertBeforeBlock(previousIt == previousRegion->end() ?
nullptr
533 static bool classof(
const IRRewrite *
rewrite) {
534 return rewrite->getKind() == Kind::MoveBlock;
537 void commit(RewriterBase &rewriter)
override {
547 void rollback()
override {
560 Block *insertBeforeBlock;
564class BlockTypeConversionRewrite :
public BlockRewrite {
566 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
568 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, origBlock),
569 newBlock(newBlock) {}
571 static bool classof(
const IRRewrite *
rewrite) {
572 return rewrite->getKind() == Kind::BlockTypeConversion;
575 Block *getOrigBlock()
const {
return block; }
577 Block *getNewBlock()
const {
return newBlock; }
579 void commit(RewriterBase &rewriter)
override;
581 void rollback()
override;
591class ReplaceValueRewrite :
public ValueRewrite {
593 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
594 const TypeConverter *converter)
595 : ValueRewrite(
Kind::ReplaceValue, rewriterImpl, value),
596 converter(converter) {}
598 static bool classof(
const IRRewrite *
rewrite) {
599 return rewrite->getKind() == Kind::ReplaceValue;
602 void commit(RewriterBase &rewriter)
override;
604 void rollback()
override;
608 const TypeConverter *converter;
612class OperationRewrite :
public IRRewrite {
615 Operation *getOperation()
const {
return op; }
617 static bool classof(
const IRRewrite *
rewrite) {
618 return rewrite->getKind() >= Kind::MoveOperation &&
619 rewrite->getKind() <= Kind::UnresolvedMaterialization;
623 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
625 : IRRewrite(kind, rewriterImpl), op(op) {}
632class MoveOperationRewrite :
public OperationRewrite {
634 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
635 Operation *op, OpBuilder::InsertPoint previous)
636 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op),
637 block(previous.getBlock()),
638 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
640 : &*previous.getPoint()) {}
642 static bool classof(
const IRRewrite *
rewrite) {
643 return rewrite->getKind() == Kind::MoveOperation;
646 void commit(RewriterBase &rewriter)
override {
652 op, OpBuilder::InsertPoint(block,
657 void rollback()
override {
670 Operation *insertBeforeOp;
675class ModifyOperationRewrite :
public OperationRewrite {
677 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
679 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
680 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
681 operands(op->operand_begin(), op->operand_end()),
682 successors(op->successor_begin(), op->successor_end()) {
685 propertiesStorage = operator new(op->getPropertiesStorageSize());
686 OpaqueProperties propCopy(propertiesStorage);
687 name.initOpProperties(propCopy, prop);
691 static bool classof(
const IRRewrite *
rewrite) {
692 return rewrite->getKind() == Kind::ModifyOperation;
695 ~ModifyOperationRewrite()
override {
696 assert(!propertiesStorage &&
697 "rewrite was neither committed nor rolled back");
700 void commit(RewriterBase &rewriter)
override {
703 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
706 if (propertiesStorage) {
707 OpaqueProperties propCopy(propertiesStorage);
711 operator delete(propertiesStorage);
712 propertiesStorage =
nullptr;
716 void rollback()
override {
720 for (
const auto &it : llvm::enumerate(successors))
722 if (propertiesStorage) {
723 OpaqueProperties propCopy(propertiesStorage);
726 operator delete(propertiesStorage);
727 propertiesStorage =
nullptr;
734 DictionaryAttr attrs;
735 SmallVector<Value, 8> operands;
736 SmallVector<Block *, 2> successors;
737 void *propertiesStorage =
nullptr;
744class ReplaceOperationRewrite :
public OperationRewrite {
746 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
747 Operation *op,
const TypeConverter *converter)
748 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
749 converter(converter) {}
751 static bool classof(
const IRRewrite *
rewrite) {
752 return rewrite->getKind() == Kind::ReplaceOperation;
755 void commit(RewriterBase &rewriter)
override;
757 void rollback()
override;
759 void cleanup(RewriterBase &rewriter)
override;
764 const TypeConverter *converter;
767class CreateOperationRewrite :
public OperationRewrite {
769 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
771 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
773 static bool classof(
const IRRewrite *
rewrite) {
774 return rewrite->getKind() == Kind::CreateOperation;
777 void commit(RewriterBase &rewriter)
override {
783 void rollback()
override;
787enum MaterializationKind {
798class UnresolvedMaterializationInfo {
800 UnresolvedMaterializationInfo() =
default;
801 UnresolvedMaterializationInfo(
const TypeConverter *converter,
802 MaterializationKind kind, Type originalType)
803 : converterAndKind(converter, kind), originalType(originalType) {}
806 const TypeConverter *getConverter()
const {
807 return converterAndKind.getPointer();
811 MaterializationKind getMaterializationKind()
const {
812 return converterAndKind.getInt();
816 Type getOriginalType()
const {
return originalType; }
821 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
832class UnresolvedMaterializationRewrite :
public OperationRewrite {
834 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
835 UnrealizedConversionCastOp op,
837 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
838 mappedValues(std::move(mappedValues)) {}
840 static bool classof(
const IRRewrite *
rewrite) {
841 return rewrite->getKind() == Kind::UnresolvedMaterialization;
844 void rollback()
override;
846 UnrealizedConversionCastOp getOperation()
const {
847 return cast<UnrealizedConversionCastOp>(op);
857#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
860template <
typename RewriteTy,
typename R>
861static bool hasRewrite(R &&rewrites, Operation *op) {
862 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
863 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
864 return rewriteTy && rewriteTy->getOperation() == op;
870template <
typename RewriteTy,
typename R>
871static bool hasRewrite(R &&rewrites,
Block *block) {
872 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
873 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
874 return rewriteTy && rewriteTy->getBlock() == block;
886 const ConversionConfig &
config,
896 RewriterState getCurrentState();
900 void applyRewrites();
905 void resetState(RewriterState state, StringRef patternName =
"");
909 template <
typename RewriteTy,
typename... Args>
911 assert(
config.allowPatternRollback &&
"appending rewrites is not allowed");
913 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
919 void undoRewrites(
unsigned numRewritesToKeep = 0, StringRef patternName =
"");
925 LogicalResult remapValues(StringRef valueDiagTag,
926 std::optional<Location> inputLoc,
ValueRange values,
943 bool skipPureTypeConversions =
false)
const;
957 TypeConverter::SignatureConversion *entryConversion);
965 Block *applySignatureConversion(
967 TypeConverter::SignatureConversion &signatureConversion);
984 void eraseBlock(
Block *block);
1022 Value findOrBuildReplacementValue(
Value value,
1030 void notifyOperationInserted(
Operation *op,
1034 void notifyBlockInserted(
Block *block,
Region *previous,
1053 std::function<
void(
Operation *)> opErasedCallback =
nullptr)
1055 opErasedCallback(std::move(opErasedCallback)) {}
1069 assert(block->empty() &&
"expected empty block");
1070 block->dropAllDefinedValueUses();
1078 if (opErasedCallback)
1079 opErasedCallback(op);
1187const ConversionConfig &IRRewrite::getConfig()
const {
1188 return rewriterImpl.
config;
1191void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1195 if (
auto *listener =
1196 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1197 for (Operation *op : getNewBlock()->getUsers())
1201void BlockTypeConversionRewrite::rollback() {
1202 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1208 if (isa<BlockArgument>(repl)) {
1245void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1252void ReplaceValueRewrite::rollback() {
1253 rewriterImpl.
mapping.erase({value});
1259void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1261 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1264 SmallVector<Value> replacements =
1266 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1274 for (
auto [
result, newValue] :
1275 llvm::zip_equal(op->
getResults(), replacements))
1281 if (getConfig().unlegalizedOps)
1282 getConfig().unlegalizedOps->erase(op);
1286 notifyIRErased(listener, *op);
1293void ReplaceOperationRewrite::rollback() {
1298void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1302void CreateOperationRewrite::rollback() {
1304 while (!region.getBlocks().empty())
1305 region.getBlocks().remove(region.getBlocks().begin());
1311void UnresolvedMaterializationRewrite::rollback() {
1312 if (!mappedValues.empty())
1313 rewriterImpl.
mapping.erase(mappedValues);
1324 for (
size_t i = 0; i <
rewrites.size(); ++i)
1330 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1331 unresolvedMaterializations.erase(castOp);
1334 rewrite->cleanup(eraseRewriter);
1342 Value from,
TypeRange desiredTypes,
bool skipPureTypeConversions)
const {
1345 assert(!values.empty() &&
"expected non-empty value vector");
1349 if (
config.allowPatternRollback)
1350 return mapping.lookup(values);
1357 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1362 if (castOp.getOutputs() != values)
1364 return castOp.getInputs();
1373 for (
Value v : values) {
1376 llvm::append_range(next, r);
1381 if (next != values) {
1410 if (skipPureTypeConversions) {
1413 match &= !pureConversion;
1416 if (!pureConversion)
1417 lastNonMaterialization = current;
1420 desiredValue = current;
1426 current = std::move(next);
1431 if (!desiredTypes.empty())
1432 return desiredValue;
1433 if (skipPureTypeConversions)
1434 return lastNonMaterialization;
1453 StringRef patternName) {
1458 while (
ignoredOps.size() != state.numIgnoredOperations)
1461 while (
replacedOps.size() != state.numReplacedOps)
1466 StringRef patternName) {
1468 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1470 rewrites.resize(numRewritesToKeep);
1474 StringRef valueDiagTag, std::optional<Location> inputLoc,
ValueRange values,
1476 remapped.reserve(llvm::size(values));
1478 for (
const auto &it : llvm::enumerate(values)) {
1479 Value operand = it.value();
1498 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1499 << it.index() <<
", type was " << origType;
1504 if (legalTypes.empty()) {
1505 remapped.push_back({});
1514 remapped.push_back(std::move(repl));
1523 repl, repl, legalTypes,
1525 remapped.push_back(castValues);
1546 TypeConverter::SignatureConversion *entryConversion) {
1548 if (region->
empty())
1553 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1555 std::optional<TypeConverter::SignatureConversion> conversion =
1556 converter.convertBlockSignature(&block);
1565 if (entryConversion)
1568 std::optional<TypeConverter::SignatureConversion> conversion =
1569 converter.convertBlockSignature(®ion->
front());
1577 TypeConverter::SignatureConversion &signatureConversion) {
1578#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1580 if (hasRewrite<BlockTypeConversionRewrite>(
rewrites, block))
1581 llvm::report_fatal_error(
"block was already converted");
1588 auto convertedTypes = signatureConversion.getConvertedTypes();
1595 for (
unsigned i = 0; i < origArgCount; ++i) {
1596 auto inputMap = signatureConversion.getInputMapping(i);
1597 if (!inputMap || inputMap->replacedWithValues())
1600 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1601 newLocs[inputMap->inputNo +
j] = origLoc;
1608 convertedTypes, newLocs);
1616 bool fastPath = !
config.listener;
1618 if (
config.allowPatternRollback)
1622 while (!block->
empty())
1629 for (
unsigned i = 0; i != origArgCount; ++i) {
1633 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1634 signatureConversion.getInputMapping(i);
1642 MaterializationKind::Source,
1646 origArgType,
Type(), converter,
1653 if (inputMap->replacedWithValues()) {
1655 assert(inputMap->size == 0 &&
1656 "invalid to provide a replacement value when the argument isn't "
1664 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1668 if (
config.allowPatternRollback)
1689 assert((!originalType || kind == MaterializationKind::Target) &&
1690 "original type is valid only for target materializations");
1691 assert(
TypeRange(inputs) != outputTypes &&
1692 "materialization is not necessary");
1696 OpBuilder builder(outputTypes.front().getContext());
1698 UnrealizedConversionCastOp convertOp =
1699 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1700 if (
config.attachDebugMaterializationKind) {
1702 kind == MaterializationKind::Source ?
"source" :
"target";
1703 convertOp->setAttr(
"__kind__", builder.
getStringAttr(kindStr));
1710 UnresolvedMaterializationInfo(converter, kind, originalType);
1711 if (
config.allowPatternRollback) {
1712 if (!valuesToMap.empty())
1713 mapping.map(valuesToMap, convertOp.getResults());
1715 std::move(valuesToMap));
1719 return convertOp.getResults();
1724 assert(
config.allowPatternRollback &&
1725 "this code path is valid only in rollback mode");
1732 return repl.front();
1739 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1764 MaterializationKind::Source, ip, value.
getLoc(),
1780 bool wasDetached = !previous.
isSet();
1782 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"' (" << op
1785 logger.getOStream() <<
" (was detached)";
1786 logger.getOStream() <<
"\n";
1792 "attempting to insert into a block within a replaced/erased op");
1796 config.listener->notifyOperationInserted(op, previous);
1805 if (
config.allowPatternRollback) {
1819 if (
config.allowPatternRollback)
1829 assert(!
impl.config.allowPatternRollback &&
1830 "this code path is valid only in 'no rollback' mode");
1832 for (
auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1835 repls.push_back(
Value());
1842 Value srcMat =
impl.buildUnresolvedMaterialization(
1847 repls.push_back(srcMat);
1853 repls.push_back(to[0]);
1862 Value srcMat =
impl.buildUnresolvedMaterialization(
1865 Type(), converter)[0];
1866 repls.push_back(srcMat);
1875 "incorrect number of replacement values");
1877 logger.startLine() <<
"** Replace : '" << op->
getName() <<
"'(" << op
1885 for (
auto [
result, repls] :
1886 llvm::zip_equal(op->
getResults(), newValues)) {
1888 auto logProlog = [&, repls = repls]() {
1889 logger.startLine() <<
" Note: Replacing op result of type "
1890 << resultType <<
" with value(s) of type (";
1891 llvm::interleaveComma(repls,
logger.getOStream(), [&](
Value v) {
1892 logger.getOStream() << v.getType();
1894 logger.getOStream() <<
")";
1900 logger.getOStream() <<
", but the type converter failed to legalize "
1901 "the original type.\n";
1906 logger.getOStream() <<
", but the legalized type(s) is/are (";
1907 llvm::interleaveComma(convertedTypes,
logger.getOStream(),
1908 [&](
Type t) { logger.getOStream() << t; });
1909 logger.getOStream() <<
")\n";
1915 if (!
config.allowPatternRollback) {
1924 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1930 if (
config.unlegalizedOps)
1931 config.unlegalizedOps->erase(op);
1939 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1943 "attempting to replace a value that was already replaced");
1948 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1953 "attempting to replace/erase an unresolved materialization");
1967 if (!
config.allowPatternRollback) {
1972 Value repl = repls.front();
1989 "attempting to replace a value that was already replaced");
1991 "attempting to replace a op result that was already replaced");
2000 if (!
config.allowPatternRollback) {
2007 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2013 if (
config.unlegalizedOps)
2014 config.unlegalizedOps->erase(op);
2023 "attempting to erase a block within a replaced/erased op");
2039 bool wasDetached = !previous;
2045 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
2046 <<
"' (" << parent <<
")";
2049 <<
"** Insert Block into detached Region (nullptr parent op)";
2052 logger.getOStream() <<
" (was detached)";
2053 logger.getOStream() <<
"\n";
2059 "attempting to insert into a region within a replaced/erased op");
2064 config.listener->notifyBlockInserted(block, previous, previousIt);
2068 if (
config.allowPatternRollback) {
2082 if (
config.allowPatternRollback)
2096 reasonCallback(
diag);
2097 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
2098 if (
config.notifyCallback)
2107ConversionPatternRewriter::ConversionPatternRewriter(
2111 *this,
config, opConverter)) {
2112 setListener(
impl.get());
2115ConversionPatternRewriter::~ConversionPatternRewriter() =
default;
2117const ConversionConfig &ConversionPatternRewriter::getConfig()
const {
2118 return impl->config;
2122 assert(op && newOp &&
"expected non-null op");
2126void ConversionPatternRewriter::replaceOp(Operation *op,
ValueRange newValues) {
2128 "incorrect # of replacement values");
2132 if (getInsertionPoint() == op->getIterator())
2135 SmallVector<SmallVector<Value>> newVals =
2136 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2137 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2139 impl->replaceOp(op, std::move(newVals));
2142void ConversionPatternRewriter::replaceOpWithMultiple(
2143 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2145 "incorrect # of replacement values");
2149 if (getInsertionPoint() == op->getIterator())
2152 impl->replaceOp(op, std::move(newValues));
2155void ConversionPatternRewriter::eraseOp(Operation *op) {
2157 impl->logger.startLine()
2158 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
2163 if (getInsertionPoint() == op->getIterator())
2166 SmallVector<SmallVector<Value>> nullRepls(op->
getNumResults(), {});
2167 impl->replaceOp(op, std::move(nullRepls));
2170void ConversionPatternRewriter::eraseBlock(
Block *block) {
2171 impl->eraseBlock(block);
2174Block *ConversionPatternRewriter::applySignatureConversion(
2175 Block *block, TypeConverter::SignatureConversion &conversion,
2176 const TypeConverter *converter) {
2177 assert(!impl->wasOpReplaced(block->
getParentOp()) &&
2178 "attempting to apply a signature conversion to a block within a "
2179 "replaced/erased op");
2180 return impl->applySignatureConversion(block, converter, conversion);
2183FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2184 Region *region,
const TypeConverter &converter,
2185 TypeConverter::SignatureConversion *entryConversion) {
2186 assert(!impl->wasOpReplaced(region->
getParentOp()) &&
2187 "attempting to apply a signature conversion to a block within a "
2188 "replaced/erased op");
2189 return impl->convertRegionTypes(region, converter, entryConversion);
2192void ConversionPatternRewriter::replaceAllUsesWith(Value from,
ValueRange to) {
2194 impl->logger.startLine() <<
"** Replace Value : '" << from <<
"'";
2195 if (
auto blockArg = dyn_cast<BlockArgument>(from)) {
2196 if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
2197 impl->logger.getOStream() <<
" (in region of '" << parentOp->getName()
2198 <<
"' (" << parentOp <<
")\n";
2200 impl->logger.getOStream() <<
" (unlinked block)\n";
2204 impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
2207Value ConversionPatternRewriter::getRemappedValue(Value key) {
2208 SmallVector<ValueVector> remappedValues;
2209 if (
failed(impl->remapValues(
"value", std::nullopt, key,
2212 assert(remappedValues.front().size() == 1 &&
"1:N conversion not supported");
2213 return remappedValues.front().front();
2217ConversionPatternRewriter::getRemappedValues(
ValueRange keys,
2218 SmallVectorImpl<Value> &results) {
2221 SmallVector<ValueVector> remapped;
2222 if (
failed(impl->remapValues(
"value", std::nullopt, keys,
2225 for (
const auto &values : remapped) {
2226 assert(values.size() == 1 &&
"1:N conversion not supported");
2227 results.push_back(values.front());
2232LogicalResult ConversionPatternRewriter::legalize(Region *r) {
2240 SmallVector<Operation *> ops;
2242 for (Operation &op :
b)
2247 if (
const TypeConverter *converter = impl->currentTypeConverter) {
2248 std::optional<TypeConverter::SignatureConversion> conversion =
2249 converter->convertBlockSignature(&r->front());
2252 applySignatureConversion(&r->front(), *conversion, converter);
2256 for (Operation *op : ops)
2257 if (
failed(legalize(op)))
2263void ConversionPatternRewriter::inlineBlockBefore(
Block *source,
Block *dest,
2268 "incorrect # of argument replacement values");
2269 assert(!impl->wasOpReplaced(source->
getParentOp()) &&
2270 "attempting to inline a block from a replaced/erased op");
2271 assert(!impl->wasOpReplaced(dest->
getParentOp()) &&
2272 "attempting to inline a block into a replaced/erased op");
2273 auto opIgnored = [&](Operation *op) {
return impl->isOpIgnored(op); };
2276 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
2277 "expected 'source' to have no predecessors");
2286 bool fastPath = !getConfig().listener;
2288 if (fastPath && impl->config.allowPatternRollback)
2289 impl->inlineBlockBefore(source, dest, before);
2292 for (
auto it : llvm::zip(source->
getArguments(), argValues))
2293 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2300 while (!source->
empty())
2301 moveOpBefore(&source->
front(), dest, before);
2306 if (getInsertionBlock() == source)
2307 setInsertionPoint(dest, getInsertionPoint());
2313void ConversionPatternRewriter::startOpModification(Operation *op) {
2314 if (!impl->config.allowPatternRollback) {
2319 assert(!impl->wasOpReplaced(op) &&
2320 "attempting to modify a replaced/erased op");
2322 impl->pendingRootUpdates.insert(op);
2324 impl->appendRewrite<ModifyOperationRewrite>(op);
2327void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2328 impl->patternModifiedOps.insert(op);
2329 if (!impl->config.allowPatternRollback) {
2331 if (getConfig().listener)
2332 getConfig().listener->notifyOperationModified(op);
2339 assert(!impl->wasOpReplaced(op) &&
2340 "attempting to modify a replaced/erased op");
2341 assert(impl->pendingRootUpdates.erase(op) &&
2342 "operation did not have a pending in-place update");
2346void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2347 if (!impl->config.allowPatternRollback) {
2352 assert(impl->pendingRootUpdates.erase(op) &&
2353 "operation did not have a pending in-place update");
2356 auto it = llvm::find_if(
2357 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
2358 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2359 return modifyRewrite && modifyRewrite->getOperation() == op;
2361 assert(it != impl->rewrites.rend() &&
"no root update started on op");
2363 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2364 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2367detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2375FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2376 ArrayRef<ValueRange> operands)
const {
2377 SmallVector<Value> oneToOneOperands;
2378 oneToOneOperands.reserve(operands.size());
2380 if (operand.size() != 1)
2383 oneToOneOperands.push_back(operand.front());
2385 return std::move(oneToOneOperands);
2389ConversionPattern::matchAndRewrite(Operation *op,
2390 PatternRewriter &rewriter)
const {
2391 auto &dialectRewriter =
static_cast<ConversionPatternRewriter &
>(rewriter);
2392 auto &rewriterImpl = dialectRewriter.getImpl();
2396 getTypeConverter());
2399 SmallVector<ValueVector> remapped;
2404 SmallVector<ValueRange> remappedAsRange =
2405 llvm::to_vector_of<ValueRange>(remapped);
2406 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2415using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2418class OperationLegalizer {
2420 using LegalizationAction = ConversionTarget::LegalizationAction;
2422 OperationLegalizer(ConversionPatternRewriter &rewriter,
2423 const ConversionTarget &targetInfo,
2424 const FrozenRewritePatternSet &
patterns);
2427 bool isIllegal(Operation *op)
const;
2431 LogicalResult legalize(Operation *op);
2434 const ConversionTarget &getTarget() {
return target; }
2438 LogicalResult legalizeWithFold(Operation *op);
2442 LogicalResult legalizeWithPattern(Operation *op);
2446 bool canApplyPattern(Operation *op,
const Pattern &pattern);
2450 legalizePatternResult(Operation *op,
const Pattern &pattern,
2451 const RewriterState &curState,
2468 void buildLegalizationGraph(
2469 LegalizationPatterns &anyOpLegalizerPatterns,
2480 void computeLegalizationGraphBenefit(
2481 LegalizationPatterns &anyOpLegalizerPatterns,
2486 unsigned computeOpLegalizationDepth(
2493 unsigned applyCostModelToPatterns(
2499 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2502 ConversionPatternRewriter &rewriter;
2505 const ConversionTarget &
target;
2508 PatternApplicator applicator;
2512OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2513 const ConversionTarget &targetInfo,
2514 const FrozenRewritePatternSet &
patterns)
2519 LegalizationPatterns anyOpLegalizerPatterns;
2521 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2522 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2525bool OperationLegalizer::isIllegal(Operation *op)
const {
2526 return target.isIllegal(op);
2529LogicalResult OperationLegalizer::legalize(Operation *op) {
2531 const char *logLineComment =
2532 "//===-------------------------------------------===//\n";
2534 auto &logger = rewriter.getImpl().logger;
2538 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2541 logger.getOStream() <<
"\n";
2542 logger.startLine() << logLineComment;
2543 logger.startLine() <<
"Legalizing operation : ";
2548 logger.getOStream() <<
"'" << op->
getName() <<
"' ";
2549 logger.getOStream() <<
"(" << op <<
") {\n";
2554 logger.startLine() << OpWithFlags(op,
2555 OpPrintingFlags().printGenericOpForm())
2562 logSuccess(logger,
"operation marked 'ignored' during conversion");
2563 logger.startLine() << logLineComment;
2569 if (
auto legalityInfo =
target.isLegal(op)) {
2572 logger,
"operation marked legal by the target{0}",
2573 legalityInfo->isRecursivelyLegal
2574 ?
"; NOTE: operation is recursively legal; skipping internals"
2576 logger.startLine() << logLineComment;
2581 if (legalityInfo->isRecursivelyLegal) {
2582 op->
walk([&](Operation *nested) {
2584 rewriter.getImpl().ignoredOps.
insert(nested);
2593 const ConversionConfig &
config = rewriter.getConfig();
2594 if (
config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2595 if (succeeded(legalizeWithFold(op))) {
2598 logger.startLine() << logLineComment;
2605 if (succeeded(legalizeWithPattern(op))) {
2608 logger.startLine() << logLineComment;
2615 if (
config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2616 if (succeeded(legalizeWithFold(op))) {
2619 logger.startLine() << logLineComment;
2626 logFailure(logger,
"no matched legalization pattern");
2627 logger.startLine() << logLineComment;
2634template <
typename T>
2636 T
result = std::move(obj);
2641LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2642 auto &rewriterImpl = rewriter.getImpl();
2644 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2645 rewriterImpl.
logger.indent();
2650 auto cleanup = llvm::make_scope_exit([&]() {
2660 SmallVector<Value, 2> replacementValues;
2661 SmallVector<Operation *, 2> newOps;
2664 if (
failed(rewriter.
tryFold(op, replacementValues, &newOps))) {
2673 if (replacementValues.empty())
2674 return legalize(op);
2677 rewriter.
replaceOp(op, replacementValues);
2680 for (Operation *newOp : newOps) {
2681 if (
failed(legalize(newOp))) {
2683 "failed to legalize generated constant '{0}'",
2685 if (!rewriter.getConfig().allowPatternRollback) {
2687 llvm::report_fatal_error(
2689 "' folder rollback of IR modifications requested");
2707 auto newOpNames = llvm::map_range(
2709 auto modifiedOpNames = llvm::map_range(
2711 llvm::report_fatal_error(
"pattern '" + pattern.
getDebugName() +
2712 "' produced IR that could not be legalized. " +
2713 "new ops: {" + llvm::join(newOpNames,
", ") +
"}, " +
2715 llvm::join(modifiedOpNames,
", ") +
"}");
2718LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2719 auto &rewriterImpl = rewriter.getImpl();
2720 const ConversionConfig &
config = rewriter.getConfig();
2722#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2724 std::optional<OperationFingerPrint> topLevelFingerPrint;
2725 if (!rewriterImpl.
config.allowPatternRollback) {
2732 topLevelFingerPrint = OperationFingerPrint(checkOp);
2738 rewriterImpl.
logger.startLine()
2739 <<
"WARNING: Multi-threadeding is enabled. Some dialect "
2740 "conversion expensive checks are skipped in multithreading "
2749 auto canApply = [&](
const Pattern &pattern) {
2750 bool canApply = canApplyPattern(op, pattern);
2751 if (canApply &&
config.listener)
2752 config.listener->notifyPatternBegin(pattern, op);
2758 auto onFailure = [&](
const Pattern &pattern) {
2760 if (!rewriterImpl.
config.allowPatternRollback) {
2767#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2769 if (checkOp && topLevelFingerPrint) {
2770 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2771 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2772 llvm::report_fatal_error(
"pattern '" + pattern.getDebugName() +
2773 "' returned failure but IR did change");
2781 if (rewriterImpl.
config.notifyCallback) {
2783 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2790 config.listener->notifyPatternEnd(pattern, failure());
2791 rewriterImpl.
resetState(curState, pattern.getDebugName());
2792 appliedPatterns.erase(&pattern);
2797 auto onSuccess = [&](
const Pattern &pattern) {
2799 if (!rewriterImpl.
config.allowPatternRollback) {
2813 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2814 appliedPatterns.erase(&pattern);
2816 if (!rewriterImpl.
config.allowPatternRollback)
2818 rewriterImpl.
resetState(curState, pattern.getDebugName());
2826 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2830bool OperationLegalizer::canApplyPattern(Operation *op,
2831 const Pattern &pattern) {
2833 auto &os = rewriter.getImpl().logger;
2834 os.getOStream() <<
"\n";
2835 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2837 os.getOStream() <<
")' {\n";
2844 !appliedPatterns.insert(&pattern).second) {
2846 logFailure(rewriter.getImpl().logger,
"pattern was already applied"));
2852LogicalResult OperationLegalizer::legalizePatternResult(
2853 Operation *op,
const Pattern &pattern,
const RewriterState &curState,
2856 [[maybe_unused]]
auto &impl = rewriter.getImpl();
2857 assert(impl.pendingRootUpdates.empty() &&
"dangling root updates");
2859#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2860 if (impl.config.allowPatternRollback) {
2862 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2863 auto replacedRoot = [&] {
2864 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2866 auto updatedRootInPlace = [&] {
2867 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2869 if (!replacedRoot() && !updatedRootInPlace())
2870 llvm::report_fatal_error(
"expected pattern to replace the root operation "
2871 "or modify it in place");
2876 if (
failed(legalizePatternRootUpdates(modifiedOps)) ||
2877 failed(legalizePatternCreatedOperations(newOps))) {
2881 LLVM_DEBUG(
logSuccess(impl.logger,
"pattern applied successfully"));
2885LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2887 for (Operation *op : newOps) {
2888 if (
failed(legalize(op))) {
2889 LLVM_DEBUG(
logFailure(rewriter.getImpl().logger,
2890 "failed to legalize generated operation '{0}'({1})",
2898LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2900 for (Operation *op : modifiedOps) {
2901 if (
failed(legalize(op))) {
2904 "failed to legalize operation updated in-place '{0}'",
2916void OperationLegalizer::buildLegalizationGraph(
2917 LegalizationPatterns &anyOpLegalizerPatterns,
2928 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2929 std::optional<OperationName> root = pattern.
getRootKind();
2935 anyOpLegalizerPatterns.push_back(&pattern);
2940 if (
target.getOpAction(*root) == LegalizationAction::Legal)
2945 invalidPatterns[*root].insert(&pattern);
2947 parentOps[op].insert(*root);
2950 patternWorklist.insert(&pattern);
2958 if (!anyOpLegalizerPatterns.empty()) {
2959 for (
const Pattern *pattern : patternWorklist)
2960 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2964 while (!patternWorklist.empty()) {
2965 auto *pattern = patternWorklist.pop_back_val();
2969 std::optional<LegalizationAction> action = target.getOpAction(op);
2970 return !legalizerPatterns.count(op) &&
2971 (!action || action == LegalizationAction::Illegal);
2977 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2978 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2982 for (
auto op : parentOps[*pattern->
getRootKind()])
2983 patternWorklist.set_union(invalidPatterns[op]);
2987void OperationLegalizer::computeLegalizationGraphBenefit(
2988 LegalizationPatterns &anyOpLegalizerPatterns,
2994 for (
auto &opIt : legalizerPatterns)
2995 if (!minOpPatternDepth.count(opIt.first))
2996 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3002 if (!anyOpLegalizerPatterns.empty())
3003 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3009 applicator.applyCostModel([&](
const Pattern &pattern) {
3010 ArrayRef<const Pattern *> orderedPatternList;
3011 if (std::optional<OperationName> rootName = pattern.
getRootKind())
3012 orderedPatternList = legalizerPatterns[*rootName];
3014 orderedPatternList = anyOpLegalizerPatterns;
3017 auto *it = llvm::find(orderedPatternList, &pattern);
3018 if (it == orderedPatternList.end())
3022 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3026unsigned OperationLegalizer::computeOpLegalizationDepth(
3030 auto depthIt = minOpPatternDepth.find(op);
3031 if (depthIt != minOpPatternDepth.end())
3032 return depthIt->second;
3036 auto opPatternsIt = legalizerPatterns.find(op);
3037 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3042 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3046 unsigned minDepth = applyCostModelToPatterns(
3047 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3048 minOpPatternDepth[op] = minDepth;
3052unsigned OperationLegalizer::applyCostModelToPatterns(
3056 unsigned minDepth = std::numeric_limits<unsigned>::max();
3059 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3060 patternsByDepth.reserve(
patterns.size());
3061 for (
const Pattern *pattern :
patterns) {
3064 unsigned generatedOpDepth = computeOpLegalizationDepth(
3065 generatedOp, minOpPatternDepth, legalizerPatterns);
3066 depth = std::max(depth, generatedOpDepth + 1);
3068 patternsByDepth.emplace_back(pattern, depth);
3071 minDepth = std::min(minDepth, depth);
3076 if (patternsByDepth.size() == 1)
3080 llvm::stable_sort(patternsByDepth,
3081 [](
const std::pair<const Pattern *, unsigned> &
lhs,
3082 const std::pair<const Pattern *, unsigned> &
rhs) {
3085 if (
lhs.second !=
rhs.second)
3086 return lhs.second <
rhs.second;
3089 auto lhsBenefit =
lhs.first->getBenefit();
3090 auto rhsBenefit =
rhs.first->getBenefit();
3091 return lhsBenefit > rhsBenefit;
3096 for (
auto &patternIt : patternsByDepth)
3097 patterns.push_back(patternIt.first);
3111template <
typename RangeT>
3114 function_ref<
bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3123 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3124 if (castOp.getInputs().empty())
3127 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3130 if (inputCastOp.getOutputs() != castOp.getInputs())
3136 while (!worklist.empty()) {
3137 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3141 UnrealizedConversionCastOp nextCast = castOp;
3143 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3144 if (llvm::any_of(nextCast.getInputs(), [&](
Value v) {
3145 return v.getDefiningOp() == castOp;
3153 castOp.replaceAllUsesWith(nextCast.getInputs());
3156 nextCast = getInputCast(nextCast);
3166 auto markOpLive = [&](
Operation *rootOp) {
3168 worklist.push_back(rootOp);
3169 while (!worklist.empty()) {
3170 Operation *op = worklist.pop_back_val();
3171 if (liveOps.insert(op).second) {
3174 if (
auto castOp = v.
getDefiningOp<UnrealizedConversionCastOp>())
3175 if (isCastOpOfInterestFn(castOp))
3176 worklist.push_back(castOp);
3182 for (UnrealizedConversionCastOp op : castOps) {
3185 if (liveOps.contains(op.getOperation()))
3189 if (llvm::any_of(op->getUsers(), [&](
Operation *user) {
3190 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3191 return !castOp || !isCastOpOfInterestFn(castOp);
3197 for (UnrealizedConversionCastOp op : castOps) {
3198 if (liveOps.contains(op)) {
3200 if (remainingCastOps)
3201 remainingCastOps->push_back(op);
3212 ArrayRef<UnrealizedConversionCastOp> castOps,
3213 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3215 DenseSet<UnrealizedConversionCastOp> castOpSet;
3216 for (UnrealizedConversionCastOp op : castOps)
3217 castOpSet.insert(op);
3222 const DenseSet<UnrealizedConversionCastOp> &castOps,
3223 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3225 llvm::make_range(castOps.begin(), castOps.end()),
3226 [&](UnrealizedConversionCastOp castOp) {
3227 return castOps.contains(castOp);
3239 [&](UnrealizedConversionCastOp castOp) {
3240 return castOps.contains(castOp);
3257 const ConversionConfig &
config,
3258 OpConversionMode mode)
3270 LogicalResult convert(
Operation *op,
bool isRecursiveLegalization =
false);
3274 ConversionPatternRewriter rewriter;
3277 OperationLegalizer opLegalizer;
3280 OpConversionMode mode;
3284LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
3285 return impl->opConverter.convert(op,
true);
3289 bool isRecursiveLegalization) {
3290 const ConversionConfig &
config = rewriter.getConfig();
3293 if (failed(opLegalizer.legalize(op))) {
3296 if (mode == OpConversionMode::Full) {
3297 if (!isRecursiveLegalization)
3305 if (mode == OpConversionMode::Partial) {
3306 if (opLegalizer.isIllegal(op)) {
3307 if (!isRecursiveLegalization)
3309 <<
"' that was explicitly marked illegal";
3312 if (
config.unlegalizedOps && !isRecursiveLegalization)
3313 config.unlegalizedOps->insert(op);
3315 }
else if (mode == OpConversionMode::Analysis) {
3319 if (
config.legalizableOps && !isRecursiveLegalization)
3320 config.legalizableOps->insert(op);
3327 UnrealizedConversionCastOp op,
3328 const UnresolvedMaterializationInfo &info) {
3329 assert(!op.use_empty() &&
3330 "expected that dead materializations have already been DCE'd");
3337 switch (info.getMaterializationKind()) {
3338 case MaterializationKind::Target:
3339 newMaterialization = converter->materializeTargetConversion(
3340 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3341 info.getOriginalType());
3343 case MaterializationKind::Source:
3344 assert(op->getNumResults() == 1 &&
"expected single result");
3345 Value sourceMat = converter->materializeSourceConversion(
3346 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3348 newMaterialization.push_back(sourceMat);
3351 if (!newMaterialization.empty()) {
3353 ValueRange newMaterializationRange(newMaterialization);
3354 assert(
TypeRange(newMaterializationRange) == op.getResultTypes() &&
3355 "materialization callback produced value of incorrect type");
3357 rewriter.
replaceOp(op, newMaterialization);
3363 <<
"failed to legalize unresolved materialization "
3365 << inputOperands.
getTypes() <<
") to ("
3366 << op.getResultTypes()
3367 <<
") that remained live after conversion";
3368 diag.attachNote(op->getUsers().begin()->getLoc())
3369 <<
"see existing live user here: " << *op->getUsers().begin();
3378 for (
auto *op : ops) {
3381 toConvert.push_back(op);
3384 auto legalityInfo =
target.isLegal(op);
3385 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3394 for (
auto *op : toConvert) {
3397 if (rewriterImpl.
config.allowPatternRollback) {
3421 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3425 if (rewriter.getConfig().buildMaterializations) {
3429 rewriter.getConfig().listener);
3430 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3431 auto it = materializations.find(castOp);
3432 assert(it != materializations.end() &&
"inconsistent state");
3446void TypeConverter::SignatureConversion::addInputs(
unsigned origInputNo,
3448 assert(!types.empty() &&
"expected valid types");
3449 remapInput(origInputNo, argTypes.size(), types.size());
3453void TypeConverter::SignatureConversion::addInputs(
ArrayRef<Type> types) {
3454 assert(!types.empty() &&
3455 "1->0 type remappings don't need to be added explicitly");
3456 argTypes.append(types.begin(), types.end());
3459void TypeConverter::SignatureConversion::remapInput(
unsigned origInputNo,
3460 unsigned newInputNo,
3461 unsigned newInputCount) {
3462 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3463 assert(newInputCount != 0 &&
"expected valid input count");
3464 remappedInputs[origInputNo] =
3465 InputMapping{newInputNo, newInputCount, {}};
3468void TypeConverter::SignatureConversion::remapInput(
3469 unsigned origInputNo, ArrayRef<Value> replacements) {
3470 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
3471 remappedInputs[origInputNo] = InputMapping{
3473 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3484TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3485 SmallVectorImpl<Type> &results)
const {
3486 assert(typeOrValue &&
"expected non-null type");
3487 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3488 : cast<Type>(typeOrValue);
3490 std::shared_lock<
decltype(cacheMutex)> cacheReadLock(cacheMutex,
3493 cacheReadLock.lock();
3494 auto existingIt = cachedDirectConversions.find(t);
3495 if (existingIt != cachedDirectConversions.end()) {
3496 if (existingIt->second)
3497 results.push_back(existingIt->second);
3498 return success(existingIt->second !=
nullptr);
3500 auto multiIt = cachedMultiConversions.find(t);
3501 if (multiIt != cachedMultiConversions.end()) {
3502 results.append(multiIt->second.begin(), multiIt->second.end());
3508 size_t currentCount = results.size();
3512 auto isCacheable = [&](
int index) {
3513 int numberOfConversionsUntilContextAware =
3514 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3515 return index < numberOfConversionsUntilContextAware;
3518 std::unique_lock<
decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3521 for (
auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3522 const ConversionCallbackFn &converter = indexedConverter.value();
3523 std::optional<LogicalResult>
result = converter(typeOrValue, results);
3525 assert(results.size() == currentCount &&
3526 "failed type conversion should not change results");
3529 if (!isCacheable(indexedConverter.index()))
3532 cacheWriteLock.lock();
3533 if (!succeeded(*
result)) {
3534 assert(results.size() == currentCount &&
3535 "failed type conversion should not change results");
3536 cachedDirectConversions.try_emplace(t,
nullptr);
3539 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3540 if (newTypes.size() == 1)
3541 cachedDirectConversions.try_emplace(t, newTypes.front());
3543 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3549LogicalResult TypeConverter::convertType(Type t,
3550 SmallVectorImpl<Type> &results)
const {
3551 return convertTypeImpl(t, results);
3554LogicalResult TypeConverter::convertType(Value v,
3555 SmallVectorImpl<Type> &results)
const {
3556 return convertTypeImpl(v, results);
3559Type TypeConverter::convertType(Type t)
const {
3561 SmallVector<Type, 1> results;
3562 if (
failed(convertType(t, results)))
3566 return results.size() == 1 ? results.front() :
nullptr;
3569Type TypeConverter::convertType(Value v)
const {
3571 SmallVector<Type, 1> results;
3572 if (
failed(convertType(v, results)))
3576 return results.size() == 1 ? results.front() :
nullptr;
3580TypeConverter::convertTypes(
TypeRange types,
3581 SmallVectorImpl<Type> &results)
const {
3582 for (Type type : types)
3583 if (
failed(convertType(type, results)))
3589TypeConverter::convertTypes(
ValueRange values,
3590 SmallVectorImpl<Type> &results)
const {
3591 for (Value value : values)
3592 if (
failed(convertType(value, results)))
3597bool TypeConverter::isLegal(Type type)
const {
3598 return convertType(type) == type;
3601bool TypeConverter::isLegal(Value value)
const {
3602 return convertType(value) == value.
getType();
3605bool TypeConverter::isLegal(Operation *op)
const {
3609bool TypeConverter::isLegal(Region *region)
const {
3610 return llvm::all_of(
3614bool TypeConverter::isSignatureLegal(FunctionType ty)
const {
3615 if (!isLegal(ty.getInputs()))
3617 if (!isLegal(ty.getResults()))
3623TypeConverter::convertSignatureArg(
unsigned inputNo, Type type,
3624 SignatureConversion &
result)
const {
3626 SmallVector<Type, 1> convertedTypes;
3627 if (
failed(convertType(type, convertedTypes)))
3631 if (convertedTypes.empty())
3635 result.addInputs(inputNo, convertedTypes);
3639TypeConverter::convertSignatureArgs(
TypeRange types,
3640 SignatureConversion &
result,
3641 unsigned origInputOffset)
const {
3642 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3643 if (
failed(convertSignatureArg(origInputOffset + i, types[i],
result)))
3648TypeConverter::convertSignatureArg(
unsigned inputNo, Value value,
3649 SignatureConversion &
result)
const {
3651 SmallVector<Type, 1> convertedTypes;
3652 if (
failed(convertType(value, convertedTypes)))
3656 if (convertedTypes.empty())
3660 result.addInputs(inputNo, convertedTypes);
3664TypeConverter::convertSignatureArgs(
ValueRange values,
3665 SignatureConversion &
result,
3666 unsigned origInputOffset)
const {
3667 for (
unsigned i = 0, e = values.size(); i != e; ++i)
3668 if (
failed(convertSignatureArg(origInputOffset + i, values[i],
result)))
3673Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3674 Location loc, Type resultType,
3676 for (
const SourceMaterializationCallbackFn &fn :
3677 llvm::reverse(sourceMaterializations))
3678 if (Value
result = fn(builder, resultType, inputs, loc))
3683Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3684 Location loc, Type resultType,
3686 Type originalType)
const {
3687 SmallVector<Value>
result = materializeTargetConversion(
3688 builder, loc,
TypeRange(resultType), inputs, originalType);
3691 assert(
result.size() == 1 &&
"expected single result");
3695SmallVector<Value> TypeConverter::materializeTargetConversion(
3697 Type originalType)
const {
3698 for (
const TargetMaterializationCallbackFn &fn :
3699 llvm::reverse(targetMaterializations)) {
3700 SmallVector<Value>
result =
3701 fn(builder, resultTypes, inputs, loc, originalType);
3705 "callback produced incorrect number of values or values with "
3712std::optional<TypeConverter::SignatureConversion>
3713TypeConverter::convertBlockSignature(
Block *block)
const {
3716 return std::nullopt;
3723TypeConverter::AttributeConversionResult
3724TypeConverter::AttributeConversionResult::result(Attribute attr) {
3725 return AttributeConversionResult(attr, resultTag);
3728TypeConverter::AttributeConversionResult
3729TypeConverter::AttributeConversionResult::na() {
3730 return AttributeConversionResult(
nullptr, naTag);
3733TypeConverter::AttributeConversionResult
3734TypeConverter::AttributeConversionResult::abort() {
3735 return AttributeConversionResult(
nullptr, abortTag);
3738bool TypeConverter::AttributeConversionResult::hasResult()
const {
3739 return impl.getInt() == resultTag;
3742bool TypeConverter::AttributeConversionResult::isNa()
const {
3743 return impl.getInt() == naTag;
3746bool TypeConverter::AttributeConversionResult::isAbort()
const {
3747 return impl.getInt() == abortTag;
3750Attribute TypeConverter::AttributeConversionResult::getResult()
const {
3751 assert(hasResult() &&
"Cannot get result from N/A or abort");
3752 return impl.getPointer();
3755std::optional<Attribute>
3756TypeConverter::convertTypeAttribute(Type type, Attribute attr)
const {
3757 for (
const TypeAttributeConversionCallbackFn &fn :
3758 llvm::reverse(typeAttributeConversions)) {
3759 AttributeConversionResult res = fn(type, attr);
3760 if (res.hasResult())
3761 return res.getResult();
3763 return std::nullopt;
3765 return std::nullopt;
3774 ConversionPatternRewriter &rewriter) {
3775 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3780 TypeConverter::SignatureConversion
result(type.getNumInputs());
3782 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
result)) ||
3783 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3785 if (!funcOp.getFunctionBody().empty())
3786 rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(),
result,
3790 auto newType = FunctionType::get(rewriter.getContext(),
3791 result.getConvertedTypes(), newResults);
3793 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3802struct FunctionOpInterfaceSignatureConversion :
public ConversionPattern {
3803 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3805 const TypeConverter &converter,
3806 PatternBenefit benefit)
3807 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3810 matchAndRewrite(Operation *op, ArrayRef<Value> ,
3811 ConversionPatternRewriter &rewriter)
const override {
3812 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3817struct AnyFunctionOpInterfaceSignatureConversion
3818 :
public OpInterfaceConversionPattern<FunctionOpInterface> {
3819 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3822 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> ,
3823 ConversionPatternRewriter &rewriter)
const override {
3829FailureOr<Operation *>
3830mlir::convertOpResultTypes(Operation *op,
ValueRange operands,
3831 const TypeConverter &converter,
3832 ConversionPatternRewriter &rewriter) {
3833 assert(op &&
"Invalid op");
3834 Location loc = op->
getLoc();
3835 if (converter.isLegal(op))
3836 return rewriter.notifyMatchFailure(loc,
"op already legal");
3838 OperationState newOp(loc, op->
getName());
3839 newOp.addOperands(operands);
3841 SmallVector<Type> newResultTypes;
3843 return rewriter.notifyMatchFailure(loc,
"couldn't convert return types");
3845 newOp.addTypes(newResultTypes);
3846 newOp.addAttributes(op->
getAttrs());
3847 return rewriter.create(newOp);
3850void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3851 StringRef functionLikeOpName, RewritePatternSet &
patterns,
3852 const TypeConverter &converter, PatternBenefit benefit) {
3853 patterns.add<FunctionOpInterfaceSignatureConversion>(
3854 functionLikeOpName,
patterns.getContext(), converter, benefit);
3857void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3858 RewritePatternSet &
patterns,
const TypeConverter &converter,
3859 PatternBenefit benefit) {
3860 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3861 converter,
patterns.getContext(), benefit);
3868void ConversionTarget::setOpAction(OperationName op,
3869 LegalizationAction action) {
3870 legalOperations[op].action = action;
3873void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3874 LegalizationAction action) {
3875 for (StringRef dialect : dialectNames)
3876 legalDialects[dialect] = action;
3879auto ConversionTarget::getOpAction(OperationName op)
const
3880 -> std::optional<LegalizationAction> {
3881 std::optional<LegalizationInfo> info = getOpInfo(op);
3882 return info ? info->action : std::optional<LegalizationAction>();
3885auto ConversionTarget::isLegal(Operation *op)
const
3886 -> std::optional<LegalOpDetails> {
3887 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3889 return std::nullopt;
3892 auto isOpLegal = [&] {
3894 if (info->action == LegalizationAction::Dynamic) {
3895 std::optional<bool>
result = info->legalityFn(op);
3901 return info->action == LegalizationAction::Legal;
3904 return std::nullopt;
3907 LegalOpDetails legalityDetails;
3908 if (info->isRecursivelyLegal) {
3909 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3910 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3911 legalityDetails.isRecursivelyLegal =
3912 legalityFnIt->second(op).value_or(
true);
3914 legalityDetails.isRecursivelyLegal =
true;
3917 return legalityDetails;
3920bool ConversionTarget::isIllegal(Operation *op)
const {
3921 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3925 if (info->action == LegalizationAction::Dynamic) {
3926 std::optional<bool>
result = info->legalityFn(op);
3933 return info->action == LegalizationAction::Illegal;
3937 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
3938 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
3942 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3944 if (std::optional<bool>
result = newCl(op))
3952void ConversionTarget::setLegalityCallback(
3953 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3954 assert(callback &&
"expected valid legality callback");
3955 auto *infoIt = legalOperations.find(name);
3956 assert(infoIt != legalOperations.end() &&
3957 infoIt->second.action == LegalizationAction::Dynamic &&
3958 "expected operation to already be marked as dynamically legal");
3959 infoIt->second.legalityFn =
3963void ConversionTarget::markOpRecursivelyLegal(
3964 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3965 auto *infoIt = legalOperations.find(name);
3966 assert(infoIt != legalOperations.end() &&
3967 infoIt->second.action != LegalizationAction::Illegal &&
3968 "expected operation to already be marked as legal");
3969 infoIt->second.isRecursivelyLegal =
true;
3972 std::move(opRecursiveLegalityFns[name]), callback);
3974 opRecursiveLegalityFns.erase(name);
3977void ConversionTarget::setLegalityCallback(
3978 ArrayRef<StringRef> dialects,
const DynamicLegalityCallbackFn &callback) {
3979 assert(callback &&
"expected valid legality callback");
3980 for (StringRef dialect : dialects)
3982 std::move(dialectLegalityFns[dialect]), callback);
3985void ConversionTarget::setLegalityCallback(
3986 const DynamicLegalityCallbackFn &callback) {
3987 assert(callback &&
"expected valid legality callback");
3991auto ConversionTarget::getOpInfo(OperationName op)
const
3992 -> std::optional<LegalizationInfo> {
3994 const auto *it = legalOperations.find(op);
3995 if (it != legalOperations.end())
3998 auto dialectIt = legalDialects.find(op.getDialectNamespace());
3999 if (dialectIt != legalDialects.end()) {
4000 DynamicLegalityCallbackFn callback;
4001 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4002 if (dialectFn != dialectLegalityFns.end())
4003 callback = dialectFn->second;
4004 return LegalizationInfo{dialectIt->second,
false,
4008 if (unknownLegalityFn)
4009 return LegalizationInfo{LegalizationAction::Dynamic,
4010 false, unknownLegalityFn};
4011 return std::nullopt;
4014#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4019void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4020 auto &rewriterImpl =
4021 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4025void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4026 auto &rewriterImpl =
4027 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4033static FailureOr<SmallVector<Value>>
4034pdllConvertValues(ConversionPatternRewriter &rewriter,
ValueRange values) {
4035 SmallVector<Value> mappedValues;
4036 if (
failed(rewriter.getRemappedValues(values, mappedValues)))
4038 return std::move(mappedValues);
4041void mlir::registerConversionPDLFunctions(RewritePatternSet &
patterns) {
4042 patterns.getPDLPatterns().registerRewriteFunction(
4044 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4045 auto results = pdllConvertValues(
4046 static_cast<ConversionPatternRewriter &
>(rewriter), value);
4049 return results->front();
4051 patterns.getPDLPatterns().registerRewriteFunction(
4052 "convertValues", [](PatternRewriter &rewriter,
ValueRange values) {
4053 return pdllConvertValues(
4054 static_cast<ConversionPatternRewriter &
>(rewriter), values);
4056 patterns.getPDLPatterns().registerRewriteFunction(
4058 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4059 auto &rewriterImpl =
4060 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4061 if (
const TypeConverter *converter =
4063 if (Type newType = converter->convertType(type))
4069 patterns.getPDLPatterns().registerRewriteFunction(
4071 [](PatternRewriter &rewriter,
4072 TypeRange types) -> FailureOr<SmallVector<Type>> {
4073 auto &rewriterImpl =
4074 static_cast<ConversionPatternRewriter &
>(rewriter).getImpl();
4077 return SmallVector<Type>(types);
4079 SmallVector<Type> remappedTypes;
4080 if (
failed(converter->convertTypes(types, remappedTypes)))
4082 return std::move(remappedTypes);
4097 static constexpr StringLiteral
tag =
"apply-conversion";
4098 static constexpr StringLiteral
desc =
4099 "Encapsulate the application of a dialect conversion";
4108 OpConversionMode mode) {
4112 LogicalResult status =
success();
4128LogicalResult mlir::applyPartialConversion(
4129 ArrayRef<Operation *> ops,
const ConversionTarget &
target,
4130 const FrozenRewritePatternSet &
patterns, ConversionConfig
config) {
4132 OpConversionMode::Partial);
4135mlir::applyPartialConversion(Operation *op,
const ConversionTarget &
target,
4136 const FrozenRewritePatternSet &
patterns,
4137 ConversionConfig
config) {
4145LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4146 const ConversionTarget &
target,
4147 const FrozenRewritePatternSet &
patterns,
4148 ConversionConfig
config) {
4151LogicalResult mlir::applyFullConversion(Operation *op,
4152 const ConversionTarget &
target,
4153 const FrozenRewritePatternSet &
patterns,
4154 ConversionConfig
config) {
4172 "expected top-level op to be isolated from above");
4175 "expected ops to have a common ancestor");
4184 for (
Operation *op : ops.drop_front()) {
4188 assert(commonAncestor &&
4189 "expected to find a common isolated from above ancestor");
4193 return commonAncestor;
4196LogicalResult mlir::applyAnalysisConversion(
4197 ArrayRef<Operation *> ops, ConversionTarget &
target,
4198 const FrozenRewritePatternSet &
patterns, ConversionConfig
config) {
4200 if (
config.legalizableOps)
4201 assert(
config.legalizableOps->empty() &&
"expected empty set");
4207 Operation *clonedAncestor = commonAncestor->
clone(mapping);
4211 inverseOperationMap[it.second] = it.first;
4214 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4215 ops, [&](Operation *op) {
return mapping.
lookup(op); });
4217 OpConversionMode::Analysis);
4221 if (
config.legalizableOps) {
4223 for (Operation *op : *
config.legalizableOps)
4224 originalLegalizableOps.insert(inverseOperationMap[op]);
4225 *
config.legalizableOps = std::move(originalLegalizableOps);
4229 clonedAncestor->
erase();
4234mlir::applyAnalysisConversion(Operation *op, ConversionTarget &
target,
4235 const FrozenRewritePatternSet &
patterns,
4236 ConversionConfig
config) {
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl)
Replace all uses of from with repl.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
void destroyOpProperties(OpaqueProperties properties) const
This hooks destroy the op properties.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
CRTP Implementation of an action.
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This iterator enumerates elements according to their dominance relationship.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
bool wasErased(void *ptr) const
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void replaceAllUsesWith(Value from, ValueRange to, const TypeConverter *converter)
Replace the uses of the given value with the given values.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.