10 #include "mlir/Config/mlir-config.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "llvm/Support/SaveAndRestore.h"
23 #include "llvm/Support/ScopedPrinter.h"
29 #define DEBUG_TYPE "dialect-conversion"
32 template <
typename... Args>
33 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
36 os.startLine() <<
"} -> SUCCESS";
38 os.getOStream() <<
" : "
39 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
40 os.getOStream() <<
"\n";
45 template <
typename... Args>
46 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
49 os.startLine() <<
"} -> FAILURE : "
50 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
60 if (
OpResult inputRes = dyn_cast<OpResult>(value))
61 insertPt = ++inputRes.getOwner()->getIterator();
68 assert(!vals.empty() &&
"expected at least one value");
71 for (
Value v : vals.drop_front()) {
85 assert(dom &&
"unable to find valid insertion point");
103 struct ValueVectorMapInfo {
106 static ::llvm::hash_code getHashValue(
const ValueVector &val) {
107 return ::llvm::hash_combine_range(val);
116 struct ConversionValueMapping {
119 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
138 template <
typename T>
139 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
142 template <
typename OldVal,
typename NewVal>
143 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
144 map(OldVal &&oldVal, NewVal &&newVal) {
148 assert(next != oldVal &&
"inserting cyclic mapping");
149 auto it = mapping.find(next);
150 if (it == mapping.end())
155 mappedTo.insert_range(newVal);
157 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
161 template <
typename OldVal,
typename NewVal>
162 std::enable_if_t<!IsValueVector<OldVal>::value ||
163 !IsValueVector<NewVal>::value>
164 map(OldVal &&oldVal, NewVal &&newVal) {
165 if constexpr (IsValueVector<OldVal>{}) {
166 map(std::forward<OldVal>(oldVal),
ValueVector{newVal});
167 }
else if constexpr (IsValueVector<NewVal>{}) {
168 map(
ValueVector{oldVal}, std::forward<NewVal>(newVal));
174 void map(
Value oldVal, SmallVector<Value> &&newVal) {
179 void erase(
const ValueVector &value) { mapping.erase(value); }
191 ConversionValueMapping::lookupOrDefault(
Value from,
200 desiredValue = current;
204 for (
Value v : current) {
205 auto it = mapping.find({v});
206 if (it != mapping.end()) {
207 llvm::append_range(next, it->second);
212 if (next != current) {
214 current = std::move(next);
226 auto it = mapping.find(current);
227 if (it == mapping.end()) {
231 current = it->second;
237 return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
242 ValueVector result = lookupOrDefault(from, desiredTypes);
255 struct RewriterState {
256 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
257 unsigned numReplacedOps)
258 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
259 numReplacedOps(numReplacedOps) {}
262 unsigned numRewrites;
265 unsigned numIgnoredOperations;
268 unsigned numReplacedOps;
280 notifyIRErased(listener, op);
289 notifyIRErased(listener, b);
320 UnresolvedMaterialization
323 virtual ~IRRewrite() =
default;
326 virtual void rollback() = 0;
345 Kind getKind()
const {
return kind; }
347 static bool classof(
const IRRewrite *
rewrite) {
return true; }
351 :
kind(
kind), rewriterImpl(rewriterImpl) {}
360 class BlockRewrite :
public IRRewrite {
363 Block *getBlock()
const {
return block; }
365 static bool classof(
const IRRewrite *
rewrite) {
366 return rewrite->getKind() >= Kind::CreateBlock &&
367 rewrite->getKind() <= Kind::ReplaceBlockArg;
373 : IRRewrite(
kind, rewriterImpl), block(block) {}
382 class CreateBlockRewrite :
public BlockRewrite {
385 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
387 static bool classof(
const IRRewrite *
rewrite) {
388 return rewrite->getKind() == Kind::CreateBlock;
397 void rollback()
override {
401 while (!blockOps.empty())
402 blockOps.remove(blockOps.begin());
415 class EraseBlockRewrite :
public BlockRewrite {
418 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
419 region(block->
getParent()), insertBeforeBlock(block->getNextNode()) {}
421 static bool classof(
const IRRewrite *
rewrite) {
422 return rewrite->getKind() == Kind::EraseBlock;
425 ~EraseBlockRewrite()
override {
427 "rewrite was neither rolled back nor committed/cleaned up");
430 void rollback()
override {
433 assert(block &&
"expected block");
434 auto &blockList = region->getBlocks();
438 blockList.insert(before, block);
443 assert(block &&
"expected block");
447 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
448 notifyIRErased(listener, *block);
453 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
455 assert(block->
empty() &&
"expected empty block");
469 Block *insertBeforeBlock;
475 class InlineBlockRewrite :
public BlockRewrite {
479 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
480 sourceBlock(sourceBlock),
481 firstInlinedInst(sourceBlock->
empty() ?
nullptr
482 : &sourceBlock->
front()),
483 lastInlinedInst(sourceBlock->
empty() ?
nullptr : &sourceBlock->
back()) {
489 assert(!getConfig().listener &&
490 "InlineBlockRewrite not supported if listener is attached");
493 static bool classof(
const IRRewrite *
rewrite) {
494 return rewrite->getKind() == Kind::InlineBlock;
497 void rollback()
override {
500 if (firstInlinedInst) {
501 assert(lastInlinedInst &&
"expected operation");
521 class MoveBlockRewrite :
public BlockRewrite {
525 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
526 insertBeforeBlock(insertBeforeBlock) {}
528 static bool classof(
const IRRewrite *
rewrite) {
529 return rewrite->getKind() == Kind::MoveBlock;
542 void rollback()
override {
555 Block *insertBeforeBlock;
559 class BlockTypeConversionRewrite :
public BlockRewrite {
563 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
564 newBlock(newBlock) {}
566 static bool classof(
const IRRewrite *
rewrite) {
567 return rewrite->getKind() == Kind::BlockTypeConversion;
570 Block *getOrigBlock()
const {
return block; }
572 Block *getNewBlock()
const {
return newBlock; }
576 void rollback()
override;
586 class ReplaceBlockArgRewrite :
public BlockRewrite {
591 : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
592 converter(converter) {}
594 static bool classof(
const IRRewrite *
rewrite) {
595 return rewrite->getKind() == Kind::ReplaceBlockArg;
600 void rollback()
override;
610 class OperationRewrite :
public IRRewrite {
613 Operation *getOperation()
const {
return op; }
615 static bool classof(
const IRRewrite *
rewrite) {
616 return rewrite->getKind() >= Kind::MoveOperation &&
617 rewrite->getKind() <= Kind::UnresolvedMaterialization;
623 : IRRewrite(
kind, rewriterImpl), op(op) {}
630 class MoveOperationRewrite :
public OperationRewrite {
634 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
635 insertBeforeOp(insertBeforeOp) {}
637 static bool classof(
const IRRewrite *
rewrite) {
638 return rewrite->getKind() == Kind::MoveOperation;
652 void rollback()
override {
670 class ModifyOperationRewrite :
public OperationRewrite {
674 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
682 name.initOpProperties(propCopy, prop);
686 static bool classof(
const IRRewrite *
rewrite) {
687 return rewrite->getKind() == Kind::ModifyOperation;
690 ~ModifyOperationRewrite()
override {
691 assert(!propertiesStorage &&
692 "rewrite was neither committed nor rolled back");
698 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
701 if (propertiesStorage) {
705 name.destroyOpProperties(propCopy);
706 operator delete(propertiesStorage);
707 propertiesStorage =
nullptr;
711 void rollback()
override {
717 if (propertiesStorage) {
720 name.destroyOpProperties(propCopy);
721 operator delete(propertiesStorage);
722 propertiesStorage =
nullptr;
729 DictionaryAttr attrs;
732 void *propertiesStorage =
nullptr;
739 class ReplaceOperationRewrite :
public OperationRewrite {
743 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
744 converter(converter) {}
746 static bool classof(
const IRRewrite *
rewrite) {
747 return rewrite->getKind() == Kind::ReplaceOperation;
752 void rollback()
override;
762 class CreateOperationRewrite :
public OperationRewrite {
766 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
768 static bool classof(
const IRRewrite *
rewrite) {
769 return rewrite->getKind() == Kind::CreateOperation;
778 void rollback()
override;
782 enum MaterializationKind {
795 class UnresolvedMaterializationRewrite :
public OperationRewrite {
798 UnrealizedConversionCastOp op,
800 MaterializationKind
kind,
Type originalType,
803 static bool classof(
const IRRewrite *
rewrite) {
804 return rewrite->getKind() == Kind::UnresolvedMaterialization;
807 void rollback()
override;
809 UnrealizedConversionCastOp getOperation()
const {
810 return cast<UnrealizedConversionCastOp>(op);
815 return converterAndKind.getPointer();
819 MaterializationKind getMaterializationKind()
const {
820 return converterAndKind.getInt();
824 Type getOriginalType()
const {
return originalType; }
829 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
842 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
845 template <
typename RewriteTy,
typename R>
846 static bool hasRewrite(R &&rewrites,
Operation *op) {
847 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
848 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
849 return rewriteTy && rewriteTy->getOperation() == op;
855 template <
typename RewriteTy,
typename R>
856 static bool hasRewrite(R &&rewrites,
Block *block) {
857 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
858 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
859 return rewriteTy && rewriteTy->getBlock() == block;
879 RewriterState getCurrentState();
883 void applyRewrites();
888 void resetState(RewriterState state, StringRef patternName =
"");
892 template <
typename RewriteTy,
typename... Args>
895 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
901 void undoRewrites(
unsigned numRewritesToKeep = 0, StringRef patternName =
"");
907 LogicalResult remapValues(StringRef valueDiagTag,
908 std::optional<Location> inputLoc,
935 Block *applySignatureConversion(
955 void eraseBlock(
Block *block);
979 UnrealizedConversionCastOp *castOp =
nullptr);
986 Value findOrBuildReplacementValue(
Value value,
994 void notifyOperationInserted(
Operation *op,
998 void notifyBlockInserted(
Block *block,
Region *previous,
1017 std::function<
void(
Operation *)> opErasedCallback =
nullptr)
1019 opErasedCallback(opErasedCallback) {}
1031 if (wasErased(block))
1033 assert(block->
empty() &&
"expected empty block");
1038 bool wasErased(
void *ptr)
const {
return erased.contains(ptr); }
1042 if (opErasedCallback)
1043 opErasedCallback(op);
1053 std::function<void(
Operation *)> opErasedCallback;
1114 llvm::ScopedPrinter logger{llvm::dbgs()};
1121 return rewriterImpl.config;
1124 void BlockTypeConversionRewrite::commit(
RewriterBase &rewriter) {
1128 if (
auto *listener =
1129 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
1130 for (
Operation *op : getNewBlock()->getUsers())
1131 listener->notifyOperationModified(op);
1134 void BlockTypeConversionRewrite::rollback() {
1135 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1138 void ReplaceBlockArgRewrite::commit(
RewriterBase &rewriter) {
1139 Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
1143 if (isa<BlockArgument>(repl)) {
1151 Operation *replOp = cast<OpResult>(repl).getOwner();
1159 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
1161 void ReplaceOperationRewrite::commit(
RewriterBase &rewriter) {
1163 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1166 SmallVector<Value> replacements =
1168 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1173 listener->notifyOperationReplaced(op, replacements);
1176 for (
auto [result, newValue] :
1177 llvm::zip_equal(op->
getResults(), replacements))
1183 if (getConfig().unlegalizedOps)
1184 getConfig().unlegalizedOps->erase(op);
1188 notifyIRErased(listener, *op);
1195 void ReplaceOperationRewrite::rollback() {
1197 rewriterImpl.mapping.erase({result});
1200 void ReplaceOperationRewrite::cleanup(
RewriterBase &rewriter) {
1204 void CreateOperationRewrite::rollback() {
1206 while (!region.getBlocks().empty())
1207 region.getBlocks().remove(region.getBlocks().begin());
1213 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1217 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
1218 converterAndKind(converter,
kind), originalType(originalType),
1219 mappedValues(std::move(mappedValues)) {
1220 assert((!originalType || kind == MaterializationKind::Target) &&
1221 "original type is valid only for target materializations");
1225 void UnresolvedMaterializationRewrite::rollback() {
1226 if (!mappedValues.empty())
1227 rewriterImpl.
mapping.erase(mappedValues);
1237 for (
size_t i = 0; i <
rewrites.size(); ++i)
1243 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1247 rewrite->cleanup(eraseRewriter);
1259 StringRef patternName) {
1264 while (
ignoredOps.size() != state.numIgnoredOperations)
1267 while (
replacedOps.size() != state.numReplacedOps)
1272 StringRef patternName) {
1274 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep))) {
1276 !isa<UnresolvedMaterializationRewrite>(
rewrite)) {
1278 llvm::report_fatal_error(
"pattern '" + patternName +
1279 "' rollback of IR modifications requested");
1283 rewrites.resize(numRewritesToKeep);
1287 StringRef valueDiagTag, std::optional<Location> inputLoc,
1290 remapped.reserve(llvm::size(values));
1293 Value operand = it.value();
1301 remapped.push_back(
mapping.lookupOrDefault(operand));
1309 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1310 << it.index() <<
", type was " << origType;
1315 if (legalTypes.empty()) {
1316 remapped.push_back({});
1325 remapped.push_back(std::move(repl));
1330 repl =
mapping.lookupOrDefault(operand);
1333 repl, repl, legalTypes,
1335 remapped.push_back(castValues);
1359 if (region->
empty())
1364 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1366 std::optional<TypeConverter::SignatureConversion> conversion =
1376 if (entryConversion)
1379 std::optional<TypeConverter::SignatureConversion> conversion =
1391 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1393 if (hasRewrite<BlockTypeConversionRewrite>(
rewrites, block))
1394 llvm::report_fatal_error(
"block was already converted");
1408 for (
unsigned i = 0; i < origArgCount; ++i) {
1410 if (!inputMap || inputMap->replacedWithValues())
1413 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1414 newLocs[inputMap->inputNo +
j] = origLoc;
1421 convertedTypes, newLocs);
1431 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->
end());
1434 while (!block->
empty())
1441 for (
unsigned i = 0; i != origArgCount; ++i) {
1445 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1452 MaterializationKind::Source,
1456 origArgType,
Type(), converter)
1462 if (inputMap->replacedWithValues()) {
1464 assert(inputMap->size == 0 &&
1465 "invalid to provide a replacement value when the argument isn't "
1474 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1478 appendRewrite<BlockTypeConversionRewrite>(block, newBlock);
1497 UnrealizedConversionCastOp *castOp) {
1498 assert((!originalType ||
kind == MaterializationKind::Target) &&
1499 "original type is valid only for target materializations");
1500 assert(
TypeRange(inputs) != outputTypes &&
1501 "materialization is not necessary");
1505 OpBuilder builder(outputTypes.front().getContext());
1508 builder.
create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
1509 if (!valuesToMap.empty())
1510 mapping.map(valuesToMap, convertOp.getResults());
1512 *castOp = convertOp;
1513 appendRewrite<UnresolvedMaterializationRewrite>(
1514 convertOp, converter,
kind, originalType, std::move(valuesToMap));
1515 return convertOp.getResults();
1525 return repl.front();
1532 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1539 repl =
mapping.lookupOrNull(value);
1573 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"'(" << op
1577 "attempting to insert into a block within a replaced/erased op");
1579 if (!previous.
isSet()) {
1581 appendRewrite<CreateOperationRewrite>(op);
1588 appendRewrite<MoveOperationRewrite>(op, previous.
getBlock(), prevOp);
1594 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1598 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1603 "attempting to replace/erase an unresolved materialization");
1607 for (
auto [repl, result] : llvm::zip_equal(newValues, op->
getResults())) {
1622 mapping.map(
static_cast<Value>(result), std::move(repl));
1632 appendRewrite<ReplaceBlockArgRewrite>(from.
getOwner(), from, converter);
1638 "attempting to erase a block within a replaced/erased op");
1639 appendRewrite<EraseBlockRewrite>(block);
1654 "attempting to insert into a region within a replaced/erased op");
1659 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
1660 <<
"'(" << parent <<
")\n";
1663 <<
"** Insert Block into detached Region (nullptr parent op)'\n";
1671 appendRewrite<CreateBlockRewrite>(block);
1674 Block *prevBlock = previousIt == previous->
end() ? nullptr : &*previousIt;
1675 appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1681 appendRewrite<InlineBlockRewrite>(dest, source, before);
1688 reasonCallback(
diag);
1689 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
1699 ConversionPatternRewriter::ConversionPatternRewriter(
1703 setListener(
impl.get());
1709 assert(op && newOp &&
"expected non-null op");
1715 "incorrect # of replacement values");
1717 impl->logger.startLine()
1718 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
1724 impl->replaceOp(op, std::move(newVals));
1730 "incorrect # of replacement values");
1732 impl->logger.startLine()
1733 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
1735 impl->replaceOp(op, std::move(newValues));
1740 impl->logger.startLine()
1741 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
1744 impl->replaceOp(op, std::move(nullRepls));
1748 impl->eraseBlock(block);
1755 "attempting to apply a signature conversion to a block within a "
1756 "replaced/erased op");
1757 return impl->applySignatureConversion(*
this, block, converter, conversion);
1764 "attempting to apply a signature conversion to a block within a "
1765 "replaced/erased op");
1766 return impl->convertRegionTypes(*
this, region, converter, entryConversion);
1772 impl->logger.startLine() <<
"** Replace Argument : '" << from <<
"'";
1774 impl->logger.getOStream() <<
" (in region of '" << parentOp->getName()
1775 <<
"' (" << parentOp <<
")\n";
1777 impl->logger.getOStream() <<
" (unlinked block)\n";
1780 impl->replaceUsesOfBlockArgument(from, to,
impl->currentTypeConverter);
1785 if (failed(
impl->remapValues(
"value", std::nullopt, *
this, key,
1788 assert(remappedValues.front().size() == 1 &&
"1:N conversion not supported");
1789 return remappedValues.front().front();
1798 if (failed(
impl->remapValues(
"value", std::nullopt, *
this, keys,
1801 for (
const auto &values : remapped) {
1802 assert(values.size() == 1 &&
"1:N conversion not supported");
1803 results.push_back(values.front());
1813 "incorrect # of argument replacement values");
1815 "attempting to inline a block from a replaced/erased op");
1817 "attempting to inline a block into a replaced/erased op");
1818 auto opIgnored = [&](
Operation *op) {
return impl->isOpIgnored(op); };
1821 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
1822 "expected 'source' to have no predecessors");
1831 bool fastPath = !
impl->config.listener;
1834 impl->inlineBlockBefore(source, dest, before);
1837 for (
auto it : llvm::zip(source->
getArguments(), argValues))
1838 replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1845 while (!source->
empty())
1846 moveOpBefore(&source->
front(), dest, before);
1854 assert(!
impl->wasOpReplaced(op) &&
1855 "attempting to modify a replaced/erased op");
1857 impl->pendingRootUpdates.insert(op);
1859 impl->appendRewrite<ModifyOperationRewrite>(op);
1863 assert(!
impl->wasOpReplaced(op) &&
1864 "attempting to modify a replaced/erased op");
1866 impl->patternModifiedOps.insert(op);
1871 assert(
impl->pendingRootUpdates.erase(op) &&
1872 "operation did not have a pending in-place update");
1878 assert(
impl->pendingRootUpdates.erase(op) &&
1879 "operation did not have a pending in-place update");
1882 auto it = llvm::find_if(
1883 llvm::reverse(
impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
1884 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1885 return modifyRewrite && modifyRewrite->getOperation() == op;
1887 assert(it !=
impl->rewrites.rend() &&
"no root update started on op");
1889 int updateIdx = std::prev(
impl->rewrites.rend()) - it;
1890 impl->rewrites.erase(
impl->rewrites.begin() + updateIdx);
1904 oneToOneOperands.reserve(operands.size());
1906 if (operand.size() != 1)
1907 llvm::report_fatal_error(
"pattern '" + getDebugName() +
1908 "' does not support 1:N conversion");
1909 oneToOneOperands.push_back(operand.front());
1911 return oneToOneOperands;
1918 auto &rewriterImpl = dialectRewriter.getImpl();
1922 getTypeConverter());
1931 llvm::to_vector_of<ValueRange>(remapped);
1932 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
1944 class OperationLegalizer {
1964 LogicalResult legalizeWithFold(
Operation *op,
1969 LogicalResult legalizeWithPattern(
Operation *op,
1986 legalizePatternBlockRewrites(
Operation *op,
2008 void buildLegalizationGraph(
2009 LegalizationPatterns &anyOpLegalizerPatterns,
2020 void computeLegalizationGraphBenefit(
2021 LegalizationPatterns &anyOpLegalizerPatterns,
2026 unsigned computeOpLegalizationDepth(
2033 unsigned applyCostModelToPatterns(
2059 LegalizationPatterns anyOpLegalizerPatterns;
2061 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2062 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2065 bool OperationLegalizer::isIllegal(
Operation *op)
const {
2066 return target.isIllegal(op);
2070 OperationLegalizer::legalize(
Operation *op,
2073 const char *logLineComment =
2074 "//===-------------------------------------------===//\n";
2079 logger.getOStream() <<
"\n";
2080 logger.startLine() << logLineComment;
2081 logger.startLine() <<
"Legalizing operation : '" << op->
getName() <<
"'("
2087 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
2088 logger.getOStream() <<
"\n\n";
2093 if (
auto legalityInfo = target.isLegal(op)) {
2096 logger,
"operation marked legal by the target{0}",
2097 legalityInfo->isRecursivelyLegal
2098 ?
"; NOTE: operation is recursively legal; skipping internals"
2100 logger.startLine() << logLineComment;
2105 if (legalityInfo->isRecursivelyLegal) {
2118 logSuccess(logger,
"operation marked 'ignored' during conversion");
2119 logger.startLine() << logLineComment;
2127 if (succeeded(legalizeWithFold(op, rewriter))) {
2130 logger.startLine() << logLineComment;
2136 if (succeeded(legalizeWithPattern(op, rewriter))) {
2139 logger.startLine() << logLineComment;
2145 logFailure(logger,
"no matched legalization pattern");
2146 logger.startLine() << logLineComment;
2153 template <
typename T>
2155 T result = std::move(obj);
2161 OperationLegalizer::legalizeWithFold(
Operation *op,
2163 auto &rewriterImpl = rewriter.
getImpl();
2165 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2166 rewriterImpl.
logger.indent();
2171 SmallVector<Value, 2> replacementValues;
2172 SmallVector<Operation *, 2> newOps;
2174 if (failed(rewriter.
tryFold(op, replacementValues, &newOps))) {
2181 if (replacementValues.empty())
2182 return legalize(op, rewriter);
2186 if (failed(legalize(newOp, rewriter))) {
2188 "failed to legalize generated constant '{0}'",
2198 rewriter.
replaceOp(op, replacementValues);
2205 OperationLegalizer::legalizeWithPattern(
Operation *op,
2207 auto &rewriterImpl = rewriter.
getImpl();
2210 auto canApply = [&](
const Pattern &pattern) {
2211 bool canApply = canApplyPattern(op, pattern, rewriter);
2212 if (canApply &&
config.listener)
2213 config.listener->notifyPatternBegin(pattern, op);
2219 auto onFailure = [&](
const Pattern &pattern) {
2228 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2235 config.listener->notifyPatternEnd(pattern, failure());
2236 rewriterImpl.
resetState(curState, pattern.getDebugName());
2237 appliedPatterns.erase(&pattern);
2242 auto onSuccess = [&](
const Pattern &pattern) {
2249 auto result = legalizePatternResult(op, pattern, rewriter, newOps,
2250 modifiedOps, insertedBlocks);
2251 appliedPatterns.erase(&pattern);
2252 if (failed(result)) {
2255 << pattern.getDebugName()
2256 <<
"' produced IR that could not be legalized";
2257 rewriterImpl.
resetState(curState, pattern.getDebugName());
2260 config.listener->notifyPatternEnd(pattern, result);
2265 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2269 bool OperationLegalizer::canApplyPattern(
Operation *op,
const Pattern &pattern,
2272 auto &os = rewriter.
getImpl().logger;
2273 os.getOStream() <<
"\n";
2274 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2276 os.getOStream() <<
")' {\n";
2283 !appliedPatterns.insert(&pattern).second) {
2291 LogicalResult OperationLegalizer::legalizePatternResult(
2297 assert(
impl.pendingRootUpdates.empty() &&
"dangling root updates");
2299 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2301 auto newRewrites = llvm::drop_begin(
impl.rewrites, curState.numRewrites);
2302 auto replacedRoot = [&] {
2303 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2305 auto updatedRootInPlace = [&] {
2306 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2308 if (!replacedRoot() && !updatedRootInPlace())
2309 llvm::report_fatal_error(
"expected pattern to replace the root operation");
2313 if (failed(legalizePatternBlockRewrites(op, rewriter,
impl, insertedBlocks,
2315 failed(legalizePatternRootUpdates(rewriter,
impl, modifiedOps)) ||
2316 failed(legalizePatternCreatedOperations(rewriter,
impl, newOps))) {
2320 LLVM_DEBUG(
logSuccess(
impl.logger,
"pattern applied successfully"));
2324 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2333 for (
Block *block : insertedBlocks) {
2341 if (
auto *converter =
impl.regionToConverter.lookup(block->
getParent())) {
2342 std::optional<TypeConverter::SignatureConversion> conversion =
2345 LLVM_DEBUG(
logFailure(
impl.logger,
"failed to convert types of moved "
2349 impl.applySignatureConversion(rewriter, block, converter, *conversion);
2357 if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
2358 if (failed(legalize(parentOp, rewriter))) {
2360 impl.logger,
"operation '{0}'({1}) became illegal after rewrite",
2361 parentOp->
getName(), parentOp));
2369 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2373 if (failed(legalize(op, rewriter))) {
2375 "failed to legalize generated operation '{0}'({1})",
2383 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2387 if (failed(legalize(op, rewriter))) {
2389 impl.logger,
"failed to legalize operation updated in-place '{0}'",
2401 void OperationLegalizer::buildLegalizationGraph(
2402 LegalizationPatterns &anyOpLegalizerPatterns,
2413 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2414 std::optional<OperationName> root = pattern.
getRootKind();
2420 anyOpLegalizerPatterns.push_back(&pattern);
2425 if (target.getOpAction(*root) == LegalizationAction::Legal)
2430 invalidPatterns[*root].insert(&pattern);
2432 parentOps[op].insert(*root);
2435 patternWorklist.insert(&pattern);
2443 if (!anyOpLegalizerPatterns.empty()) {
2444 for (
const Pattern *pattern : patternWorklist)
2445 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2449 while (!patternWorklist.empty()) {
2450 auto *pattern = patternWorklist.pop_back_val();
2454 std::optional<LegalizationAction> action = target.getOpAction(op);
2455 return !legalizerPatterns.count(op) &&
2456 (!action || action == LegalizationAction::Illegal);
2462 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2463 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2467 for (
auto op : parentOps[*pattern->
getRootKind()])
2468 patternWorklist.set_union(invalidPatterns[op]);
2472 void OperationLegalizer::computeLegalizationGraphBenefit(
2473 LegalizationPatterns &anyOpLegalizerPatterns,
2479 for (
auto &opIt : legalizerPatterns)
2480 if (!minOpPatternDepth.count(opIt.first))
2481 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2487 if (!anyOpLegalizerPatterns.empty())
2488 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2494 applicator.applyCostModel([&](
const Pattern &pattern) {
2496 if (std::optional<OperationName> rootName = pattern.
getRootKind())
2497 orderedPatternList = legalizerPatterns[*rootName];
2499 orderedPatternList = anyOpLegalizerPatterns;
2502 auto *it = llvm::find(orderedPatternList, &pattern);
2503 if (it == orderedPatternList.end())
2507 return PatternBenefit(std::distance(it, orderedPatternList.end()));
2511 unsigned OperationLegalizer::computeOpLegalizationDepth(
2515 auto depthIt = minOpPatternDepth.find(op);
2516 if (depthIt != minOpPatternDepth.end())
2517 return depthIt->second;
2521 auto opPatternsIt = legalizerPatterns.find(op);
2522 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2531 unsigned minDepth = applyCostModelToPatterns(
2532 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2533 minOpPatternDepth[op] = minDepth;
2537 unsigned OperationLegalizer::applyCostModelToPatterns(
2544 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2545 patternsByDepth.reserve(
patterns.size());
2549 unsigned generatedOpDepth = computeOpLegalizationDepth(
2550 generatedOp, minOpPatternDepth, legalizerPatterns);
2551 depth =
std::max(depth, generatedOpDepth + 1);
2553 patternsByDepth.emplace_back(pattern, depth);
2556 minDepth =
std::min(minDepth, depth);
2561 if (patternsByDepth.size() == 1)
2565 llvm::stable_sort(patternsByDepth,
2566 [](
const std::pair<const Pattern *, unsigned> &lhs,
2567 const std::pair<const Pattern *, unsigned> &rhs) {
2570 if (lhs.second != rhs.second)
2571 return lhs.second < rhs.second;
2574 auto lhsBenefit = lhs.first->getBenefit();
2575 auto rhsBenefit = rhs.first->getBenefit();
2576 return lhsBenefit > rhsBenefit;
2581 for (
auto &patternIt : patternsByDepth)
2582 patterns.push_back(patternIt.first);
2590 enum OpConversionMode {
2613 OpConversionMode mode)
2628 OperationLegalizer opLegalizer;
2631 OpConversionMode mode;
2638 if (failed(opLegalizer.legalize(op, rewriter))) {
2641 if (mode == OpConversionMode::Full)
2643 <<
"failed to legalize operation '" << op->
getName() <<
"'";
2647 if (mode == OpConversionMode::Partial) {
2648 if (opLegalizer.isIllegal(op))
2650 <<
"failed to legalize operation '" << op->
getName()
2651 <<
"' that was explicitly marked illegal";
2655 }
else if (mode == OpConversionMode::Analysis) {
2665 static LogicalResult
2667 UnresolvedMaterializationRewrite *
rewrite) {
2668 UnrealizedConversionCastOp op =
rewrite->getOperation();
2669 assert(!op.use_empty() &&
2670 "expected that dead materializations have already been DCE'd");
2676 SmallVector<Value> newMaterialization;
2677 switch (
rewrite->getMaterializationKind()) {
2678 case MaterializationKind::Target:
2680 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
2683 case MaterializationKind::Source:
2684 assert(op->getNumResults() == 1 &&
"expected single result");
2686 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
2688 newMaterialization.push_back(sourceMat);
2691 if (!newMaterialization.empty()) {
2693 ValueRange newMaterializationRange(newMaterialization);
2694 assert(
TypeRange(newMaterializationRange) == op.getResultTypes() &&
2695 "materialization callback produced value of incorrect type");
2697 rewriter.
replaceOp(op, newMaterialization);
2703 <<
"failed to legalize unresolved materialization "
2705 << inputOperands.
getTypes() <<
") to ("
2706 << op.getResultTypes()
2707 <<
") that remained live after conversion";
2708 diag.attachNote(op->getUsers().begin()->getLoc())
2709 <<
"see existing live user here: " << *op->getUsers().begin();
2720 for (
auto *op : ops) {
2723 toConvert.push_back(op);
2726 auto legalityInfo = target.
isLegal(op);
2727 if (legalityInfo && legalityInfo->isRecursivelyLegal)
2737 for (
auto *op : toConvert) {
2738 if (failed(convert(rewriter, op))) {
2759 for (
auto it : materializations)
2760 allCastOps.push_back(it.first);
2771 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2772 auto it = materializations.find(castOp);
2773 assert(it != materializations.end() &&
"inconsistent state");
2795 auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2796 for (
Value v : castOp.getInputs())
2797 if (
auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2798 worklist.insert(inputCastOp);
2805 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2806 if (castOp.getInputs().empty())
2809 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2812 if (inputCastOp.getOutputs() != castOp.getInputs())
2818 while (!worklist.empty()) {
2819 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2820 if (castOp->use_empty()) {
2823 enqueueOperands(castOp);
2824 if (remainingCastOps)
2825 erasedOps.insert(castOp.getOperation());
2832 UnrealizedConversionCastOp nextCast = castOp;
2834 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2838 enqueueOperands(castOp);
2839 castOp.replaceAllUsesWith(nextCast.getInputs());
2840 if (remainingCastOps)
2841 erasedOps.insert(castOp.getOperation());
2845 nextCast = getInputCast(nextCast);
2849 if (remainingCastOps)
2850 for (UnrealizedConversionCastOp op : castOps)
2851 if (!erasedOps.contains(op.getOperation()))
2852 remainingCastOps->push_back(op);
2861 assert(!types.empty() &&
"expected valid types");
2862 remapInput(origInputNo, argTypes.size(), types.size());
2867 assert(!types.empty() &&
2868 "1->0 type remappings don't need to be added explicitly");
2869 argTypes.append(types.begin(), types.end());
2873 unsigned newInputNo,
2874 unsigned newInputCount) {
2875 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
2876 assert(newInputCount != 0 &&
"expected valid input count");
2877 remappedInputs[origInputNo] =
2878 InputMapping{newInputNo, newInputCount, {}};
2883 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
2891 assert(t &&
"expected non-null type");
2894 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2897 cacheReadLock.lock();
2898 auto existingIt = cachedDirectConversions.find(t);
2899 if (existingIt != cachedDirectConversions.end()) {
2900 if (existingIt->second)
2901 results.push_back(existingIt->second);
2902 return success(existingIt->second !=
nullptr);
2904 auto multiIt = cachedMultiConversions.find(t);
2905 if (multiIt != cachedMultiConversions.end()) {
2906 results.append(multiIt->second.begin(), multiIt->second.end());
2912 size_t currentCount = results.size();
2914 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2917 for (
const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2918 if (std::optional<LogicalResult> result = converter(t, results)) {
2920 cacheWriteLock.lock();
2921 if (!succeeded(*result)) {
2922 assert(results.size() == currentCount &&
2923 "failed type conversion should not change results");
2924 cachedDirectConversions.try_emplace(t,
nullptr);
2927 auto newTypes =
ArrayRef<Type>(results).drop_front(currentCount);
2928 if (newTypes.size() == 1)
2929 cachedDirectConversions.try_emplace(t, newTypes.front());
2931 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2934 assert(results.size() == currentCount &&
2935 "failed type conversion should not change results");
2948 return results.size() == 1 ? results.front() :
nullptr;
2954 for (
Type type : types)
2968 return llvm::all_of(*region, [
this](
Block &block) {
2974 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2986 if (convertedTypes.empty())
2990 result.
addInputs(inputNo, convertedTypes);
2996 unsigned origInputOffset)
const {
2997 for (
unsigned i = 0, e = types.size(); i != e; ++i)
3006 for (
const SourceMaterializationCallbackFn &fn :
3007 llvm::reverse(sourceMaterializations))
3008 if (
Value result = fn(builder, resultType, inputs, loc))
3016 Type originalType)
const {
3018 builder, loc,
TypeRange(resultType), inputs, originalType);
3021 assert(result.size() == 1 &&
"expected single result");
3022 return result.front();
3027 Type originalType)
const {
3028 for (
const TargetMaterializationCallbackFn &fn :
3029 llvm::reverse(targetMaterializations)) {
3031 fn(builder, resultTypes, inputs, loc, originalType);
3035 "callback produced incorrect number of values or values with "
3042 std::optional<TypeConverter::SignatureConversion>
3046 return std::nullopt;
3069 return impl.getInt() == resultTag;
3073 return impl.getInt() == naTag;
3077 return impl.getInt() == abortTag;
3081 assert(hasResult() &&
"Cannot get result from N/A or abort");
3082 return impl.getPointer();
3085 std::optional<Attribute>
3087 for (
const TypeAttributeConversionCallbackFn &fn :
3088 llvm::reverse(typeAttributeConversions)) {
3093 return std::nullopt;
3095 return std::nullopt;
3105 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3111 SmallVector<Type, 1> newResults;
3113 failed(typeConverter.
convertTypes(type.getResults(), newResults)) ||
3115 typeConverter, &result)))
3132 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3140 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3145 struct AnyFunctionOpInterfaceSignatureConversion
3157 FailureOr<Operation *>
3161 assert(op &&
"Invalid op");
3175 return rewriter.
create(newOp);
3181 patterns.add<FunctionOpInterfaceSignatureConversion>(
3182 functionLikeOpName,
patterns.getContext(), converter);
3187 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3197 legalOperations[op].action = action;
3202 for (StringRef dialect : dialectNames)
3203 legalDialects[dialect] = action;
3207 -> std::optional<LegalizationAction> {
3208 std::optional<LegalizationInfo> info = getOpInfo(op);
3209 return info ? info->action : std::optional<LegalizationAction>();
3213 -> std::optional<LegalOpDetails> {
3214 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3216 return std::nullopt;
3219 auto isOpLegal = [&] {
3221 if (info->action == LegalizationAction::Dynamic) {
3222 std::optional<bool> result = info->legalityFn(op);
3228 return info->action == LegalizationAction::Legal;
3231 return std::nullopt;
3235 if (info->isRecursivelyLegal) {
3236 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3237 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3239 legalityFnIt->second(op).value_or(
true);
3244 return legalityDetails;
3248 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3252 if (info->action == LegalizationAction::Dynamic) {
3253 std::optional<bool> result = info->legalityFn(op);
3260 return info->action == LegalizationAction::Illegal;
3269 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3271 if (std::optional<bool> result = newCl(op))
3279 void ConversionTarget::setLegalityCallback(
3280 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3281 assert(callback &&
"expected valid legality callback");
3282 auto *infoIt = legalOperations.find(name);
3283 assert(infoIt != legalOperations.end() &&
3284 infoIt->second.action == LegalizationAction::Dynamic &&
3285 "expected operation to already be marked as dynamically legal");
3286 infoIt->second.legalityFn =
3292 auto *infoIt = legalOperations.find(name);
3293 assert(infoIt != legalOperations.end() &&
3294 infoIt->second.action != LegalizationAction::Illegal &&
3295 "expected operation to already be marked as legal");
3296 infoIt->second.isRecursivelyLegal =
true;
3299 std::move(opRecursiveLegalityFns[name]), callback);
3301 opRecursiveLegalityFns.erase(name);
3304 void ConversionTarget::setLegalityCallback(
3306 assert(callback &&
"expected valid legality callback");
3307 for (StringRef dialect : dialects)
3309 std::move(dialectLegalityFns[dialect]), callback);
3312 void ConversionTarget::setLegalityCallback(
3313 const DynamicLegalityCallbackFn &callback) {
3314 assert(callback &&
"expected valid legality callback");
3319 -> std::optional<LegalizationInfo> {
3321 const auto *it = legalOperations.find(op);
3322 if (it != legalOperations.end())
3325 auto dialectIt = legalDialects.find(op.getDialectNamespace());
3326 if (dialectIt != legalDialects.end()) {
3327 DynamicLegalityCallbackFn callback;
3328 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3329 if (dialectFn != dialectLegalityFns.end())
3330 callback = dialectFn->second;
3331 return LegalizationInfo{dialectIt->second,
false,
3335 if (unknownLegalityFn)
3336 return LegalizationInfo{LegalizationAction::Dynamic,
3337 false, unknownLegalityFn};
3338 return std::nullopt;
3341 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3347 auto &rewriterImpl =
3353 auto &rewriterImpl =
3360 static FailureOr<SmallVector<Value>>
3362 SmallVector<Value> mappedValues;
3365 return std::move(mappedValues);
3369 patterns.getPDLPatterns().registerRewriteFunction(
3374 if (failed(results))
3376 return results->front();
3378 patterns.getPDLPatterns().registerRewriteFunction(
3383 patterns.getPDLPatterns().registerRewriteFunction(
3386 auto &rewriterImpl =
3396 patterns.getPDLPatterns().registerRewriteFunction(
3399 TypeRange types) -> FailureOr<SmallVector<Type>> {
3400 auto &rewriterImpl =
3407 if (failed(converter->
convertTypes(types, remappedTypes)))
3409 return std::move(remappedTypes);
3426 OpConversionMode::Partial);
3445 OpConversionMode::Full);
3469 "expected top-level op to be isolated from above");
3472 "expected ops to have a common ancestor");
3481 for (
Operation *op : ops.drop_front()) {
3485 assert(commonAncestor &&
3486 "expected to find a common isolated from above ancestor");
3490 return commonAncestor;
3497 if (
config.legalizableOps)
3498 assert(
config.legalizableOps->empty() &&
"expected empty set");
3508 inverseOperationMap[it.second] = it.first;
3514 OpConversionMode::Analysis);
3515 LogicalResult status = opConverter.convertOperations(opsToConvert);
3519 if (
config.legalizableOps) {
3522 originalLegalizableOps.insert(inverseOperationMap[op]);
3523 *
config.legalizableOps = std::move(originalLegalizableOps);
3527 clonedAncestor->
erase();
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
union mlir::linalg::@1215::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
void erase()
Unlink this Block from its parent region and delete it.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={}) override
PatternRewriter hook for inlining the ops of a block into another block.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to)
Replace all the uses of the block argument from with to.
~ConversionPatternRewriter() override
Base class for the conversion patterns.
SmallVector< Value > getOneToOneAdaptorOperands(ArrayRef< ValueRange > operands) const
Given an array of value ranges, which are the inputs to a 1:N adaptor, try to extract the single valu...
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void dropAllUses()
Drop all uses of this object from their respective owners.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
operand_iterator operand_begin()
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
operand_iterator operand_end()
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
int getPropertiesStorageSize() const
Returns the properties storage size.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
succ_iterator successor_end()
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
succ_iterator successor_begin()
unsigned getNumResults()
Return the number of results held by this operation.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
The general result of a type attribute conversion callback, allowing for early termination.
Attribute getResult() const
static AttributeConversionResult abort()
static AttributeConversionResult na()
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, ArrayRef< Value > replacements)
Remap an input of the original signature to replacements values.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
bool allowPatternRollback
If set to "true", pattern rollback is allowed.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target materializations through the...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
bool wasErased(void *ptr) const
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to, const TypeConverter *converter)
Replace the given block argument with the given values.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
MLIRContext * context
MLIR context.
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Block * > patternInsertedBlocks
A set of blocks that were inserted (newly-created blocks or moved blocks) by the current pattern.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp=nullptr)
Build an unresolved materialization operation given a range of output types and a list of input opera...
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
void replaceOp(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the results of the given operation with the given values and erase the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.