10 #include "mlir/Config/mlir-config.h"
19 #include "llvm/ADT/ScopeExit.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/SaveAndRestore.h"
25 #include "llvm/Support/ScopedPrinter.h"
31 #define DEBUG_TYPE "dialect-conversion"
34 template <
typename... Args>
35 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
38 os.startLine() <<
"} -> SUCCESS";
40 os.getOStream() <<
" : "
41 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
42 os.getOStream() <<
"\n";
47 template <
typename... Args>
48 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
51 os.startLine() <<
"} -> FAILURE : "
52 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
62 if (
OpResult inputRes = dyn_cast<OpResult>(value))
63 insertPt = ++inputRes.getOwner()->getIterator();
70 assert(!vals.empty() &&
"expected at least one value");
73 for (
Value v : vals.drop_front()) {
87 assert(dom &&
"unable to find valid insertion point");
105 struct ValueVectorMapInfo {
108 static ::llvm::hash_code getHashValue(
const ValueVector &val) {
109 return ::llvm::hash_combine_range(val.begin(), val.end());
118 struct ConversionValueMapping {
121 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
140 template <
typename T>
141 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
144 template <
typename OldVal,
typename NewVal>
145 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
146 map(OldVal &&oldVal, NewVal &&newVal) {
150 assert(next != oldVal &&
"inserting cyclic mapping");
151 auto it = mapping.find(next);
152 if (it == mapping.end())
157 for (
Value v : newVal)
160 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
164 template <
typename OldVal,
typename NewVal>
165 std::enable_if_t<!IsValueVector<OldVal>::value ||
166 !IsValueVector<NewVal>::value>
167 map(OldVal &&oldVal, NewVal &&newVal) {
168 if constexpr (IsValueVector<OldVal>{}) {
169 map(std::forward<OldVal>(oldVal),
ValueVector{newVal});
170 }
else if constexpr (IsValueVector<NewVal>{}) {
171 map(
ValueVector{oldVal}, std::forward<NewVal>(newVal));
178 void erase(
const ValueVector &value) { mapping.erase(value); }
190 ConversionValueMapping::lookupOrDefault(
Value from,
199 desiredValue = current;
203 for (
Value v : current) {
204 auto it = mapping.find({v});
205 if (it != mapping.end()) {
206 llvm::append_range(next, it->second);
211 if (next != current) {
213 current = std::move(next);
225 auto it = mapping.find(current);
226 if (it == mapping.end()) {
230 current = it->second;
236 return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
241 ValueVector result = lookupOrDefault(from, desiredTypes);
254 struct RewriterState {
255 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
256 unsigned numReplacedOps)
257 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
258 numReplacedOps(numReplacedOps) {}
261 unsigned numRewrites;
264 unsigned numIgnoredOperations;
267 unsigned numReplacedOps;
299 UnresolvedMaterialization
302 virtual ~IRRewrite() =
default;
305 virtual void rollback() = 0;
324 Kind getKind()
const {
return kind; }
326 static bool classof(
const IRRewrite *
rewrite) {
return true; }
330 : kind(kind), rewriterImpl(rewriterImpl) {}
339 class BlockRewrite :
public IRRewrite {
342 Block *getBlock()
const {
return block; }
344 static bool classof(
const IRRewrite *
rewrite) {
345 return rewrite->getKind() >= Kind::CreateBlock &&
346 rewrite->getKind() <= Kind::ReplaceBlockArg;
352 : IRRewrite(kind, rewriterImpl), block(block) {}
361 class CreateBlockRewrite :
public BlockRewrite {
364 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
366 static bool classof(
const IRRewrite *
rewrite) {
367 return rewrite->getKind() == Kind::CreateBlock;
373 listener->notifyBlockInserted(block, {}, {});
376 void rollback()
override {
379 auto &blockOps = block->getOperations();
380 while (!blockOps.empty())
381 blockOps.remove(blockOps.begin());
382 block->dropAllUses();
383 if (block->getParent())
394 class EraseBlockRewrite :
public BlockRewrite {
397 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
398 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
400 static bool classof(
const IRRewrite *
rewrite) {
401 return rewrite->getKind() == Kind::EraseBlock;
404 ~EraseBlockRewrite()
override {
406 "rewrite was neither rolled back nor committed/cleaned up");
409 void rollback()
override {
412 assert(block &&
"expected block");
413 auto &blockList = region->getBlocks();
417 blockList.insert(before, block);
423 assert(block &&
"expected block");
424 assert(block->empty() &&
"expected empty block");
428 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
429 listener->notifyBlockErased(block);
434 block->dropAllDefinedValueUses();
445 Block *insertBeforeBlock;
451 class InlineBlockRewrite :
public BlockRewrite {
455 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
456 sourceBlock(sourceBlock),
457 firstInlinedInst(sourceBlock->empty() ? nullptr
458 : &sourceBlock->front()),
459 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
465 assert(!getConfig().listener &&
466 "InlineBlockRewrite not supported if listener is attached");
469 static bool classof(
const IRRewrite *
rewrite) {
470 return rewrite->getKind() == Kind::InlineBlock;
473 void rollback()
override {
476 if (firstInlinedInst) {
477 assert(lastInlinedInst &&
"expected operation");
497 class MoveBlockRewrite :
public BlockRewrite {
501 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block), region(region),
502 insertBeforeBlock(insertBeforeBlock) {}
504 static bool classof(
const IRRewrite *
rewrite) {
505 return rewrite->getKind() == Kind::MoveBlock;
513 listener->notifyBlockInserted(block, region,
518 void rollback()
override {
531 Block *insertBeforeBlock;
535 class BlockTypeConversionRewrite :
public BlockRewrite {
539 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, origBlock),
540 newBlock(newBlock) {}
542 static bool classof(
const IRRewrite *
rewrite) {
543 return rewrite->getKind() == Kind::BlockTypeConversion;
546 Block *getOrigBlock()
const {
return block; }
548 Block *getNewBlock()
const {
return newBlock; }
552 void rollback()
override;
562 class ReplaceBlockArgRewrite :
public BlockRewrite {
567 : BlockRewrite(
Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
568 converter(converter) {}
570 static bool classof(
const IRRewrite *
rewrite) {
571 return rewrite->getKind() == Kind::ReplaceBlockArg;
576 void rollback()
override;
586 class OperationRewrite :
public IRRewrite {
589 Operation *getOperation()
const {
return op; }
591 static bool classof(
const IRRewrite *
rewrite) {
592 return rewrite->getKind() >= Kind::MoveOperation &&
593 rewrite->getKind() <= Kind::UnresolvedMaterialization;
599 : IRRewrite(kind, rewriterImpl), op(op) {}
606 class MoveOperationRewrite :
public OperationRewrite {
610 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op), block(block),
611 insertBeforeOp(insertBeforeOp) {}
613 static bool classof(
const IRRewrite *
rewrite) {
614 return rewrite->getKind() == Kind::MoveOperation;
622 listener->notifyOperationInserted(
628 void rollback()
override {
632 block->
getOperations().splice(before, op->getBlock()->getOperations(), op);
646 class ModifyOperationRewrite :
public OperationRewrite {
650 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
651 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
652 operands(op->operand_begin(), op->operand_end()),
653 successors(op->successor_begin(), op->successor_end()) {
658 name.initOpProperties(propCopy, prop);
662 static bool classof(
const IRRewrite *
rewrite) {
663 return rewrite->getKind() == Kind::ModifyOperation;
666 ~ModifyOperationRewrite()
override {
667 assert(!propertiesStorage &&
668 "rewrite was neither committed nor rolled back");
674 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
675 listener->notifyOperationModified(op);
677 if (propertiesStorage) {
681 name.destroyOpProperties(propCopy);
682 operator delete(propertiesStorage);
683 propertiesStorage =
nullptr;
687 void rollback()
override {
693 if (propertiesStorage) {
696 name.destroyOpProperties(propCopy);
697 operator delete(propertiesStorage);
698 propertiesStorage =
nullptr;
705 DictionaryAttr attrs;
706 SmallVector<Value, 8> operands;
707 SmallVector<Block *, 2> successors;
708 void *propertiesStorage =
nullptr;
715 class ReplaceOperationRewrite :
public OperationRewrite {
719 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
720 converter(converter) {}
722 static bool classof(
const IRRewrite *
rewrite) {
723 return rewrite->getKind() == Kind::ReplaceOperation;
728 void rollback()
override;
738 class CreateOperationRewrite :
public OperationRewrite {
742 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
744 static bool classof(
const IRRewrite *
rewrite) {
745 return rewrite->getKind() == Kind::CreateOperation;
751 listener->notifyOperationInserted(op, {});
754 void rollback()
override;
758 enum MaterializationKind {
771 class UnresolvedMaterializationRewrite :
public OperationRewrite {
774 UnrealizedConversionCastOp op,
776 MaterializationKind kind,
Type originalType,
779 static bool classof(
const IRRewrite *
rewrite) {
780 return rewrite->getKind() == Kind::UnresolvedMaterialization;
783 void rollback()
override;
785 UnrealizedConversionCastOp getOperation()
const {
786 return cast<UnrealizedConversionCastOp>(op);
791 return converterAndKind.getPointer();
795 MaterializationKind getMaterializationKind()
const {
796 return converterAndKind.getInt();
800 Type getOriginalType()
const {
return originalType; }
805 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
818 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
821 template <
typename RewriteTy,
typename R>
822 static bool hasRewrite(R &&rewrites,
Operation *op) {
823 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
824 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
825 return rewriteTy && rewriteTy->getOperation() == op;
831 template <
typename RewriteTy,
typename R>
832 static bool hasRewrite(R &&rewrites,
Block *block) {
833 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
834 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
835 return rewriteTy && rewriteTy->getBlock() == block;
855 RewriterState getCurrentState();
859 void applyRewrites();
862 void resetState(RewriterState state);
866 template <
typename RewriteTy,
typename... Args>
869 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
874 void undoRewrites(
unsigned numRewritesToKeep = 0);
880 LogicalResult remapValues(StringRef valueDiagTag,
881 std::optional<Location> inputLoc,
908 Block *applySignatureConversion(
931 UnrealizedConversionCastOp *castOp =
nullptr);
938 Value findOrBuildReplacementValue(
Value value,
946 void notifyOperationInserted(
Operation *op,
953 void notifyBlockIsBeingErased(
Block *block);
956 void notifyBlockInserted(
Block *block,
Region *previous,
960 void notifyBlockBeingInlined(
Block *block,
Block *srcBlock,
990 if (wasErased(block))
992 assert(block->
empty() &&
"expected empty block");
997 bool wasErased(
void *ptr)
const {
return erased.contains(ptr); }
1061 llvm::ScopedPrinter logger{llvm::dbgs()};
1068 return rewriterImpl.
config;
1071 void BlockTypeConversionRewrite::commit(
RewriterBase &rewriter) {
1075 if (
auto *listener =
1076 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1077 for (
Operation *op : getNewBlock()->getUsers())
1078 listener->notifyOperationModified(op);
1081 void BlockTypeConversionRewrite::rollback() {
1082 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1085 void ReplaceBlockArgRewrite::commit(
RewriterBase &rewriter) {
1090 if (isa<BlockArgument>(repl)) {
1098 Operation *replOp = cast<OpResult>(repl).getOwner();
1106 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.
mapping.erase({arg}); }
1108 void ReplaceOperationRewrite::commit(
RewriterBase &rewriter) {
1110 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1113 SmallVector<Value> replacements =
1115 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1120 listener->notifyOperationReplaced(op, replacements);
1123 for (
auto [result, newValue] :
1124 llvm::zip_equal(op->
getResults(), replacements))
1130 if (getConfig().unlegalizedOps)
1131 getConfig().unlegalizedOps->erase(op);
1137 [&](
Operation *op) { listener->notifyOperationErased(op); });
1145 void ReplaceOperationRewrite::rollback() {
1147 rewriterImpl.
mapping.erase({result});
1150 void ReplaceOperationRewrite::cleanup(
RewriterBase &rewriter) {
1154 void CreateOperationRewrite::rollback() {
1156 while (!region.getBlocks().empty())
1157 region.getBlocks().remove(region.getBlocks().begin());
1163 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1167 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
1168 converterAndKind(converter, kind), originalType(originalType),
1169 mappedValues(std::move(mappedValues)) {
1170 assert((!originalType || kind == MaterializationKind::Target) &&
1171 "original type is valid only for target materializations");
1175 void UnresolvedMaterializationRewrite::rollback() {
1176 if (!mappedValues.empty())
1177 rewriterImpl.
mapping.erase(mappedValues);
1187 for (
size_t i = 0; i <
rewrites.size(); ++i)
1207 while (
ignoredOps.size() != state.numIgnoredOperations)
1210 while (
replacedOps.size() != state.numReplacedOps)
1216 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1218 rewrites.resize(numRewritesToKeep);
1222 StringRef valueDiagTag, std::optional<Location> inputLoc,
1225 remapped.reserve(llvm::size(values));
1228 Value operand = it.value();
1236 remapped.push_back(
mapping.lookupOrDefault(operand));
1244 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1245 << it.index() <<
", type was " << origType;
1250 if (legalTypes.empty()) {
1251 remapped.push_back({});
1260 remapped.push_back(std::move(repl));
1265 repl =
mapping.lookupOrDefault(operand);
1268 repl, repl, legalTypes,
1270 remapped.push_back(castValues);
1293 if (region->
empty())
1298 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1300 std::optional<TypeConverter::SignatureConversion> conversion =
1310 if (entryConversion)
1313 std::optional<TypeConverter::SignatureConversion> conversion =
1325 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1327 if (hasRewrite<BlockTypeConversionRewrite>(
rewrites, block))
1328 llvm::report_fatal_error(
"block was already converted");
1342 for (
unsigned i = 0; i < origArgCount; ++i) {
1344 if (!inputMap || inputMap->replacementValue)
1347 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1348 newLocs[inputMap->inputNo +
j] = origLoc;
1355 convertedTypes, newLocs);
1365 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->
end());
1368 while (!block->
empty())
1375 for (
unsigned i = 0; i != origArgCount; ++i) {
1379 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1385 MaterializationKind::Source,
1388 origArgType,
Type(), converter);
1389 appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1393 if (
Value repl = inputMap->replacementValue) {
1395 assert(inputMap->size == 0 &&
1396 "invalid to provide a replacement value when the argument isn't "
1399 appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1405 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1406 ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
1407 mapping.map(origArg, std::move(replArgVals));
1408 appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1411 appendRewrite<BlockTypeConversionRewrite>(block, newBlock);
1430 UnrealizedConversionCastOp *castOp) {
1431 assert((!originalType || kind == MaterializationKind::Target) &&
1432 "original type is valid only for target materializations");
1433 assert(
TypeRange(inputs) != outputTypes &&
1434 "materialization is not necessary");
1438 OpBuilder builder(outputTypes.front().getContext());
1441 builder.
create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
1442 if (!valuesToMap.empty())
1443 mapping.map(valuesToMap, convertOp.getResults());
1445 *castOp = convertOp;
1446 appendRewrite<UnresolvedMaterializationRewrite>(
1447 convertOp, converter, kind, originalType, std::move(valuesToMap));
1448 return convertOp.getResults();
1458 return repl.front();
1465 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1472 repl =
mapping.lookupOrNull(value);
1505 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"'(" << op
1509 "attempting to insert into a block within a replaced/erased op");
1511 if (!previous.
isSet()) {
1513 appendRewrite<CreateOperationRewrite>(op);
1519 appendRewrite<MoveOperationRewrite>(op, previous.
getBlock(), prevOp);
1525 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1529 bool isUnresolvedMaterialization =
false;
1530 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1532 isUnresolvedMaterialization =
true;
1535 for (
auto [repl, result] : llvm::zip_equal(newValues, op->
getResults())) {
1538 if (isUnresolvedMaterialization) {
1558 assert(!isUnresolvedMaterialization &&
1559 "attempting to replace an unresolved materialization");
1574 appendRewrite<EraseBlockRewrite>(block);
1580 "attempting to insert into a region within a replaced/erased op");
1585 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
1586 <<
"'(" << parent <<
")\n";
1589 <<
"** Insert Block into detached Region (nullptr parent op)'";
1595 appendRewrite<CreateBlockRewrite>(block);
1598 Block *prevBlock = previousIt == previous->
end() ? nullptr : &*previousIt;
1599 appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1604 appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1611 reasonCallback(
diag);
1612 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
1622 ConversionPatternRewriter::ConversionPatternRewriter(
1626 setListener(
impl.get());
1632 assert(op && newOp &&
"expected non-null op");
1638 "incorrect # of replacement values");
1640 impl->logger.startLine()
1641 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
1644 for (
size_t i = 0; i < newValues.size(); ++i) {
1646 newVals.push_back(newValues.slice(i, 1));
1651 impl->notifyOpReplaced(op, newVals);
1657 "incorrect # of replacement values");
1659 impl->logger.startLine()
1660 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
1662 impl->notifyOpReplaced(op, newValues);
1667 impl->logger.startLine()
1668 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
1671 impl->notifyOpReplaced(op, nullRepls);
1676 "attempting to erase a block within a replaced/erased op");
1686 impl->notifyBlockIsBeingErased(block);
1694 "attempting to apply a signature conversion to a block within a "
1695 "replaced/erased op");
1696 return impl->applySignatureConversion(*
this, block, converter, conversion);
1703 "attempting to apply a signature conversion to a block within a "
1704 "replaced/erased op");
1705 return impl->convertRegionTypes(*
this, region, converter, entryConversion);
1712 impl->logger.startLine() <<
"** Replace Argument : '" << from
1713 <<
"'(in region of '" << parentOp->
getName()
1716 impl->appendRewrite<ReplaceBlockArgRewrite>(from.
getOwner(), from,
1717 impl->currentTypeConverter);
1718 impl->mapping.map(
impl->mapping.lookupOrDefault(from), to);
1723 if (failed(
impl->remapValues(
"value", std::nullopt, *
this, key,
1726 assert(remappedValues.front().size() == 1 &&
"1:N conversion not supported");
1727 return remappedValues.front().front();
1736 if (failed(
impl->remapValues(
"value", std::nullopt, *
this, keys,
1739 for (
const auto &values : remapped) {
1740 assert(values.size() == 1 &&
"1:N conversion not supported");
1741 results.push_back(values.front());
1751 "incorrect # of argument replacement values");
1753 "attempting to inline a block from a replaced/erased op");
1755 "attempting to inline a block into a replaced/erased op");
1756 auto opIgnored = [&](
Operation *op) {
return impl->isOpIgnored(op); };
1759 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
1760 "expected 'source' to have no predecessors");
1769 bool fastPath = !
impl->config.listener;
1772 impl->notifyBlockBeingInlined(dest, source, before);
1775 for (
auto it : llvm::zip(source->
getArguments(), argValues))
1776 replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1783 while (!source->
empty())
1784 moveOpBefore(&source->
front(), dest, before);
1792 assert(!
impl->wasOpReplaced(op) &&
1793 "attempting to modify a replaced/erased op");
1795 impl->pendingRootUpdates.insert(op);
1797 impl->appendRewrite<ModifyOperationRewrite>(op);
1801 assert(!
impl->wasOpReplaced(op) &&
1802 "attempting to modify a replaced/erased op");
1807 assert(
impl->pendingRootUpdates.erase(op) &&
1808 "operation did not have a pending in-place update");
1814 assert(
impl->pendingRootUpdates.erase(op) &&
1815 "operation did not have a pending in-place update");
1818 auto it = llvm::find_if(
1819 llvm::reverse(
impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
1820 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1821 return modifyRewrite && modifyRewrite->getOperation() == op;
1823 assert(it !=
impl->rewrites.rend() &&
"no root update started on op");
1825 int updateIdx = std::prev(
impl->rewrites.rend()) - it;
1826 impl->rewrites.erase(
impl->rewrites.begin() + updateIdx);
1840 oneToOneOperands.reserve(operands.size());
1842 if (operand.size() != 1)
1843 llvm::report_fatal_error(
"pattern '" + getDebugName() +
1844 "' does not support 1:N conversion");
1845 oneToOneOperands.push_back(operand.front());
1847 return oneToOneOperands;
1854 auto &rewriterImpl = dialectRewriter.getImpl();
1858 getTypeConverter());
1867 llvm::to_vector_of<ValueRange>(remapped);
1868 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
1880 class OperationLegalizer {
1900 LogicalResult legalizeWithFold(
Operation *op,
1905 LogicalResult legalizeWithPattern(
Operation *op,
1916 RewriterState &curState);
1920 legalizePatternBlockRewrites(
Operation *op,
1923 RewriterState &state, RewriterState &newState);
1924 LogicalResult legalizePatternCreatedOperations(
1926 RewriterState &state, RewriterState &newState);
1929 RewriterState &state,
1930 RewriterState &newState);
1940 void buildLegalizationGraph(
1941 LegalizationPatterns &anyOpLegalizerPatterns,
1952 void computeLegalizationGraphBenefit(
1953 LegalizationPatterns &anyOpLegalizerPatterns,
1958 unsigned computeOpLegalizationDepth(
1965 unsigned applyCostModelToPatterns(
1991 LegalizationPatterns anyOpLegalizerPatterns;
1993 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1994 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1997 bool OperationLegalizer::isIllegal(
Operation *op)
const {
1998 return target.isIllegal(op);
2002 OperationLegalizer::legalize(
Operation *op,
2005 const char *logLineComment =
2006 "//===-------------------------------------------===//\n";
2011 logger.getOStream() <<
"\n";
2012 logger.startLine() << logLineComment;
2013 logger.startLine() <<
"Legalizing operation : '" << op->
getName() <<
"'("
2019 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
2020 logger.getOStream() <<
"\n\n";
2025 if (
auto legalityInfo = target.isLegal(op)) {
2028 logger,
"operation marked legal by the target{0}",
2029 legalityInfo->isRecursivelyLegal
2030 ?
"; NOTE: operation is recursively legal; skipping internals"
2032 logger.startLine() << logLineComment;
2037 if (legalityInfo->isRecursivelyLegal) {
2050 logSuccess(logger,
"operation marked 'ignored' during conversion");
2051 logger.startLine() << logLineComment;
2059 if (succeeded(legalizeWithFold(op, rewriter))) {
2062 logger.startLine() << logLineComment;
2068 if (succeeded(legalizeWithPattern(op, rewriter))) {
2071 logger.startLine() << logLineComment;
2077 logFailure(logger,
"no matched legalization pattern");
2078 logger.startLine() << logLineComment;
2084 OperationLegalizer::legalizeWithFold(
Operation *op,
2086 auto &rewriterImpl = rewriter.
getImpl();
2090 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2091 rewriterImpl.
logger.indent();
2095 SmallVector<Value, 2> replacementValues;
2097 if (failed(rewriter.
tryFold(op, replacementValues))) {
2103 if (replacementValues.empty())
2104 return legalize(op, rewriter);
2107 rewriter.
replaceOp(op, replacementValues);
2110 for (
unsigned i = curState.numRewrites, e = rewriterImpl.
rewrites.size();
2113 dyn_cast<CreateOperationRewrite>(rewriterImpl.
rewrites[i].get());
2116 if (failed(legalize(createOp->getOperation(), rewriter))) {
2118 "failed to legalize generated constant '{0}'",
2119 createOp->getOperation()->getName()));
2130 OperationLegalizer::legalizeWithPattern(
Operation *op,
2132 auto &rewriterImpl = rewriter.
getImpl();
2135 auto canApply = [&](
const Pattern &pattern) {
2136 bool canApply = canApplyPattern(op, pattern, rewriter);
2137 if (canApply &&
config.listener)
2138 config.listener->notifyPatternBegin(pattern, op);
2144 auto onFailure = [&](
const Pattern &pattern) {
2150 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2157 config.listener->notifyPatternEnd(pattern, failure());
2159 appliedPatterns.erase(&pattern);
2164 auto onSuccess = [&](
const Pattern &pattern) {
2166 auto result = legalizePatternResult(op, pattern, rewriter, curState);
2167 appliedPatterns.erase(&pattern);
2171 config.listener->notifyPatternEnd(pattern, result);
2176 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2180 bool OperationLegalizer::canApplyPattern(
Operation *op,
const Pattern &pattern,
2183 auto &os = rewriter.
getImpl().logger;
2184 os.getOStream() <<
"\n";
2185 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2187 os.getOStream() <<
")' {\n";
2194 !appliedPatterns.insert(&pattern).second) {
2203 OperationLegalizer::legalizePatternResult(
Operation *op,
const Pattern &pattern,
2205 RewriterState &curState) {
2207 assert(
impl.pendingRootUpdates.empty() &&
"dangling root updates");
2209 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2211 auto newRewrites = llvm::drop_begin(
impl.rewrites, curState.numRewrites);
2212 auto replacedRoot = [&] {
2213 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2215 auto updatedRootInPlace = [&] {
2216 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2218 if (!replacedRoot() && !updatedRootInPlace())
2219 llvm::report_fatal_error(
"expected pattern to replace the root operation");
2223 RewriterState newState =
impl.getCurrentState();
2224 if (failed(legalizePatternBlockRewrites(op, rewriter,
impl, curState,
2226 failed(legalizePatternRootUpdates(rewriter,
impl, curState, newState)) ||
2227 failed(legalizePatternCreatedOperations(rewriter,
impl, curState,
2232 LLVM_DEBUG(
logSuccess(
impl.logger,
"pattern applied successfully"));
2236 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2239 RewriterState &newState) {
2244 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2245 BlockRewrite *
rewrite = dyn_cast<BlockRewrite>(
impl.rewrites[i].get());
2249 if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2250 ReplaceBlockArgRewrite>(
rewrite))
2259 if (
auto *converter =
impl.regionToConverter.lookup(block->
getParent())) {
2260 std::optional<TypeConverter::SignatureConversion> conversion =
2263 LLVM_DEBUG(
logFailure(
impl.logger,
"failed to convert types of moved "
2267 impl.applySignatureConversion(rewriter, block, converter, *conversion);
2275 if (operationsToIgnore.empty()) {
2276 for (
unsigned i = state.numRewrites, e =
impl.rewrites.size(); i != e;
2279 dyn_cast<CreateOperationRewrite>(
impl.rewrites[i].get());
2282 operationsToIgnore.insert(createOp->getOperation());
2287 if (operationsToIgnore.insert(parentOp).second &&
2288 failed(legalize(parentOp, rewriter))) {
2290 "operation '{0}'({1}) became illegal after rewrite",
2291 parentOp->
getName(), parentOp));
2298 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2300 RewriterState &state, RewriterState &newState) {
2301 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2302 auto *createOp = dyn_cast<CreateOperationRewrite>(
impl.rewrites[i].get());
2305 Operation *op = createOp->getOperation();
2306 if (failed(legalize(op, rewriter))) {
2308 "failed to legalize generated operation '{0}'({1})",
2316 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2318 RewriterState &state, RewriterState &newState) {
2319 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2320 auto *
rewrite = dyn_cast<ModifyOperationRewrite>(
impl.rewrites[i].get());
2324 if (failed(legalize(op, rewriter))) {
2326 impl.logger,
"failed to legalize operation updated in-place '{0}'",
2337 void OperationLegalizer::buildLegalizationGraph(
2338 LegalizationPatterns &anyOpLegalizerPatterns,
2349 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2350 std::optional<OperationName> root = pattern.
getRootKind();
2356 anyOpLegalizerPatterns.push_back(&pattern);
2361 if (target.getOpAction(*root) == LegalizationAction::Legal)
2366 invalidPatterns[*root].insert(&pattern);
2368 parentOps[op].insert(*root);
2371 patternWorklist.insert(&pattern);
2379 if (!anyOpLegalizerPatterns.empty()) {
2380 for (
const Pattern *pattern : patternWorklist)
2381 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2385 while (!patternWorklist.empty()) {
2386 auto *pattern = patternWorklist.pop_back_val();
2390 std::optional<LegalizationAction> action = target.getOpAction(op);
2391 return !legalizerPatterns.count(op) &&
2392 (!action || action == LegalizationAction::Illegal);
2398 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2399 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2403 for (
auto op : parentOps[*pattern->
getRootKind()])
2404 patternWorklist.set_union(invalidPatterns[op]);
2408 void OperationLegalizer::computeLegalizationGraphBenefit(
2409 LegalizationPatterns &anyOpLegalizerPatterns,
2415 for (
auto &opIt : legalizerPatterns)
2416 if (!minOpPatternDepth.count(opIt.first))
2417 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2423 if (!anyOpLegalizerPatterns.empty())
2424 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2430 applicator.applyCostModel([&](
const Pattern &pattern) {
2432 if (std::optional<OperationName> rootName = pattern.
getRootKind())
2433 orderedPatternList = legalizerPatterns[*rootName];
2435 orderedPatternList = anyOpLegalizerPatterns;
2438 auto *it = llvm::find(orderedPatternList, &pattern);
2439 if (it == orderedPatternList.end())
2443 return PatternBenefit(std::distance(it, orderedPatternList.end()));
2447 unsigned OperationLegalizer::computeOpLegalizationDepth(
2451 auto depthIt = minOpPatternDepth.find(op);
2452 if (depthIt != minOpPatternDepth.end())
2453 return depthIt->second;
2457 auto opPatternsIt = legalizerPatterns.find(op);
2458 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2467 unsigned minDepth = applyCostModelToPatterns(
2468 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2469 minOpPatternDepth[op] = minDepth;
2473 unsigned OperationLegalizer::applyCostModelToPatterns(
2480 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2481 patternsByDepth.reserve(
patterns.size());
2485 unsigned generatedOpDepth = computeOpLegalizationDepth(
2486 generatedOp, minOpPatternDepth, legalizerPatterns);
2487 depth =
std::max(depth, generatedOpDepth + 1);
2489 patternsByDepth.emplace_back(pattern, depth);
2492 minDepth =
std::min(minDepth, depth);
2497 if (patternsByDepth.size() == 1)
2501 std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2502 [](
const std::pair<const Pattern *, unsigned> &lhs,
2503 const std::pair<const Pattern *, unsigned> &rhs) {
2506 if (lhs.second != rhs.second)
2507 return lhs.second < rhs.second;
2510 auto lhsBenefit = lhs.first->getBenefit();
2511 auto rhsBenefit = rhs.first->getBenefit();
2512 return lhsBenefit > rhsBenefit;
2517 for (
auto &patternIt : patternsByDepth)
2518 patterns.push_back(patternIt.first);
2526 enum OpConversionMode {
2549 OpConversionMode mode)
2564 OperationLegalizer opLegalizer;
2567 OpConversionMode mode;
2574 if (failed(opLegalizer.legalize(op, rewriter))) {
2577 if (mode == OpConversionMode::Full)
2579 <<
"failed to legalize operation '" << op->
getName() <<
"'";
2583 if (mode == OpConversionMode::Partial) {
2584 if (opLegalizer.isIllegal(op))
2586 <<
"failed to legalize operation '" << op->
getName()
2587 <<
"' that was explicitly marked illegal";
2591 }
else if (mode == OpConversionMode::Analysis) {
2601 static LogicalResult
2603 UnresolvedMaterializationRewrite *
rewrite) {
2604 UnrealizedConversionCastOp op =
rewrite->getOperation();
2605 assert(!op.use_empty() &&
2606 "expected that dead materializations have already been DCE'd");
2612 SmallVector<Value> newMaterialization;
2613 switch (
rewrite->getMaterializationKind()) {
2614 case MaterializationKind::Target:
2616 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
2619 case MaterializationKind::Source:
2620 assert(op->getNumResults() == 1 &&
"expected single result");
2622 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
2624 newMaterialization.push_back(sourceMat);
2627 if (!newMaterialization.empty()) {
2629 ValueRange newMaterializationRange(newMaterialization);
2630 assert(
TypeRange(newMaterializationRange) == op.getResultTypes() &&
2631 "materialization callback produced value of incorrect type");
2633 rewriter.
replaceOp(op, newMaterialization);
2639 <<
"failed to legalize unresolved materialization "
2641 << inputOperands.
getTypes() <<
") to ("
2642 << op.getResultTypes()
2643 <<
") that remained live after conversion";
2644 diag.attachNote(op->getUsers().begin()->getLoc())
2645 <<
"see existing live user here: " << *op->getUsers().begin();
2656 for (
auto *op : ops) {
2659 toConvert.push_back(op);
2662 auto legalityInfo = target.
isLegal(op);
2663 if (legalityInfo && legalityInfo->isRecursivelyLegal)
2673 for (
auto *op : toConvert)
2674 if (failed(convert(rewriter, op)))
2684 for (
auto it : materializations) {
2687 allCastOps.push_back(it.first);
2699 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2700 auto it = materializations.find(castOp);
2701 assert(it != materializations.end() &&
"inconsistent state");
2724 auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2725 for (
Value v : castOp.getInputs())
2726 if (
auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2727 worklist.insert(inputCastOp);
2734 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2735 if (castOp.getInputs().empty())
2738 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2741 if (inputCastOp.getOutputs() != castOp.getInputs())
2747 while (!worklist.empty()) {
2748 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2749 if (castOp->use_empty()) {
2752 enqueueOperands(castOp);
2753 if (remainingCastOps)
2754 erasedOps.insert(castOp.getOperation());
2761 UnrealizedConversionCastOp nextCast = castOp;
2763 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2767 enqueueOperands(castOp);
2768 castOp.replaceAllUsesWith(nextCast.getInputs());
2769 if (remainingCastOps)
2770 erasedOps.insert(castOp.getOperation());
2774 nextCast = getInputCast(nextCast);
2778 if (remainingCastOps)
2779 for (UnrealizedConversionCastOp op : castOps)
2780 if (!erasedOps.contains(op.getOperation()))
2781 remainingCastOps->push_back(op);
2790 assert(!types.empty() &&
"expected valid types");
2791 remapInput(origInputNo, argTypes.size(), types.size());
2796 assert(!types.empty() &&
2797 "1->0 type remappings don't need to be added explicitly");
2798 argTypes.append(types.begin(), types.end());
2802 unsigned newInputNo,
2803 unsigned newInputCount) {
2804 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
2805 assert(newInputCount != 0 &&
"expected valid input count");
2806 remappedInputs[origInputNo] =
2807 InputMapping{newInputNo, newInputCount,
nullptr};
2811 Value replacementValue) {
2812 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
2813 remappedInputs[origInputNo] =
2819 assert(t &&
"expected non-null type");
2822 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2825 cacheReadLock.lock();
2826 auto existingIt = cachedDirectConversions.find(t);
2827 if (existingIt != cachedDirectConversions.end()) {
2828 if (existingIt->second)
2829 results.push_back(existingIt->second);
2830 return success(existingIt->second !=
nullptr);
2832 auto multiIt = cachedMultiConversions.find(t);
2833 if (multiIt != cachedMultiConversions.end()) {
2834 results.append(multiIt->second.begin(), multiIt->second.end());
2840 size_t currentCount = results.size();
2842 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2845 for (
const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2846 if (std::optional<LogicalResult> result = converter(t, results)) {
2848 cacheWriteLock.lock();
2849 if (!succeeded(*result)) {
2850 cachedDirectConversions.try_emplace(t,
nullptr);
2853 auto newTypes =
ArrayRef<Type>(results).drop_front(currentCount);
2854 if (newTypes.size() == 1)
2855 cachedDirectConversions.try_emplace(t, newTypes.front());
2857 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2871 return results.size() == 1 ? results.front() :
nullptr;
2877 for (
Type type : types)
2891 return llvm::all_of(*region, [
this](
Block &block) {
2897 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2909 if (convertedTypes.empty())
2913 result.
addInputs(inputNo, convertedTypes);
2919 unsigned origInputOffset)
const {
2920 for (
unsigned i = 0, e = types.size(); i != e; ++i)
2930 for (
const MaterializationCallbackFn &fn :
2931 llvm::reverse(argumentMaterializations))
2932 if (
Value result = fn(builder, resultType, inputs, loc))
2940 for (
const MaterializationCallbackFn &fn :
2941 llvm::reverse(sourceMaterializations))
2942 if (
Value result = fn(builder, resultType, inputs, loc))
2950 Type originalType)
const {
2952 builder, loc,
TypeRange(resultType), inputs, originalType);
2955 assert(result.size() == 1 &&
"expected single result");
2956 return result.front();
2961 Type originalType)
const {
2962 for (
const TargetMaterializationCallbackFn &fn :
2963 llvm::reverse(targetMaterializations)) {
2965 fn(builder, resultTypes, inputs, loc, originalType);
2969 "callback produced incorrect number of values or values with "
2976 std::optional<TypeConverter::SignatureConversion>
2980 return std::nullopt;
3003 return impl.getInt() == resultTag;
3007 return impl.getInt() == naTag;
3011 return impl.getInt() == abortTag;
3015 assert(hasResult() &&
"Cannot get result from N/A or abort");
3016 return impl.getPointer();
3019 std::optional<Attribute>
3021 for (
const TypeAttributeConversionCallbackFn &fn :
3022 llvm::reverse(typeAttributeConversions)) {
3027 return std::nullopt;
3029 return std::nullopt;
3039 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3045 SmallVector<Type, 1> newResults;
3047 failed(typeConverter.
convertTypes(type.getResults(), newResults)) ||
3049 typeConverter, &result)))
3066 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3074 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3079 struct AnyFunctionOpInterfaceSignatureConversion
3091 FailureOr<Operation *>
3095 assert(op &&
"Invalid op");
3109 return rewriter.
create(newOp);
3115 patterns.add<FunctionOpInterfaceSignatureConversion>(
3116 functionLikeOpName,
patterns.getContext(), converter);
3121 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3131 legalOperations[op].action = action;
3136 for (StringRef dialect : dialectNames)
3137 legalDialects[dialect] = action;
3141 -> std::optional<LegalizationAction> {
3142 std::optional<LegalizationInfo> info = getOpInfo(op);
3143 return info ? info->action : std::optional<LegalizationAction>();
3147 -> std::optional<LegalOpDetails> {
3148 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3150 return std::nullopt;
3153 auto isOpLegal = [&] {
3155 if (info->action == LegalizationAction::Dynamic) {
3156 std::optional<bool> result = info->legalityFn(op);
3162 return info->action == LegalizationAction::Legal;
3165 return std::nullopt;
3169 if (info->isRecursivelyLegal) {
3170 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3171 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3173 legalityFnIt->second(op).value_or(
true);
3178 return legalityDetails;
3182 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3186 if (info->action == LegalizationAction::Dynamic) {
3187 std::optional<bool> result = info->legalityFn(op);
3194 return info->action == LegalizationAction::Illegal;
3203 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3205 if (std::optional<bool> result = newCl(op))
3213 void ConversionTarget::setLegalityCallback(
3214 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3215 assert(callback &&
"expected valid legality callback");
3216 auto *infoIt = legalOperations.find(name);
3217 assert(infoIt != legalOperations.end() &&
3218 infoIt->second.action == LegalizationAction::Dynamic &&
3219 "expected operation to already be marked as dynamically legal");
3220 infoIt->second.legalityFn =
3226 auto *infoIt = legalOperations.find(name);
3227 assert(infoIt != legalOperations.end() &&
3228 infoIt->second.action != LegalizationAction::Illegal &&
3229 "expected operation to already be marked as legal");
3230 infoIt->second.isRecursivelyLegal =
true;
3233 std::move(opRecursiveLegalityFns[name]), callback);
3235 opRecursiveLegalityFns.erase(name);
3238 void ConversionTarget::setLegalityCallback(
3240 assert(callback &&
"expected valid legality callback");
3241 for (StringRef dialect : dialects)
3243 std::move(dialectLegalityFns[dialect]), callback);
3246 void ConversionTarget::setLegalityCallback(
3247 const DynamicLegalityCallbackFn &callback) {
3248 assert(callback &&
"expected valid legality callback");
3253 -> std::optional<LegalizationInfo> {
3255 const auto *it = legalOperations.find(op);
3256 if (it != legalOperations.end())
3259 auto dialectIt = legalDialects.find(op.getDialectNamespace());
3260 if (dialectIt != legalDialects.end()) {
3261 DynamicLegalityCallbackFn callback;
3262 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3263 if (dialectFn != dialectLegalityFns.end())
3264 callback = dialectFn->second;
3265 return LegalizationInfo{dialectIt->second,
false,
3269 if (unknownLegalityFn)
3270 return LegalizationInfo{LegalizationAction::Dynamic,
3271 false, unknownLegalityFn};
3272 return std::nullopt;
3275 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3281 auto &rewriterImpl =
3287 auto &rewriterImpl =
3294 static FailureOr<SmallVector<Value>>
3296 SmallVector<Value> mappedValues;
3299 return std::move(mappedValues);
3303 patterns.getPDLPatterns().registerRewriteFunction(
3308 if (failed(results))
3310 return results->front();
3312 patterns.getPDLPatterns().registerRewriteFunction(
3317 patterns.getPDLPatterns().registerRewriteFunction(
3320 auto &rewriterImpl =
3330 patterns.getPDLPatterns().registerRewriteFunction(
3333 TypeRange types) -> FailureOr<SmallVector<Type>> {
3334 auto &rewriterImpl =
3341 if (failed(converter->
convertTypes(types, remappedTypes)))
3343 return std::move(remappedTypes);
3359 OpConversionMode::Partial);
3377 OpConversionMode::Full);
3400 "expected top-level op to be isolated from above");
3403 "expected ops to have a common ancestor");
3412 for (
Operation *op : ops.drop_front()) {
3416 assert(commonAncestor &&
3417 "expected to find a common isolated from above ancestor");
3421 return commonAncestor;
3428 if (
config.legalizableOps)
3429 assert(
config.legalizableOps->empty() &&
"expected empty set");
3439 inverseOperationMap[it.second] = it.first;
3445 OpConversionMode::Analysis);
3446 LogicalResult status = opConverter.convertOperations(opsToConvert);
3450 if (
config.legalizableOps) {
3453 originalLegalizableOps.insert(inverseOperationMap[op]);
3454 *
config.legalizableOps = std::move(originalLegalizableOps);
3458 clonedAncestor->
erase();
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
OpListType & getOperations()
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
~ConversionPatternRewriter() override
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
int getPropertiesStorageSize() const
Returns the properties storage size.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
The general result of a type attribute conversion callback, allowing for early termination.
Attribute getResult() const
static AttributeConversionResult abort()
static AttributeConversionResult na()
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target materializations through the...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
A rewriter that keeps track of erased ops and blocks.
bool wasErased(void *ptr) const
SingleEraseRewriter(MLIRContext *context)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
SingleEraseRewriter eraseRewriter
A rewriter that keeps track of ops/block that were already erased and skips duplicate op/block erasur...
MLIRContext * context
MLIR context.
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp=nullptr)
Build an unresolved materialization operation given a range of output types and a list of input opera...
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
void notifyOpReplaced(Operation *op, ArrayRef< ValueRange > newValues)
Notifies that an op is about to be replaced with the given values.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.