10 #include "mlir/Config/mlir-config.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/SaveAndRestore.h"
24 #include "llvm/Support/ScopedPrinter.h"
30 #define DEBUG_TYPE "dialect-conversion"
33 template <
typename... Args>
34 static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
37 os.startLine() <<
"} -> SUCCESS";
39 os.getOStream() <<
" : "
40 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
41 os.getOStream() <<
"\n";
46 template <
typename... Args>
47 static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
50 os.startLine() <<
"} -> FAILURE : "
51 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
61 if (
OpResult inputRes = dyn_cast<OpResult>(value))
62 insertPt = ++inputRes.getOwner()->getIterator();
77 struct ConversionValueMapping {
80 bool isMappedTo(
Value value)
const {
return mappedTo.contains(value); }
92 Value lookupOrDefault(
Value from,
Type desiredType =
nullptr)
const;
102 for (
Value it = newVal; it; it = mapping.lookupOrNull(it))
103 assert(it != oldVal &&
"inserting cyclic mapping");
105 mapping.map(oldVal, newVal);
106 mappedTo.insert(newVal);
110 void erase(
Value value) { mapping.erase(value); }
121 Value ConversionValueMapping::lookupOrDefault(
Value from,
122 Type desiredType)
const {
127 if (!desiredType || from.
getType() == desiredType)
130 Value mappedValue = mapping.lookupOrNull(from);
137 return desiredValue ? desiredValue : from;
140 Value ConversionValueMapping::lookupOrNull(
Value from,
Type desiredType)
const {
141 Value result = lookupOrDefault(from, desiredType);
142 if (result == from || (desiredType && result.
getType() != desiredType))
153 struct RewriterState {
154 RewriterState(
unsigned numRewrites,
unsigned numIgnoredOperations,
155 unsigned numReplacedOps)
156 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
157 numReplacedOps(numReplacedOps) {}
160 unsigned numRewrites;
163 unsigned numIgnoredOperations;
166 unsigned numReplacedOps;
198 UnresolvedMaterialization
201 virtual ~IRRewrite() =
default;
204 virtual void rollback() = 0;
223 Kind getKind()
const {
return kind; }
225 static bool classof(
const IRRewrite *
rewrite) {
return true; }
229 : kind(kind), rewriterImpl(rewriterImpl) {}
238 class BlockRewrite :
public IRRewrite {
241 Block *getBlock()
const {
return block; }
243 static bool classof(
const IRRewrite *
rewrite) {
244 return rewrite->getKind() >= Kind::CreateBlock &&
245 rewrite->getKind() <= Kind::ReplaceBlockArg;
251 : IRRewrite(kind, rewriterImpl), block(block) {}
260 class CreateBlockRewrite :
public BlockRewrite {
263 : BlockRewrite(
Kind::CreateBlock, rewriterImpl, block) {}
265 static bool classof(
const IRRewrite *
rewrite) {
266 return rewrite->getKind() == Kind::CreateBlock;
272 listener->notifyBlockInserted(block, {}, {});
275 void rollback()
override {
278 auto &blockOps = block->getOperations();
279 while (!blockOps.empty())
280 blockOps.remove(blockOps.begin());
281 block->dropAllUses();
282 if (block->getParent())
293 class EraseBlockRewrite :
public BlockRewrite {
296 : BlockRewrite(
Kind::EraseBlock, rewriterImpl, block),
297 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
299 static bool classof(
const IRRewrite *
rewrite) {
300 return rewrite->getKind() == Kind::EraseBlock;
303 ~EraseBlockRewrite()
override {
305 "rewrite was neither rolled back nor committed/cleaned up");
308 void rollback()
override {
311 assert(block &&
"expected block");
312 auto &blockList = region->getBlocks();
316 blockList.insert(before, block);
322 assert(block &&
"expected block");
323 assert(block->empty() &&
"expected empty block");
327 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
328 listener->notifyBlockErased(block);
333 block->dropAllDefinedValueUses();
344 Block *insertBeforeBlock;
350 class InlineBlockRewrite :
public BlockRewrite {
354 : BlockRewrite(
Kind::InlineBlock, rewriterImpl, block),
355 sourceBlock(sourceBlock),
356 firstInlinedInst(sourceBlock->empty() ? nullptr
357 : &sourceBlock->front()),
358 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
364 assert(!getConfig().listener &&
365 "InlineBlockRewrite not supported if listener is attached");
368 static bool classof(
const IRRewrite *
rewrite) {
369 return rewrite->getKind() == Kind::InlineBlock;
372 void rollback()
override {
375 if (firstInlinedInst) {
376 assert(lastInlinedInst &&
"expected operation");
396 class MoveBlockRewrite :
public BlockRewrite {
400 : BlockRewrite(
Kind::MoveBlock, rewriterImpl, block), region(region),
401 insertBeforeBlock(insertBeforeBlock) {}
403 static bool classof(
const IRRewrite *
rewrite) {
404 return rewrite->getKind() == Kind::MoveBlock;
412 listener->notifyBlockInserted(block, region,
417 void rollback()
override {
430 Block *insertBeforeBlock;
434 class BlockTypeConversionRewrite :
public BlockRewrite {
438 : BlockRewrite(
Kind::BlockTypeConversion, rewriterImpl, origBlock),
439 newBlock(newBlock) {}
441 static bool classof(
const IRRewrite *
rewrite) {
442 return rewrite->getKind() == Kind::BlockTypeConversion;
445 Block *getOrigBlock()
const {
return block; }
447 Block *getNewBlock()
const {
return newBlock; }
451 void rollback()
override;
461 class ReplaceBlockArgRewrite :
public BlockRewrite {
466 : BlockRewrite(
Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
467 converter(converter) {}
469 static bool classof(
const IRRewrite *
rewrite) {
470 return rewrite->getKind() == Kind::ReplaceBlockArg;
475 void rollback()
override;
485 class OperationRewrite :
public IRRewrite {
488 Operation *getOperation()
const {
return op; }
490 static bool classof(
const IRRewrite *
rewrite) {
491 return rewrite->getKind() >= Kind::MoveOperation &&
492 rewrite->getKind() <= Kind::UnresolvedMaterialization;
498 : IRRewrite(kind, rewriterImpl), op(op) {}
505 class MoveOperationRewrite :
public OperationRewrite {
509 : OperationRewrite(
Kind::MoveOperation, rewriterImpl, op), block(block),
510 insertBeforeOp(insertBeforeOp) {}
512 static bool classof(
const IRRewrite *
rewrite) {
513 return rewrite->getKind() == Kind::MoveOperation;
521 listener->notifyOperationInserted(
527 void rollback()
override {
531 block->
getOperations().splice(before, op->getBlock()->getOperations(), op);
545 class ModifyOperationRewrite :
public OperationRewrite {
549 : OperationRewrite(
Kind::ModifyOperation, rewriterImpl, op),
550 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
551 operands(op->operand_begin(), op->operand_end()),
552 successors(op->successor_begin(), op->successor_end()) {
557 name.initOpProperties(propCopy, prop);
561 static bool classof(
const IRRewrite *
rewrite) {
562 return rewrite->getKind() == Kind::ModifyOperation;
565 ~ModifyOperationRewrite()
override {
566 assert(!propertiesStorage &&
567 "rewrite was neither committed nor rolled back");
573 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
574 listener->notifyOperationModified(op);
576 if (propertiesStorage) {
580 name.destroyOpProperties(propCopy);
581 operator delete(propertiesStorage);
582 propertiesStorage =
nullptr;
586 void rollback()
override {
592 if (propertiesStorage) {
595 name.destroyOpProperties(propCopy);
596 operator delete(propertiesStorage);
597 propertiesStorage =
nullptr;
604 DictionaryAttr attrs;
605 SmallVector<Value, 8> operands;
606 SmallVector<Block *, 2> successors;
607 void *propertiesStorage =
nullptr;
614 class ReplaceOperationRewrite :
public OperationRewrite {
618 : OperationRewrite(
Kind::ReplaceOperation, rewriterImpl, op),
619 converter(converter) {}
621 static bool classof(
const IRRewrite *
rewrite) {
622 return rewrite->getKind() == Kind::ReplaceOperation;
627 void rollback()
override;
637 class CreateOperationRewrite :
public OperationRewrite {
641 : OperationRewrite(
Kind::CreateOperation, rewriterImpl, op) {}
643 static bool classof(
const IRRewrite *
rewrite) {
644 return rewrite->getKind() == Kind::CreateOperation;
650 listener->notifyOperationInserted(op, {});
653 void rollback()
override;
657 enum MaterializationKind {
674 class UnresolvedMaterializationRewrite :
public OperationRewrite {
677 UnrealizedConversionCastOp op,
679 MaterializationKind kind,
Type originalType);
681 static bool classof(
const IRRewrite *
rewrite) {
682 return rewrite->getKind() == Kind::UnresolvedMaterialization;
685 void rollback()
override;
687 UnrealizedConversionCastOp getOperation()
const {
688 return cast<UnrealizedConversionCastOp>(op);
693 return converterAndKind.getPointer();
697 MaterializationKind getMaterializationKind()
const {
698 return converterAndKind.getInt();
702 Type getOriginalType()
const {
return originalType; }
707 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
718 template <
typename RewriteTy,
typename R>
720 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
721 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
722 return rewriteTy && rewriteTy->getOperation() == op;
729 template <
typename RewriteTy,
typename R>
731 return any_of(std::forward<R>(rewrites), [&](
auto &
rewrite) {
732 auto *rewriteTy = dyn_cast<RewriteTy>(
rewrite.get());
733 return rewriteTy && rewriteTy->getBlock() == block;
746 : context(ctx), eraseRewriter(ctx), config(config) {}
753 RewriterState getCurrentState();
757 void applyRewrites();
760 void resetState(RewriterState state);
764 template <
typename RewriteTy,
typename... Args>
767 std::make_unique<RewriteTy>(*
this, std::forward<Args>(args)...));
772 void undoRewrites(
unsigned numRewritesToKeep = 0);
778 LogicalResult remapValues(StringRef valueDiagTag,
779 std::optional<Location> inputLoc,
806 Block *applySignatureConversion(
817 Value buildUnresolvedMaterialization(MaterializationKind kind,
844 Value findOrBuildReplacementValue(
Value value,
852 void notifyOperationInserted(
Operation *op,
859 void notifyBlockIsBeingErased(
Block *block);
862 void notifyBlockInserted(
Block *block,
Region *previous,
866 void notifyBlockBeingInlined(
Block *block,
Block *srcBlock,
896 if (wasErased(block))
898 assert(block->
empty() &&
"expected empty block");
903 bool wasErased(
void *ptr)
const {
return erased.contains(ptr); }
967 llvm::ScopedPrinter logger{llvm::dbgs()};
974 return rewriterImpl.
config;
977 void BlockTypeConversionRewrite::commit(
RewriterBase &rewriter) {
982 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener()))
983 for (
Operation *op : getNewBlock()->getUsers())
984 listener->notifyOperationModified(op);
987 void BlockTypeConversionRewrite::rollback() {
988 getNewBlock()->replaceAllUsesWith(getOrigBlock());
991 void ReplaceBlockArgRewrite::commit(
RewriterBase &rewriter) {
996 if (isa<BlockArgument>(repl)) {
1004 Operation *replOp = cast<OpResult>(repl).getOwner();
1012 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.
mapping.erase(arg); }
1014 void ReplaceOperationRewrite::commit(
RewriterBase &rewriter) {
1016 dyn_cast_or_null<RewriterBase::Listener>(rewriter.
getListener());
1019 SmallVector<Value> replacements =
1021 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1026 listener->notifyOperationReplaced(op, replacements);
1029 for (
auto [result, newValue] :
1030 llvm::zip_equal(op->
getResults(), replacements))
1036 if (getConfig().unlegalizedOps)
1037 getConfig().unlegalizedOps->erase(op);
1043 [&](
Operation *op) { listener->notifyOperationErased(op); });
1051 void ReplaceOperationRewrite::rollback() {
1053 rewriterImpl.
mapping.erase(result);
1056 void ReplaceOperationRewrite::cleanup(
RewriterBase &rewriter) {
1060 void CreateOperationRewrite::rollback() {
1062 while (!region.getBlocks().empty())
1063 region.getBlocks().remove(region.getBlocks().begin());
1069 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1072 : OperationRewrite(
Kind::UnresolvedMaterialization, rewriterImpl, op),
1073 converterAndKind(converter, kind), originalType(originalType) {
1074 assert((!originalType || kind == MaterializationKind::Target) &&
1075 "original type is valid only for target materializations");
1079 void UnresolvedMaterializationRewrite::rollback() {
1080 if (getMaterializationKind() == MaterializationKind::Target) {
1081 for (
Value input : op->getOperands())
1082 rewriterImpl.
mapping.erase(input);
1093 for (
size_t i = 0; i <
rewrites.size(); ++i)
1113 while (
ignoredOps.size() != state.numIgnoredOperations)
1116 while (
replacedOps.size() != state.numReplacedOps)
1122 llvm::reverse(llvm::drop_begin(
rewrites, numRewritesToKeep)))
1124 rewrites.resize(numRewritesToKeep);
1128 StringRef valueDiagTag, std::optional<Location> inputLoc,
1131 remapped.reserve(llvm::size(values));
1134 Value operand = it.value();
1142 remapped.push_back(
mapping.lookupOrDefault(operand));
1150 diag <<
"unable to convert type for " << valueDiagTag <<
" #"
1151 << it.index() <<
", type was " << origType;
1156 if (legalTypes.size() != 1) {
1164 remapped.push_back(
mapping.lookupOrDefault(operand));
1169 Type desiredType = legalTypes.front();
1172 Value newOperand =
mapping.lookupOrDefault(operand, desiredType);
1173 if (newOperand.
getType() != desiredType) {
1180 newOperand, desiredType,
1182 mapping.map(newOperand, castValue);
1183 newOperand = castValue;
1185 remapped.push_back(newOperand);
1208 if (region->
empty())
1213 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1215 std::optional<TypeConverter::SignatureConversion> conversion =
1225 if (entryConversion)
1228 std::optional<TypeConverter::SignatureConversion> conversion =
1241 assert(!hasRewrite<BlockTypeConversionRewrite>(
rewrites, block) &&
1242 "block was already converted");
1254 for (
unsigned i = 0; i < origArgCount; ++i) {
1256 if (!inputMap || inputMap->replacementValue)
1259 for (
unsigned j = 0;
j < inputMap->size; ++
j)
1260 newLocs[inputMap->inputNo +
j] = origLoc;
1267 convertedTypes, newLocs);
1277 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->
end());
1280 while (!block->
empty())
1287 for (
unsigned i = 0; i != origArgCount; ++i) {
1291 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1297 MaterializationKind::Source,
1300 origArgType,
Type(), converter);
1302 appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1306 if (
Value repl = inputMap->replacementValue) {
1308 assert(inputMap->size == 0 &&
1309 "invalid to provide a replacement value when the argument isn't "
1312 appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1321 newBlock->
getArguments().slice(inputMap->inputNo, inputMap->size);
1324 replArgs, origArg, converter);
1325 appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
1328 appendRewrite<BlockTypeConversionRewrite>(block, newBlock);
1347 assert((!originalType || kind == MaterializationKind::Target) &&
1348 "original type is valid only for target materializations");
1351 if (inputs.size() == 1 && inputs.front().
getType() == outputType)
1352 return inputs.front();
1359 builder.
create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1360 appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1362 return convertOp.getResult(0);
1372 replacements, originalType,
1374 mapping.map(originalValue, argMat);
1377 Type legalOutputType;
1379 legalOutputType = converter->
convertType(originalType);
1380 }
else if (replacements.size() == 1) {
1387 legalOutputType = replacements[0].
getType();
1389 if (legalOutputType && legalOutputType != originalType) {
1392 argMat, legalOutputType,
1393 originalType, converter);
1394 mapping.map(argMat, targetMat);
1410 [&](
Operation *op) { return replacedOps.contains(op); }) &&
1417 repl =
mapping.lookupOrNull(value);
1430 mapping.map(value, castValue);
1440 logger.startLine() <<
"** Insert : '" << op->
getName() <<
"'(" << op
1444 "attempting to insert into a block within a replaced/erased op");
1446 if (!previous.
isSet()) {
1448 appendRewrite<CreateOperationRewrite>(op);
1454 appendRewrite<MoveOperationRewrite>(op, previous.
getBlock(), prevOp);
1460 assert(!
ignoredOps.contains(op) &&
"operation was already replaced");
1464 bool isUnresolvedMaterialization =
false;
1465 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1467 isUnresolvedMaterialization =
true;
1470 for (
auto [n, result] : llvm::zip_equal(newValues, op->
getResults())) {
1474 if (isUnresolvedMaterialization) {
1486 repl.push_back(sourceMat);
1494 assert(!isUnresolvedMaterialization &&
1495 "attempting to replace an unresolved materialization");
1502 if (repl.size() == 1) {
1504 mapping.map(result, repl.front());
1519 appendRewrite<EraseBlockRewrite>(block);
1525 "attempting to insert into a region within a replaced/erased op");
1530 logger.startLine() <<
"** Insert Block into : '" << parent->
getName()
1531 <<
"'(" << parent <<
")\n";
1534 <<
"** Insert Block into detached Region (nullptr parent op)'";
1540 appendRewrite<CreateBlockRewrite>(block);
1543 Block *prevBlock = previousIt == previous->
end() ? nullptr : &*previousIt;
1544 appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
1549 appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
1556 reasonCallback(
diag);
1557 logger.startLine() <<
"** Failure : " <<
diag.str() <<
"\n";
1567 ConversionPatternRewriter::ConversionPatternRewriter(
1571 setListener(
impl.get());
1577 assert(op && newOp &&
"expected non-null op");
1583 "incorrect # of replacement values");
1585 impl->logger.startLine()
1586 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
1591 newVals[index].push_back(val);
1592 impl->notifyOpReplaced(op, newVals);
1598 "incorrect # of replacement values");
1600 impl->logger.startLine()
1601 <<
"** Replace : '" << op->
getName() <<
"'(" << op <<
")\n";
1605 llvm::append_range(newVals[index], val);
1606 impl->notifyOpReplaced(op, newVals);
1611 impl->logger.startLine()
1612 <<
"** Erase : '" << op->
getName() <<
"'(" << op <<
")\n";
1615 impl->notifyOpReplaced(op, nullRepls);
1620 "attempting to erase a block within a replaced/erased op");
1630 impl->notifyBlockIsBeingErased(block);
1638 "attempting to apply a signature conversion to a block within a "
1639 "replaced/erased op");
1640 return impl->applySignatureConversion(*
this, block, converter, conversion);
1647 "attempting to apply a signature conversion to a block within a "
1648 "replaced/erased op");
1649 return impl->convertRegionTypes(*
this, region, converter, entryConversion);
1656 impl->logger.startLine() <<
"** Replace Argument : '" << from
1657 <<
"'(in region of '" << parentOp->
getName()
1660 impl->appendRewrite<ReplaceBlockArgRewrite>(from.
getOwner(), from,
1661 impl->currentTypeConverter);
1662 impl->mapping.map(
impl->mapping.lookupOrDefault(from), to);
1667 if (failed(
impl->remapValues(
"value", std::nullopt, *
this, key,
1670 return remappedValues.front();
1678 return impl->remapValues(
"value", std::nullopt, *
this, keys,
1687 "incorrect # of argument replacement values");
1689 "attempting to inline a block from a replaced/erased op");
1691 "attempting to inline a block into a replaced/erased op");
1692 auto opIgnored = [&](
Operation *op) {
return impl->isOpIgnored(op); };
1695 assert(llvm::all_of(source->
getUsers(), opIgnored) &&
1696 "expected 'source' to have no predecessors");
1705 bool fastPath = !
impl->config.listener;
1708 impl->notifyBlockBeingInlined(dest, source, before);
1711 for (
auto it : llvm::zip(source->
getArguments(), argValues))
1712 replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
1719 while (!source->
empty())
1720 moveOpBefore(&source->
front(), dest, before);
1728 assert(!
impl->wasOpReplaced(op) &&
1729 "attempting to modify a replaced/erased op");
1731 impl->pendingRootUpdates.insert(op);
1733 impl->appendRewrite<ModifyOperationRewrite>(op);
1737 assert(!
impl->wasOpReplaced(op) &&
1738 "attempting to modify a replaced/erased op");
1743 assert(
impl->pendingRootUpdates.erase(op) &&
1744 "operation did not have a pending in-place update");
1750 assert(
impl->pendingRootUpdates.erase(op) &&
1751 "operation did not have a pending in-place update");
1754 auto it = llvm::find_if(
1755 llvm::reverse(
impl->rewrites), [&](std::unique_ptr<IRRewrite> &
rewrite) {
1756 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1757 return modifyRewrite && modifyRewrite->getOperation() == op;
1759 assert(it !=
impl->rewrites.rend() &&
"no root update started on op");
1761 int updateIdx = std::prev(
impl->rewrites.rend()) - it;
1762 impl->rewrites.erase(
impl->rewrites.begin() + updateIdx);
1777 auto &rewriterImpl = dialectRewriter.getImpl();
1781 getTypeConverter());
1789 return matchAndRewrite(op, operands, dialectRewriter);
1801 class OperationLegalizer {
1821 LogicalResult legalizeWithFold(
Operation *op,
1826 LogicalResult legalizeWithPattern(
Operation *op,
1837 RewriterState &curState);
1841 legalizePatternBlockRewrites(
Operation *op,
1844 RewriterState &state, RewriterState &newState);
1845 LogicalResult legalizePatternCreatedOperations(
1847 RewriterState &state, RewriterState &newState);
1850 RewriterState &state,
1851 RewriterState &newState);
1861 void buildLegalizationGraph(
1862 LegalizationPatterns &anyOpLegalizerPatterns,
1873 void computeLegalizationGraphBenefit(
1874 LegalizationPatterns &anyOpLegalizerPatterns,
1879 unsigned computeOpLegalizationDepth(
1886 unsigned applyCostModelToPatterns(
1887 LegalizationPatterns &patterns,
1908 : target(targetInfo), applicator(patterns), config(config) {
1912 LegalizationPatterns anyOpLegalizerPatterns;
1914 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
1915 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
1918 bool OperationLegalizer::isIllegal(
Operation *op)
const {
1919 return target.isIllegal(op);
1923 OperationLegalizer::legalize(
Operation *op,
1926 const char *logLineComment =
1927 "//===-------------------------------------------===//\n";
1932 logger.getOStream() <<
"\n";
1933 logger.startLine() << logLineComment;
1934 logger.startLine() <<
"Legalizing operation : '" << op->
getName() <<
"'("
1940 op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
1941 logger.getOStream() <<
"\n\n";
1946 if (
auto legalityInfo = target.isLegal(op)) {
1949 logger,
"operation marked legal by the target{0}",
1950 legalityInfo->isRecursivelyLegal
1951 ?
"; NOTE: operation is recursively legal; skipping internals"
1953 logger.startLine() << logLineComment;
1958 if (legalityInfo->isRecursivelyLegal) {
1971 logSuccess(logger,
"operation marked 'ignored' during conversion");
1972 logger.startLine() << logLineComment;
1980 if (succeeded(legalizeWithFold(op, rewriter))) {
1983 logger.startLine() << logLineComment;
1989 if (succeeded(legalizeWithPattern(op, rewriter))) {
1992 logger.startLine() << logLineComment;
1998 logFailure(logger,
"no matched legalization pattern");
1999 logger.startLine() << logLineComment;
2005 OperationLegalizer::legalizeWithFold(
Operation *op,
2007 auto &rewriterImpl = rewriter.
getImpl();
2011 rewriterImpl.
logger.startLine() <<
"* Fold {\n";
2012 rewriterImpl.
logger.indent();
2016 SmallVector<Value, 2> replacementValues;
2018 if (failed(rewriter.
tryFold(op, replacementValues))) {
2024 if (replacementValues.empty())
2025 return legalize(op, rewriter);
2028 rewriter.
replaceOp(op, replacementValues);
2031 for (
unsigned i = curState.numRewrites, e = rewriterImpl.
rewrites.size();
2034 dyn_cast<CreateOperationRewrite>(rewriterImpl.
rewrites[i].get());
2037 if (failed(legalize(createOp->getOperation(), rewriter))) {
2039 "failed to legalize generated constant '{0}'",
2040 createOp->getOperation()->getName()));
2051 OperationLegalizer::legalizeWithPattern(
Operation *op,
2053 auto &rewriterImpl = rewriter.
getImpl();
2056 auto canApply = [&](
const Pattern &pattern) {
2057 bool canApply = canApplyPattern(op, pattern, rewriter);
2058 if (canApply && config.listener)
2059 config.listener->notifyPatternBegin(pattern, op);
2065 auto onFailure = [&](
const Pattern &pattern) {
2071 diag <<
"Failed to apply pattern \"" << pattern.getDebugName()
2077 if (config.listener)
2078 config.listener->notifyPatternEnd(pattern, failure());
2080 appliedPatterns.erase(&pattern);
2085 auto onSuccess = [&](
const Pattern &pattern) {
2087 auto result = legalizePatternResult(op, pattern, rewriter, curState);
2088 appliedPatterns.erase(&pattern);
2091 if (config.listener)
2092 config.listener->notifyPatternEnd(pattern, result);
2097 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2101 bool OperationLegalizer::canApplyPattern(
Operation *op,
const Pattern &pattern,
2104 auto &os = rewriter.
getImpl().logger;
2105 os.getOStream() <<
"\n";
2106 os.startLine() <<
"* Pattern : '" << op->
getName() <<
" -> (";
2108 os.getOStream() <<
")' {\n";
2115 !appliedPatterns.insert(&pattern).second) {
2124 OperationLegalizer::legalizePatternResult(
Operation *op,
const Pattern &pattern,
2126 RewriterState &curState) {
2130 assert(
impl.pendingRootUpdates.empty() &&
"dangling root updates");
2132 auto newRewrites = llvm::drop_begin(
impl.rewrites, curState.numRewrites);
2133 auto replacedRoot = [&] {
2134 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2136 auto updatedRootInPlace = [&] {
2137 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2139 assert((replacedRoot() || updatedRootInPlace()) &&
2140 "expected pattern to replace the root operation");
2144 RewriterState newState =
impl.getCurrentState();
2145 if (failed(legalizePatternBlockRewrites(op, rewriter,
impl, curState,
2147 failed(legalizePatternRootUpdates(rewriter,
impl, curState, newState)) ||
2148 failed(legalizePatternCreatedOperations(rewriter,
impl, curState,
2153 LLVM_DEBUG(
logSuccess(
impl.logger,
"pattern applied successfully"));
2157 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2160 RewriterState &newState) {
2165 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2166 BlockRewrite *
rewrite = dyn_cast<BlockRewrite>(
impl.rewrites[i].get());
2170 if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2171 ReplaceBlockArgRewrite>(
rewrite))
2180 if (
auto *converter =
impl.regionToConverter.lookup(block->
getParent())) {
2181 std::optional<TypeConverter::SignatureConversion> conversion =
2184 LLVM_DEBUG(
logFailure(
impl.logger,
"failed to convert types of moved "
2188 impl.applySignatureConversion(rewriter, block, converter, *conversion);
2196 if (operationsToIgnore.empty()) {
2197 for (
unsigned i = state.numRewrites, e =
impl.rewrites.size(); i != e;
2200 dyn_cast<CreateOperationRewrite>(
impl.rewrites[i].get());
2203 operationsToIgnore.insert(createOp->getOperation());
2208 if (operationsToIgnore.insert(parentOp).second &&
2209 failed(legalize(parentOp, rewriter))) {
2211 "operation '{0}'({1}) became illegal after rewrite",
2212 parentOp->
getName(), parentOp));
2219 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2221 RewriterState &state, RewriterState &newState) {
2222 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2223 auto *createOp = dyn_cast<CreateOperationRewrite>(
impl.rewrites[i].get());
2226 Operation *op = createOp->getOperation();
2227 if (failed(legalize(op, rewriter))) {
2229 "failed to legalize generated operation '{0}'({1})",
2237 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2239 RewriterState &state, RewriterState &newState) {
2240 for (
int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2241 auto *
rewrite = dyn_cast<ModifyOperationRewrite>(
impl.rewrites[i].get());
2245 if (failed(legalize(op, rewriter))) {
2247 impl.logger,
"failed to legalize operation updated in-place '{0}'",
2258 void OperationLegalizer::buildLegalizationGraph(
2259 LegalizationPatterns &anyOpLegalizerPatterns,
2270 applicator.walkAllPatterns([&](
const Pattern &pattern) {
2271 std::optional<OperationName> root = pattern.
getRootKind();
2277 anyOpLegalizerPatterns.push_back(&pattern);
2282 if (target.getOpAction(*root) == LegalizationAction::Legal)
2287 invalidPatterns[*root].insert(&pattern);
2289 parentOps[op].insert(*root);
2292 patternWorklist.insert(&pattern);
2300 if (!anyOpLegalizerPatterns.empty()) {
2301 for (
const Pattern *pattern : patternWorklist)
2302 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2306 while (!patternWorklist.empty()) {
2307 auto *pattern = patternWorklist.pop_back_val();
2311 std::optional<LegalizationAction> action = target.getOpAction(op);
2312 return !legalizerPatterns.count(op) &&
2313 (!action || action == LegalizationAction::Illegal);
2319 legalizerPatterns[*pattern->
getRootKind()].push_back(pattern);
2320 invalidPatterns[*pattern->
getRootKind()].erase(pattern);
2324 for (
auto op : parentOps[*pattern->
getRootKind()])
2325 patternWorklist.set_union(invalidPatterns[op]);
2329 void OperationLegalizer::computeLegalizationGraphBenefit(
2330 LegalizationPatterns &anyOpLegalizerPatterns,
2336 for (
auto &opIt : legalizerPatterns)
2337 if (!minOpPatternDepth.count(opIt.first))
2338 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2344 if (!anyOpLegalizerPatterns.empty())
2345 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
2351 applicator.applyCostModel([&](
const Pattern &pattern) {
2353 if (std::optional<OperationName> rootName = pattern.
getRootKind())
2354 orderedPatternList = legalizerPatterns[*rootName];
2356 orderedPatternList = anyOpLegalizerPatterns;
2359 auto *it = llvm::find(orderedPatternList, &pattern);
2360 if (it == orderedPatternList.end())
2364 return PatternBenefit(std::distance(it, orderedPatternList.end()));
2368 unsigned OperationLegalizer::computeOpLegalizationDepth(
2372 auto depthIt = minOpPatternDepth.find(op);
2373 if (depthIt != minOpPatternDepth.end())
2374 return depthIt->second;
2378 auto opPatternsIt = legalizerPatterns.find(op);
2379 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
2388 unsigned minDepth = applyCostModelToPatterns(
2389 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
2390 minOpPatternDepth[op] = minDepth;
2394 unsigned OperationLegalizer::applyCostModelToPatterns(
2395 LegalizationPatterns &patterns,
2401 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
2402 patternsByDepth.reserve(patterns.size());
2403 for (
const Pattern *pattern : patterns) {
2406 unsigned generatedOpDepth = computeOpLegalizationDepth(
2407 generatedOp, minOpPatternDepth, legalizerPatterns);
2408 depth =
std::max(depth, generatedOpDepth + 1);
2410 patternsByDepth.emplace_back(pattern, depth);
2413 minDepth =
std::min(minDepth, depth);
2418 if (patternsByDepth.size() == 1)
2422 std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(),
2423 [](
const std::pair<const Pattern *, unsigned> &lhs,
2424 const std::pair<const Pattern *, unsigned> &rhs) {
2427 if (lhs.second != rhs.second)
2428 return lhs.second < rhs.second;
2431 auto lhsBenefit = lhs.first->getBenefit();
2432 auto rhsBenefit = rhs.first->getBenefit();
2433 return lhsBenefit > rhsBenefit;
2438 for (
auto &patternIt : patternsByDepth)
2439 patterns.push_back(patternIt.first);
2447 enum OpConversionMode {
2470 OpConversionMode mode)
2471 : config(config), opLegalizer(target, patterns, this->config),
2485 OperationLegalizer opLegalizer;
2488 OpConversionMode mode;
2495 if (failed(opLegalizer.legalize(op, rewriter))) {
2498 if (mode == OpConversionMode::Full)
2500 <<
"failed to legalize operation '" << op->
getName() <<
"'";
2504 if (mode == OpConversionMode::Partial) {
2505 if (opLegalizer.isIllegal(op))
2507 <<
"failed to legalize operation '" << op->
getName()
2508 <<
"' that was explicitly marked illegal";
2512 }
else if (mode == OpConversionMode::Analysis) {
2522 static LogicalResult
2524 UnresolvedMaterializationRewrite *
rewrite) {
2525 UnrealizedConversionCastOp op =
rewrite->getOperation();
2526 assert(!op.use_empty() &&
2527 "expected that dead materializations have already been DCE'd");
2529 Type outputType = op.getResultTypes()[0];
2534 Value newMaterialization;
2535 switch (
rewrite->getMaterializationKind()) {
2539 rewriter, op->getLoc(), outputType, inputOperands);
2540 if (newMaterialization)
2545 case MaterializationKind::Target:
2547 rewriter, op->getLoc(), outputType, inputOperands,
2550 case MaterializationKind::Source:
2552 rewriter, op->getLoc(), outputType, inputOperands);
2555 if (newMaterialization) {
2556 assert(newMaterialization.
getType() == outputType &&
2557 "materialization callback produced value of incorrect type");
2558 rewriter.
replaceOp(op, newMaterialization);
2564 op->emitError() <<
"failed to legalize unresolved materialization "
2566 << inputOperands.
getTypes() <<
") to (" << outputType
2567 <<
") that remained live after conversion";
2568 diag.attachNote(op->getUsers().begin()->getLoc())
2569 <<
"see existing live user here: " << *op->getUsers().begin();
2580 for (
auto *op : ops) {
2583 toConvert.push_back(op);
2586 auto legalityInfo = target.
isLegal(op);
2587 if (legalityInfo && legalityInfo->isRecursivelyLegal)
2597 for (
auto *op : toConvert)
2598 if (failed(convert(rewriter, op)))
2608 for (
auto it : materializations) {
2611 allCastOps.push_back(it.first);
2623 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2624 auto it = materializations.find(castOp);
2625 assert(it != materializations.end() &&
"inconsistent state");
2648 auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2649 for (
Value v : castOp.getInputs())
2650 if (
auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2651 worklist.insert(inputCastOp);
2658 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2659 if (castOp.getInputs().empty())
2662 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2665 if (inputCastOp.getOutputs() != castOp.getInputs())
2671 while (!worklist.empty()) {
2672 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2673 if (castOp->use_empty()) {
2676 enqueueOperands(castOp);
2677 if (remainingCastOps)
2678 erasedOps.insert(castOp.getOperation());
2685 UnrealizedConversionCastOp nextCast = castOp;
2687 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2691 enqueueOperands(castOp);
2692 castOp.replaceAllUsesWith(nextCast.getInputs());
2693 if (remainingCastOps)
2694 erasedOps.insert(castOp.getOperation());
2698 nextCast = getInputCast(nextCast);
2702 if (remainingCastOps)
2703 for (UnrealizedConversionCastOp op : castOps)
2704 if (!erasedOps.contains(op.getOperation()))
2705 remainingCastOps->push_back(op);
2714 assert(!types.empty() &&
"expected valid types");
2715 remapInput(origInputNo, argTypes.size(), types.size());
2720 assert(!types.empty() &&
2721 "1->0 type remappings don't need to be added explicitly");
2722 argTypes.append(types.begin(), types.end());
2726 unsigned newInputNo,
2727 unsigned newInputCount) {
2728 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
2729 assert(newInputCount != 0 &&
"expected valid input count");
2730 remappedInputs[origInputNo] =
2731 InputMapping{newInputNo, newInputCount,
nullptr};
2735 Value replacementValue) {
2736 assert(!remappedInputs[origInputNo] &&
"input has already been remapped");
2737 remappedInputs[origInputNo] =
2744 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
2747 cacheReadLock.lock();
2748 auto existingIt = cachedDirectConversions.find(t);
2749 if (existingIt != cachedDirectConversions.end()) {
2750 if (existingIt->second)
2751 results.push_back(existingIt->second);
2752 return success(existingIt->second !=
nullptr);
2754 auto multiIt = cachedMultiConversions.find(t);
2755 if (multiIt != cachedMultiConversions.end()) {
2756 results.append(multiIt->second.begin(), multiIt->second.end());
2762 size_t currentCount = results.size();
2764 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
2767 for (
const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2768 if (std::optional<LogicalResult> result = converter(t, results)) {
2770 cacheWriteLock.lock();
2771 if (!succeeded(*result)) {
2772 cachedDirectConversions.try_emplace(t,
nullptr);
2775 auto newTypes =
ArrayRef<Type>(results).drop_front(currentCount);
2776 if (newTypes.size() == 1)
2777 cachedDirectConversions.try_emplace(t, newTypes.front());
2779 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
2793 return results.size() == 1 ? results.front() :
nullptr;
2799 for (
Type type : types)
2813 return llvm::all_of(*region, [
this](
Block &block) {
2819 return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
2831 if (convertedTypes.empty())
2835 result.
addInputs(inputNo, convertedTypes);
2841 unsigned origInputOffset)
const {
2842 for (
unsigned i = 0, e = types.size(); i != e; ++i)
2852 for (
const MaterializationCallbackFn &fn :
2853 llvm::reverse(argumentMaterializations))
2854 if (
Value result = fn(builder, resultType, inputs, loc))
2862 for (
const MaterializationCallbackFn &fn :
2863 llvm::reverse(sourceMaterializations))
2864 if (
Value result = fn(builder, resultType, inputs, loc))
2872 Type originalType)
const {
2874 builder, loc,
TypeRange(resultType), inputs, originalType);
2877 assert(result.size() == 1 &&
"expected single result");
2878 return result.front();
2883 Type originalType)
const {
2884 for (
const TargetMaterializationCallbackFn &fn :
2885 llvm::reverse(targetMaterializations)) {
2887 fn(builder, resultTypes, inputs, loc, originalType);
2891 "callback produced incorrect number of values or values with "
2898 std::optional<TypeConverter::SignatureConversion>
2902 return std::nullopt;
2925 return impl.getInt() == resultTag;
2929 return impl.getInt() == naTag;
2933 return impl.getInt() == abortTag;
2937 assert(hasResult() &&
"Cannot get result from N/A or abort");
2938 return impl.getPointer();
2941 std::optional<Attribute>
2943 for (
const TypeAttributeConversionCallbackFn &fn :
2944 llvm::reverse(typeAttributeConversions)) {
2949 return std::nullopt;
2951 return std::nullopt;
2961 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
2967 SmallVector<Type, 1> newResults;
2969 failed(typeConverter.
convertTypes(type.getResults(), newResults)) ||
2971 typeConverter, &result)))
2988 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
2996 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3001 struct AnyFunctionOpInterfaceSignatureConversion
3013 FailureOr<Operation *>
3017 assert(op &&
"Invalid op");
3031 return rewriter.
create(newOp);
3037 patterns.
add<FunctionOpInterfaceSignatureConversion>(
3038 functionLikeOpName, patterns.
getContext(), converter);
3043 patterns.
add<AnyFunctionOpInterfaceSignatureConversion>(
3053 legalOperations[op].action = action;
3058 for (StringRef dialect : dialectNames)
3059 legalDialects[dialect] = action;
3063 -> std::optional<LegalizationAction> {
3064 std::optional<LegalizationInfo> info = getOpInfo(op);
3065 return info ? info->action : std::optional<LegalizationAction>();
3069 -> std::optional<LegalOpDetails> {
3070 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3072 return std::nullopt;
3075 auto isOpLegal = [&] {
3077 if (info->action == LegalizationAction::Dynamic) {
3078 std::optional<bool> result = info->legalityFn(op);
3084 return info->action == LegalizationAction::Legal;
3087 return std::nullopt;
3091 if (info->isRecursivelyLegal) {
3092 auto legalityFnIt = opRecursiveLegalityFns.find(op->
getName());
3093 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3095 legalityFnIt->second(op).value_or(
true);
3100 return legalityDetails;
3104 std::optional<LegalizationInfo> info = getOpInfo(op->
getName());
3108 if (info->action == LegalizationAction::Dynamic) {
3109 std::optional<bool> result = info->legalityFn(op);
3116 return info->action == LegalizationAction::Illegal;
3125 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3127 if (std::optional<bool> result = newCl(op))
3135 void ConversionTarget::setLegalityCallback(
3136 OperationName name,
const DynamicLegalityCallbackFn &callback) {
3137 assert(callback &&
"expected valid legality callback");
3138 auto *infoIt = legalOperations.find(name);
3139 assert(infoIt != legalOperations.end() &&
3140 infoIt->second.action == LegalizationAction::Dynamic &&
3141 "expected operation to already be marked as dynamically legal");
3142 infoIt->second.legalityFn =
3148 auto *infoIt = legalOperations.find(name);
3149 assert(infoIt != legalOperations.end() &&
3150 infoIt->second.action != LegalizationAction::Illegal &&
3151 "expected operation to already be marked as legal");
3152 infoIt->second.isRecursivelyLegal =
true;
3155 std::move(opRecursiveLegalityFns[name]), callback);
3157 opRecursiveLegalityFns.erase(name);
3160 void ConversionTarget::setLegalityCallback(
3162 assert(callback &&
"expected valid legality callback");
3163 for (StringRef dialect : dialects)
3165 std::move(dialectLegalityFns[dialect]), callback);
3168 void ConversionTarget::setLegalityCallback(
3169 const DynamicLegalityCallbackFn &callback) {
3170 assert(callback &&
"expected valid legality callback");
3175 -> std::optional<LegalizationInfo> {
3177 const auto *it = legalOperations.find(op);
3178 if (it != legalOperations.end())
3181 auto dialectIt = legalDialects.find(op.getDialectNamespace());
3182 if (dialectIt != legalDialects.end()) {
3183 DynamicLegalityCallbackFn callback;
3184 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
3185 if (dialectFn != dialectLegalityFns.end())
3186 callback = dialectFn->second;
3187 return LegalizationInfo{dialectIt->second,
false,
3191 if (unknownLegalityFn)
3192 return LegalizationInfo{LegalizationAction::Dynamic,
3193 false, unknownLegalityFn};
3194 return std::nullopt;
3197 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3203 auto &rewriterImpl =
3209 auto &rewriterImpl =
3216 static FailureOr<SmallVector<Value>>
3218 SmallVector<Value> mappedValues;
3221 return std::move(mappedValues);
3230 if (failed(results))
3232 return results->front();
3242 auto &rewriterImpl =
3255 TypeRange types) -> FailureOr<SmallVector<Type>> {
3256 auto &rewriterImpl =
3263 if (failed(converter->
convertTypes(types, remappedTypes)))
3265 return std::move(remappedTypes);
3281 OpConversionMode::Partial);
3299 OpConversionMode::Full);
3322 "expected top-level op to be isolated from above");
3325 "expected ops to have a common ancestor");
3334 for (
Operation *op : ops.drop_front()) {
3338 assert(commonAncestor &&
3339 "expected to find a common isolated from above ancestor");
3343 return commonAncestor;
3361 inverseOperationMap[it.second] = it.first;
3367 OpConversionMode::Analysis);
3368 LogicalResult status = opConverter.convertOperations(opsToConvert);
3375 originalLegalizableOps.insert(inverseOperationMap[op]);
3380 clonedAncestor->
erase();
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
SmallVector< Value, 1 > ReplacementValues
A list of replacement SSA values.
static FailureOr< SmallVector< Value > > pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values)
Remap the given value using the rewriter and the type converter in the provided config.
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnresolvedMaterializationRewrite *rewrite)
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static bool hasRewrite(R &&rewrites, Operation *op)
Return "true" if there is an operation rewrite that matches the specified rewrite type and operation ...
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
void dropAllDefinedValueUses()
This drops all uses of values defined in this block or in the blocks of nested regions wherever the u...
OpListType & getOperations()
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
LogicalResult getRemappedValues(ValueRange keys, SmallVectorImpl< Value > &results)
Return the converted values that replace 'keys' with types defined by the type converter of the curre...
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt) override
PatternRewriter hook for inlining the ops of a block into another block.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
detail::ConversionPatternRewriterImpl & getImpl()
Return a reference to the internal implementation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
Value getRemappedValue(Value key)
Return the converted value of 'key' with a type defined by the type converter of the currently execut...
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
~ConversionPatternRewriter() override
Base class for the conversion patterns.
virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const
Hook for derived classes to implement combined matching and rewriting.
This class describes a specific conversion target.
void setDialectAction(ArrayRef< StringRef > dialectNames, LegalizationAction action)
Register a legality action for the given dialects.
void setOpAction(OperationName op, LegalizationAction action)
Register a legality action for the given operation.
std::optional< LegalOpDetails > isLegal(Operation *op) const
If the given operation instance is legal on this target, a structure containing legality information ...
std::optional< LegalizationAction > getOpAction(OperationName op) const
Get the legality action for the given operation.
LegalizationAction
This enumeration corresponds to the specific action to take when considering an operation legal for t...
void markOpRecursivelyLegal(OperationName name, const DynamicLegalityCallbackFn &callback)
Mark an operation, that must have either been set as Legal or DynamicallyLegal, as being recursively ...
std::function< std::optional< bool >(Operation *)> DynamicLegalityCallbackFn
The signature of the callback used to determine if an operation is dynamically legal on the target.
bool isIllegal(Operation *op) const
Returns true is operation instance is illegal on this target.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This is a utility class for mapping one set of IR entities to another.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
auto lookup(T from) const
Lookup a mapped value within the map.
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
Location objects represent source locations information in MLIR.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Block::iterator getPoint() const
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results)
Attempts to fold the given operation and places new results within results.
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
void dropAllUses()
Drop all uses of results of this operation.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
int getPropertiesStorageSize() const
Returns the properties storage size.
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
void notifyRewriteEnd(PatternRewriter &rewriter) final
void notifyRewriteBegin(PatternRewriter &rewriter) final
Hooks that are invoked at the beginning and end of a rewrite of a matched pattern.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
MLIRContext * getContext() const
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
The general result of a type attribute conversion callback, allowing for early termination.
Attribute getResult() const
static AttributeConversionResult abort()
static AttributeConversionResult na()
static AttributeConversionResult result(Attribute attr)
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertSignatureArgs(TypeRange types, SignatureConversion &result, unsigned origInputOffset=0) const
LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) const
This method allows for converting a specific argument of a signature.
Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const
Materialize a conversion from a set of types into one result type by generating a cast sequence of so...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::optional< SignatureConversion > convertBlockSignature(Block *block) const
This function converts the type signature of the given block, by invoking 'convertSignatureArg' for e...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType={}) const
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
@ Full
Documents are synced by always sending the full content of the document.
Kind
An enumeration of the kinds of predicates.
llvm::PointerUnion< NamedAttribute *, NamedProperty *, NamedTypeConstraint * > Argument
Include the generated interface declarations.
void populateFunctionOpInterfaceTypeConversionPattern(StringRef functionLikeOpName, RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the signature of a FunctionOpInterface op with the...
void populateAnyFunctionOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
LogicalResult applyAnalysisConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply an analysis conversion on the given operations, and all nested operations.
void reconcileUnrealizedCasts(ArrayRef< UnrealizedConversionCastOp > castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps=nullptr)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void registerConversionPDLFunctions(RewritePatternSet &patterns)
Register the dialect conversion PDL functions with the given pattern set.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
function_ref< void(Diagnostic &)> notifyCallback
An optional callback used to notify about match failure diagnostics during the conversion.
DenseSet< Operation * > * legalizableOps
Analysis conversion only.
DenseSet< Operation * > * unlegalizedOps
Partial conversion only.
bool buildMaterializations
If set to "true", the dialect conversion attempts to build source/target/ argument materializations t...
A structure containing additional information describing a specific legal operation instance.
bool isRecursivelyLegal
A flag that indicates if this operation is 'recursively' legal.
This iterator enumerates elements according to their dominance relationship.
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
A rewriter that keeps track of erased ops and blocks.
bool wasErased(void *ptr) const
SingleEraseRewriter(MLIRContext *context)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config)
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * > unresolvedMaterializations
A mapping of all unresolved materializations (UnrealizedConversionCastOp) to the corresponding rewrit...
void resetState(RewriterState state)
Reset the state of the rewriter to a previously saved point.
Block * applySignatureConversion(ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
FailureOr< Block * > convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
ConversionValueMapping mapping
void applyRewrites()
Apply all requested operation rewrites.
void undoRewrites(unsigned numRewritesToKeep=0)
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
RewriterState getCurrentState()
Return the current state of the rewriter.
void notifyOpReplaced(Operation *op, ArrayRef< ReplacementValues > newValues)
Notifies that an op is about to be replaced with the given values.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void notifyBlockBeingInlined(Block *block, Block *srcBlock, Block::iterator before)
Notifies that a block is being inlined into another block.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc, ValueRange replacements, Value originalValue, const TypeConverter *converter)
Build an N:1 materialization for the given original value that was replaced with the given replacemen...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl< Value > &remapped)
Remap the given values to those with potentially different types.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
SingleEraseRewriter eraseRewriter
A rewriter that keeps track of ops/block that were already erased and skips duplicate op/block erasur...
MLIRContext * context
MLIR context.
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
Value buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueRange inputs, Type outputType, Type originalType, const TypeConverter *converter)
Build an unresolved materialization operation given an output type and set of input operands.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
void notifyBlockIsBeingErased(Block *block)
Notifies that a block is about to be erased.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.