MLIR 23.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/Config/mlir-config.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Iterators.h"
17#include "mlir/IR/Operation.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/SaveAndRestore.h"
27#include "llvm/Support/ScopedPrinter.h"
28#include <optional>
29#include <utility>
30
31using namespace mlir;
32using namespace mlir::detail;
33
34#define DEBUG_TYPE "dialect-conversion"
35
36/// A utility function to log a successful result for the given reason.
37template <typename... Args>
38static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
39 LLVM_DEBUG({
40 os.unindent();
41 os.startLine() << "} -> SUCCESS";
42 if (!fmt.empty())
43 os.getOStream() << " : "
44 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
45 os.getOStream() << "\n";
46 });
47}
48
49/// A utility function to log a failure result for the given reason.
50template <typename... Args>
51static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
52 LLVM_DEBUG({
53 os.unindent();
54 os.startLine() << "} -> FAILURE : "
55 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
56 << "\n";
57 });
58}
59
60/// Helper function that computes an insertion point where the given value is
61/// defined and can be used without a dominance violation.
63 Block *insertBlock = value.getParentBlock();
64 Block::iterator insertPt = insertBlock->begin();
65 if (OpResult inputRes = dyn_cast<OpResult>(value))
66 insertPt = ++inputRes.getOwner()->getIterator();
67 return OpBuilder::InsertPoint(insertBlock, insertPt);
68}
69
70/// Helper function that computes an insertion point where the given values are
71/// defined and can be used without a dominance violation.
73 assert(!vals.empty() && "expected at least one value");
74 DominanceInfo domInfo;
76 for (Value v : vals.drop_front()) {
77 // Choose the "later" insertion point.
79 if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
80 nextPt.getPoint())) {
81 // pt is before nextPt => choose nextPt.
82 pt = nextPt;
83 } else {
84#ifndef NDEBUG
85 // nextPt should be before pt => choose pt.
86 // If pt, nextPt are no dominance relationship, then there is no valid
87 // insertion point at which all given values are defined.
88 bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
89 pt.getBlock(), pt.getPoint());
90 assert(dom && "unable to find valid insertion point");
91#endif // NDEBUG
92 }
93 }
94 return pt;
95}
96
97namespace {
98enum OpConversionMode {
99 /// In this mode, the conversion will ignore failed conversions to allow
100 /// illegal operations to co-exist in the IR.
101 Partial,
102
103 /// In this mode, all operations must be legal for the given target for the
104 /// conversion to succeed.
105 Full,
106
107 /// In this mode, operations are analyzed for legality. No actual rewrites are
108 /// applied to the operations on success.
109 Analysis,
110};
111} // namespace
112
113//===----------------------------------------------------------------------===//
114// ConversionValueMapping
115//===----------------------------------------------------------------------===//
116
117/// A vector of SSA values, optimized for the most common case of one or two
118/// values. Inline size chosen empirically based on compilation profiling.
119/// Profiled: 2.3M calls, avg=2.0+-0.3. N=2 covers 98% of cases inline.
121
122namespace {
123
124/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
125struct ValueVectorMapInfo {
126 static ::llvm::hash_code getHashValue(const ValueVector &val) {
127 return ::llvm::hash_combine_range(val);
128 }
129 static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
130 return LHS == RHS;
131 }
132};
133
134/// This class wraps a IRMapping to provide recursive lookup
135/// functionality, i.e. we will traverse if the mapped value also has a mapping.
136struct ConversionValueMapping {
137 /// Return "true" if an SSA value is mapped to the given value. May return
138 /// false positives.
139 bool isMappedTo(Value value) const { return mappedTo.contains(value); }
140
141 /// Lookup a value in the mapping.
142 ValueVector lookup(const ValueVector &from) const;
143
144 template <typename T>
145 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
146
147 /// Map a value vector to the one provided.
148 template <typename OldVal, typename NewVal>
149 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
150 map(OldVal &&oldVal, NewVal &&newVal) {
151 LLVM_DEBUG({
152 ValueVector next(newVal);
153 while (true) {
154 assert(next != oldVal && "inserting cyclic mapping");
155 auto it = mapping.find(next);
156 if (it == mapping.end())
157 break;
158 next = it->second;
159 }
160 });
161 mappedTo.insert_range(newVal);
162
163 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
164 }
165
166 /// Map a value vector or single value to the one provided.
167 template <typename OldVal, typename NewVal>
168 std::enable_if_t<!IsValueVector<OldVal>::value ||
169 !IsValueVector<NewVal>::value>
170 map(OldVal &&oldVal, NewVal &&newVal) {
171 if constexpr (IsValueVector<OldVal>{}) {
172 map(std::forward<OldVal>(oldVal), ValueVector{newVal});
173 } else if constexpr (IsValueVector<NewVal>{}) {
174 map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
175 } else {
176 map(ValueVector{oldVal}, ValueVector{newVal});
177 }
178 }
179
180 void map(Value oldVal, SmallVector<Value> &&newVal) {
181 map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
182 }
183
184 /// Drop the last mapping for the given values.
185 void erase(const ValueVector &value) { mapping.erase(value); }
186
187private:
188 /// Current value mappings.
190
191 /// All SSA values that are mapped to. May contain false positives.
192 DenseSet<Value> mappedTo;
193};
194} // namespace
195
196/// Marker attribute for pure type conversions. I.e., mappings whose only
197/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
198/// the replacement values of a "replaceOp" call, etc., are not pure type
199/// conversions.)
200static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
201
202/// Return the operation that defines all values in the vector. Return nullptr
203/// if the values are not defined by the same operation.
205 assert(!values.empty() && "expected non-empty value vector");
206 Operation *op = values.front().getDefiningOp();
207 for (Value v : llvm::drop_begin(values)) {
208 if (v.getDefiningOp() != op)
209 return nullptr;
210 }
211 return op;
212}
213
214/// A vector of values is a pure type conversion if all values are defined by
215/// the same operation and the operation has the `kPureTypeConversionMarker`
216/// attribute.
217static bool isPureTypeConversion(const ValueVector &values) {
218 assert(!values.empty() && "expected non-empty value vector");
219 Operation *op = getCommonDefiningOp(values);
220 return op && op->hasAttr(kPureTypeConversionMarker);
221}
222
223ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
224 auto it = mapping.find(from);
225 if (it == mapping.end()) {
226 // No mapping found: The lookup stops here.
227 return {};
228 }
229 return it->second;
230}
231
232//===----------------------------------------------------------------------===//
233// Rewriter and Translation State
234//===----------------------------------------------------------------------===//
235namespace {
236/// This class contains a snapshot of the current conversion rewriter state.
237/// This is useful when saving and undoing a set of rewrites.
238struct RewriterState {
239 RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
240 unsigned numReplacedOps)
241 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
242 numReplacedOps(numReplacedOps) {}
243
244 /// The current number of rewrites performed.
245 unsigned numRewrites;
246
247 /// The current number of ignored operations.
248 unsigned numIgnoredOperations;
249
250 /// The current number of replaced ops that are scheduled for erasure.
251 unsigned numReplacedOps;
252};
253
254//===----------------------------------------------------------------------===//
255// IR rewrites
256//===----------------------------------------------------------------------===//
257
258static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
259
260/// Notify the listener that the given block and its contents are being erased.
261static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
262 for (Operation &op : b)
263 notifyIRErased(listener, op);
264 listener->notifyBlockErased(&b);
265}
266
267/// Notify the listener that the given operation and its contents are being
268/// erased.
269static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
270 for (Region &r : op.getRegions()) {
271 for (Block &b : r) {
272 notifyIRErased(listener, b);
273 }
274 }
275 listener->notifyOperationErased(&op);
276}
277
278/// An IR rewrite that can be committed (upon success) or rolled back (upon
279/// failure).
280///
281/// The dialect conversion keeps track of IR modifications (requested by the
282/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
283/// are directly applied to the IR as the rewriter API is used, some are applied
284/// partially, and some are delayed until the `IRRewrite` objects are committed.
285class IRRewrite {
286public:
287 /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
288 /// Enum values are ordered, so that they can be used in `classof`: first all
289 /// block rewrites, then all operation rewrites.
290 enum class Kind {
291 // Block rewrites
292 CreateBlock,
293 EraseBlock,
294 InlineBlock,
295 MoveBlock,
296 BlockTypeConversion,
297 // Operation rewrites
298 MoveOperation,
299 ModifyOperation,
300 ReplaceOperation,
301 CreateOperation,
302 UnresolvedMaterialization,
303 // Value rewrites
304 ReplaceValue
305 };
306
307 virtual ~IRRewrite() = default;
308
309 /// Roll back the rewrite. Operations may be erased during rollback.
310 virtual void rollback() = 0;
311
312 /// Commit the rewrite. At this point, it is certain that the dialect
313 /// conversion will succeed. All IR modifications, except for operation/block
314 /// erasure, must be performed through the given rewriter.
315 ///
316 /// Instead of erasing operations/blocks, they should merely be unlinked
317 /// commit phase and finally be erased during the cleanup phase. This is
318 /// because internal dialect conversion state (such as `mapping`) may still
319 /// be using them.
320 ///
321 /// Any IR modification that was already performed before the commit phase
322 /// (e.g., insertion of an op) must be communicated to the listener that may
323 /// be attached to the given rewriter.
324 virtual void commit(RewriterBase &rewriter) {}
325
326 /// Cleanup operations/blocks. Cleanup is called after commit.
327 virtual void cleanup(RewriterBase &rewriter) {}
328
329 Kind getKind() const { return kind; }
330
331 static bool classof(const IRRewrite *rewrite) { return true; }
332
333protected:
334 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
335 : kind(kind), rewriterImpl(rewriterImpl) {}
336
337 const ConversionConfig &getConfig() const;
338
339 const Kind kind;
340 ConversionPatternRewriterImpl &rewriterImpl;
341};
342
343/// A block rewrite.
344class BlockRewrite : public IRRewrite {
345public:
346 /// Return the block that this rewrite operates on.
347 Block *getBlock() const { return block; }
348
349 static bool classof(const IRRewrite *rewrite) {
350 return rewrite->getKind() >= Kind::CreateBlock &&
351 rewrite->getKind() <= Kind::BlockTypeConversion;
352 }
353
354protected:
355 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
356 Block *block)
357 : IRRewrite(kind, rewriterImpl), block(block) {}
358
359 // The block that this rewrite operates on.
360 Block *block;
361};
362
363/// A value rewrite.
364class ValueRewrite : public IRRewrite {
365public:
366 /// Return the value that this rewrite operates on.
367 Value getValue() const { return value; }
368
369 static bool classof(const IRRewrite *rewrite) {
370 return rewrite->getKind() == Kind::ReplaceValue;
371 }
372
373protected:
374 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
375 Value value)
376 : IRRewrite(kind, rewriterImpl), value(value) {}
377
378 // The value that this rewrite operates on.
379 Value value;
380};
381
382/// Creation of a block. Block creations are immediately reflected in the IR.
383/// There is no extra work to commit the rewrite. During rollback, the newly
384/// created block is erased.
385class CreateBlockRewrite : public BlockRewrite {
386public:
387 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
388 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
389
390 static bool classof(const IRRewrite *rewrite) {
391 return rewrite->getKind() == Kind::CreateBlock;
392 }
393
394 void commit(RewriterBase &rewriter) override {
395 // The block was already created and inserted. Just inform the listener.
396 if (auto *listener = rewriter.getListener())
397 listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
398 }
399
400 void rollback() override {
401 // Unlink all of the operations within this block, they will be deleted
402 // separately.
403 auto &blockOps = block->getOperations();
404 while (!blockOps.empty())
405 blockOps.remove(blockOps.begin());
406 block->dropAllUses();
407 if (block->getParent())
408 block->erase();
409 else
410 delete block;
411 }
412};
413
414/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
415/// blocks are immediately unlinked, but only erased during cleanup. This makes
416/// it easier to rollback a block erasure: the block is simply inserted into its
417/// original location.
418class EraseBlockRewrite : public BlockRewrite {
419public:
420 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
421 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
422 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
423
424 static bool classof(const IRRewrite *rewrite) {
425 return rewrite->getKind() == Kind::EraseBlock;
426 }
427
428 ~EraseBlockRewrite() override {
429 assert(!block &&
430 "rewrite was neither rolled back nor committed/cleaned up");
431 }
432
433 void rollback() override {
434 // The block (owned by this rewrite) was not actually erased yet. It was
435 // just unlinked. Put it back into its original position.
436 assert(block && "expected block");
437 auto &blockList = region->getBlocks();
438 Region::iterator before = insertBeforeBlock
439 ? Region::iterator(insertBeforeBlock)
440 : blockList.end();
441 blockList.insert(before, block);
442 block = nullptr;
443 }
444
445 void commit(RewriterBase &rewriter) override {
446 assert(block && "expected block");
447
448 // Notify the listener that the block and its contents are being erased.
449 if (auto *listener =
450 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
451 notifyIRErased(listener, *block);
452 }
453
454 void cleanup(RewriterBase &rewriter) override {
455 // Erase the contents of the block.
456 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
457 rewriter.eraseOp(&op);
458 assert(block->empty() && "expected empty block");
459
460 // Erase the block.
461 block->dropAllDefinedValueUses();
462 delete block;
463 block = nullptr;
464 }
465
466private:
467 // The region in which this block was previously contained.
468 Region *region;
469
470 // The original successor of this block before it was unlinked. "nullptr" if
471 // this block was the only block in the region.
472 Block *insertBeforeBlock;
473};
474
475/// Inlining of a block. This rewrite is immediately reflected in the IR.
476/// Note: This rewrite represents only the inlining of the operations. The
477/// erasure of the inlined block is a separate rewrite.
478class InlineBlockRewrite : public BlockRewrite {
479public:
480 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
481 Block *sourceBlock, Block::iterator before)
482 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
483 sourceBlock(sourceBlock),
484 firstInlinedInst(sourceBlock->empty() ? nullptr
485 : &sourceBlock->front()),
486 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
487 // If a listener is attached to the dialect conversion, ops must be moved
488 // one-by-one. When they are moved in bulk, notifications cannot be sent
489 // because the ops that used to be in the source block at the time of the
490 // inlining (before the "commit" phase) are unknown at the time when
491 // notifications are sent (which is during the "commit" phase).
492 assert(!getConfig().listener &&
493 "InlineBlockRewrite not supported if listener is attached");
494 }
495
496 static bool classof(const IRRewrite *rewrite) {
497 return rewrite->getKind() == Kind::InlineBlock;
498 }
499
500 void rollback() override {
501 // Put the operations from the destination block (owned by the rewrite)
502 // back into the source block.
503 if (firstInlinedInst) {
504 assert(lastInlinedInst && "expected operation");
505 sourceBlock->getOperations().splice(sourceBlock->begin(),
506 block->getOperations(),
507 Block::iterator(firstInlinedInst),
508 ++Block::iterator(lastInlinedInst));
509 }
510 }
511
512private:
513 // The block that originally contained the operations.
514 Block *sourceBlock;
515
516 // The first inlined operation.
517 Operation *firstInlinedInst;
518
519 // The last inlined operation.
520 Operation *lastInlinedInst;
521};
522
523/// Moving of a block. This rewrite is immediately reflected in the IR.
524class MoveBlockRewrite : public BlockRewrite {
525public:
526 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
527 Region *previousRegion, Region::iterator previousIt)
528 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
529 region(previousRegion),
530 insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
531 : &*previousIt) {}
532
533 static bool classof(const IRRewrite *rewrite) {
534 return rewrite->getKind() == Kind::MoveBlock;
535 }
536
537 void commit(RewriterBase &rewriter) override {
538 // The block was already moved. Just inform the listener.
539 if (auto *listener = rewriter.getListener()) {
540 // Note: `previousIt` cannot be passed because this is a delayed
541 // notification and iterators into past IR state cannot be represented.
542 listener->notifyBlockInserted(block, /*previous=*/region,
543 /*previousIt=*/{});
544 }
545 }
546
547 void rollback() override {
548 // Move the block back to its original position.
549 Region::iterator before =
550 insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
551 if (Region *currentParent = block->getParent()) {
552 // Block is still in a region, use cheap splice to move it back.
553 region->getBlocks().splice(before, currentParent->getBlocks(), block);
554 return;
555 }
556 // Block was orphaned by a prior rollback, can't splice.
557 region->getBlocks().insert(before, block);
558 }
559
560private:
561 // The region in which this block was previously contained.
562 Region *region;
563
564 // The original successor of this block before it was moved. "nullptr" if
565 // this block was the only block in the region.
566 Block *insertBeforeBlock;
567};
568
569/// Block type conversion. This rewrite is partially reflected in the IR.
570class BlockTypeConversionRewrite : public BlockRewrite {
571public:
572 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
573 Block *origBlock, Block *newBlock)
574 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
575 newBlock(newBlock) {}
576
577 static bool classof(const IRRewrite *rewrite) {
578 return rewrite->getKind() == Kind::BlockTypeConversion;
579 }
580
581 Block *getOrigBlock() const { return block; }
582
583 Block *getNewBlock() const { return newBlock; }
584
585 void commit(RewriterBase &rewriter) override;
586
587 void rollback() override;
588
589private:
590 /// The new block that was created as part of this signature conversion.
591 Block *newBlock;
592};
593
594/// Replacing a value. This rewrite is not immediately reflected in the
595/// IR. An internal IR mapping is updated, but the actual replacement is delayed
596/// until the rewrite is committed.
597class ReplaceValueRewrite : public ValueRewrite {
598public:
599 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
600 const TypeConverter *converter)
601 : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
602 converter(converter) {}
603
604 static bool classof(const IRRewrite *rewrite) {
605 return rewrite->getKind() == Kind::ReplaceValue;
606 }
607
608 void commit(RewriterBase &rewriter) override;
609
610 void rollback() override;
611
612private:
613 /// The current type converter when the value was replaced.
614 const TypeConverter *converter;
615};
616
617/// An operation rewrite.
618class OperationRewrite : public IRRewrite {
619public:
620 /// Return the operation that this rewrite operates on.
621 Operation *getOperation() const { return op; }
622
623 static bool classof(const IRRewrite *rewrite) {
624 return rewrite->getKind() >= Kind::MoveOperation &&
625 rewrite->getKind() <= Kind::UnresolvedMaterialization;
626 }
627
628protected:
629 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
630 Operation *op)
631 : IRRewrite(kind, rewriterImpl), op(op) {}
632
633 // The operation that this rewrite operates on.
634 Operation *op;
635};
636
637/// Moving of an operation. This rewrite is immediately reflected in the IR.
638class MoveOperationRewrite : public OperationRewrite {
639public:
640 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
641 Operation *op, OpBuilder::InsertPoint previous)
642 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
643 block(previous.getBlock()),
644 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
645 ? nullptr
646 : &*previous.getPoint()) {}
647
648 static bool classof(const IRRewrite *rewrite) {
649 return rewrite->getKind() == Kind::MoveOperation;
650 }
651
652 void commit(RewriterBase &rewriter) override {
653 // The operation was already moved. Just inform the listener.
654 if (auto *listener = rewriter.getListener()) {
655 // Note: `previousIt` cannot be passed because this is a delayed
656 // notification and iterators into past IR state cannot be represented.
657 listener->notifyOperationInserted(
658 op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
659 /*insertPt=*/{}));
660 }
661 }
662
663 void rollback() override {
664 // Move the operation back to its original position.
665 Block::iterator before =
666 insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
667 block->getOperations().splice(before, op->getBlock()->getOperations(), op);
668 }
669
670private:
671 // The block in which this operation was previously contained.
672 Block *block;
673
674 // The original successor of this operation before it was moved. "nullptr"
675 // if this operation was the only operation in the region.
676 Operation *insertBeforeOp;
677};
678
679/// In-place modification of an op. This rewrite is immediately reflected in
680/// the IR. The previous state of the operation is stored in this object.
681class ModifyOperationRewrite : public OperationRewrite {
682public:
683 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
684 Operation *op)
685 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
686 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
687 operands(op->operand_begin(), op->operand_end()),
688 successors(op->successor_begin(), op->successor_end()) {
689 if (PropertyRef prop = op->getPropertiesStorage()) {
690 // Make a copy of the properties.
691 propertiesStorage = operator new(op->getPropertiesStorageSize());
692 PropertyRef propCopy(name.getOpPropertiesTypeID(), propertiesStorage);
693 name.initOpProperties(propCopy, /*init=*/prop);
694 }
695 }
696
697 static bool classof(const IRRewrite *rewrite) {
698 return rewrite->getKind() == Kind::ModifyOperation;
699 }
700
701 ~ModifyOperationRewrite() override {
702 assert(!propertiesStorage &&
703 "rewrite was neither committed nor rolled back");
704 }
705
706 void commit(RewriterBase &rewriter) override {
707 // Notify the listener that the operation was modified in-place.
708 if (auto *listener =
709 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
710 listener->notifyOperationModified(op);
711
712 if (propertiesStorage) {
713 PropertyRef propCopy(name.getOpPropertiesTypeID(), propertiesStorage);
714 // Note: The operation may have been erased in the mean time, so
715 // OperationName must be stored in this object.
716 name.destroyOpProperties(propCopy);
717 operator delete(propertiesStorage);
718 propertiesStorage = nullptr;
719 }
720 }
721
722 void rollback() override {
723 op->setLoc(loc);
724 op->setAttrs(attrs);
725 op->setOperands(operands);
726 for (const auto &it : llvm::enumerate(successors))
727 op->setSuccessor(it.value(), it.index());
728 if (propertiesStorage) {
729 PropertyRef propCopy(name.getOpPropertiesTypeID(), propertiesStorage);
730 op->copyProperties(propCopy);
731 name.destroyOpProperties(propCopy);
732 operator delete(propertiesStorage);
733 propertiesStorage = nullptr;
734 }
735 }
736
737private:
738 OperationName name;
739 LocationAttr loc;
740 DictionaryAttr attrs;
741 SmallVector<Value, 8> operands;
742 SmallVector<Block *, 2> successors;
743 void *propertiesStorage = nullptr;
744};
745
746/// Replacing an operation. Erasing an operation is treated as a special case
747/// with "null" replacements. This rewrite is not immediately reflected in the
748/// IR. An internal IR mapping is updated, but values are not replaced and the
749/// original op is not erased until the rewrite is committed.
750class ReplaceOperationRewrite : public OperationRewrite {
751public:
752 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
753 Operation *op, const TypeConverter *converter)
754 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
755 converter(converter) {}
756
757 static bool classof(const IRRewrite *rewrite) {
758 return rewrite->getKind() == Kind::ReplaceOperation;
759 }
760
761 void commit(RewriterBase &rewriter) override;
762
763 void rollback() override;
764
765 void cleanup(RewriterBase &rewriter) override;
766
767private:
768 /// An optional type converter that can be used to materialize conversions
769 /// between the new and old values if necessary.
770 const TypeConverter *converter;
771};
772
773class CreateOperationRewrite : public OperationRewrite {
774public:
775 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
776 Operation *op)
777 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
778
779 static bool classof(const IRRewrite *rewrite) {
780 return rewrite->getKind() == Kind::CreateOperation;
781 }
782
783 void commit(RewriterBase &rewriter) override {
784 // The operation was already created and inserted. Just inform the listener.
785 if (auto *listener = rewriter.getListener())
786 listener->notifyOperationInserted(op, /*previous=*/{});
787 }
788
789 void rollback() override;
790};
791
792/// The type of materialization.
793enum MaterializationKind {
794 /// This materialization materializes a conversion from an illegal type to a
795 /// legal one.
796 Target,
797
798 /// This materialization materializes a conversion from a legal type back to
799 /// an illegal one.
800 Source
801};
802
803/// Helper class that stores metadata about an unresolved materialization.
804class UnresolvedMaterializationInfo {
805public:
806 UnresolvedMaterializationInfo() = default;
807 UnresolvedMaterializationInfo(const TypeConverter *converter,
808 MaterializationKind kind, Type originalType)
809 : converterAndKind(converter, kind), originalType(originalType) {}
810
811 /// Return the type converter of this materialization (which may be null).
812 const TypeConverter *getConverter() const {
813 return converterAndKind.getPointer();
814 }
815
816 /// Return the kind of this materialization.
817 MaterializationKind getMaterializationKind() const {
818 return converterAndKind.getInt();
819 }
820
821 /// Return the original type of the SSA value.
822 Type getOriginalType() const { return originalType; }
823
824private:
825 /// The corresponding type converter to use when resolving this
826 /// materialization, and the kind of this materialization.
827 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
828 converterAndKind;
829
830 /// The original type of the SSA value. Only used for target
831 /// materializations.
832 Type originalType;
833};
834
835/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
836/// op. Unresolved materializations fold away or are replaced with
837/// source/target materializations at the end of the dialect conversion.
838class UnresolvedMaterializationRewrite : public OperationRewrite {
839public:
840 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
841 UnrealizedConversionCastOp op,
842 ValueVector mappedValues)
843 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
844 mappedValues(std::move(mappedValues)) {}
845
846 static bool classof(const IRRewrite *rewrite) {
847 return rewrite->getKind() == Kind::UnresolvedMaterialization;
848 }
849
850 void rollback() override;
851
852 UnrealizedConversionCastOp getOperation() const {
853 return cast<UnrealizedConversionCastOp>(op);
854 }
855
856private:
857 /// The values in the conversion value mapping that are being replaced by the
858 /// results of this unresolved materialization.
859 ValueVector mappedValues;
860};
861} // namespace
862
863#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
864/// Return "true" if there is an operation rewrite that matches the specified
865/// rewrite type and operation among the given rewrites.
866template <typename RewriteTy, typename R>
867static bool hasRewrite(R &&rewrites, Operation *op) {
868 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
869 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
870 return rewriteTy && rewriteTy->getOperation() == op;
871 });
872}
873
874/// Return "true" if there is a block rewrite that matches the specified
875/// rewrite type and block among the given rewrites.
876template <typename RewriteTy, typename R>
877static bool hasRewrite(R &&rewrites, Block *block) {
878 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
879 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
880 return rewriteTy && rewriteTy->getBlock() == block;
881 });
882}
883#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
884
885//===----------------------------------------------------------------------===//
886// ConversionPatternRewriterImpl
887//===----------------------------------------------------------------------===//
888namespace mlir {
889namespace detail {
891 explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
892 const ConversionConfig &config,
896
897 //===--------------------------------------------------------------------===//
898 // State Management
899 //===--------------------------------------------------------------------===//
900
901 /// Return the current state of the rewriter.
902 RewriterState getCurrentState();
903
904 /// Apply all requested operation rewrites. This method is invoked when the
905 /// conversion process succeeds.
906 void applyRewrites();
907
908 /// Reset the state of the rewriter to a previously saved point. Optionally,
909 /// the name of the pattern that triggered the rollback can specified for
910 /// debugging purposes.
911 void resetState(RewriterState state, StringRef patternName = "");
912
913 /// Append a rewrite. Rewrites are committed upon success and rolled back upon
914 /// failure.
915 template <typename RewriteTy, typename... Args>
916 void appendRewrite(Args &&...args) {
917 assert(config.allowPatternRollback && "appending rewrites is not allowed");
918 rewrites.push_back(
919 std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
920 }
921
922 /// Undo the rewrites (motions, splits) one by one in reverse order until
923 /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
924 /// that triggered the rollback can specified for debugging purposes.
925 void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
926
927 /// Remap the given values to those with potentially different types. Returns
928 /// success if the values could be remapped, failure otherwise. `valueDiagTag`
929 /// is the tag used when describing a value within a diagnostic, e.g.
930 /// "operand".
931 LogicalResult remapValues(StringRef valueDiagTag,
932 std::optional<Location> inputLoc, ValueRange values,
933 SmallVector<ValueVector> &remapped);
934
935 /// Return "true" if the given operation is ignored, and does not need to be
936 /// converted.
937 bool isOpIgnored(Operation *op) const;
938
939 /// Return "true" if the given operation was replaced or erased.
940 bool wasOpReplaced(Operation *op) const;
941
942 /// Lookup the most recently mapped values with the desired types in the
943 /// mapping, taking into account only replacements. Perform a best-effort
944 /// search for existing materializations with the desired types.
945 ///
946 /// If `skipPureTypeConversions` is "true", materializations that are pure
947 /// type conversions are not considered.
948 ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
949 bool skipPureTypeConversions = false) const;
950
951 /// Lookup the given value within the map, or return an empty vector if the
952 /// value is not mapped. If it is mapped, this follows the same behavior
953 /// as `lookupOrDefault`.
954 ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
955
956 //===--------------------------------------------------------------------===//
957 // IR Rewrites / Type Conversion
958 //===--------------------------------------------------------------------===//
959
960 /// Convert the types of block arguments within the given region.
961 FailureOr<Block *>
962 convertRegionTypes(Region *region, const TypeConverter &converter,
963 TypeConverter::SignatureConversion *entryConversion);
964
965 /// Apply the given signature conversion on the given block. The new block
966 /// containing the updated signature is returned. If no conversions were
967 /// necessary, e.g. if the block has no arguments, `block` is returned.
968 /// `converter` is used to generate any necessary cast operations that
969 /// translate between the origin argument types and those specified in the
970 /// signature conversion.
971 Block *applySignatureConversion(
972 Block *block, const TypeConverter *converter,
973 TypeConverter::SignatureConversion &signatureConversion);
974
975 /// Replace the results of the given operation with the given values and
976 /// erase the operation.
977 ///
978 /// There can be multiple replacement values for each result (1:N
979 /// replacement). If the replacement values are empty, the respective result
980 /// is dropped and a source materialization is built if the result still has
981 /// uses.
982 void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
983
984 /// Replace the uses of the given value with the given values. The specified
985 /// converter is used to build materializations (if necessary). If `functor`
986 /// is specified, only the uses that the functor returns "true" for are
987 /// replaced.
988 void replaceValueUses(Value from, ValueRange to,
989 const TypeConverter *converter,
990 function_ref<bool(OpOperand &)> functor = nullptr);
991
992 /// Erase the given block and its contents.
993 void eraseBlock(Block *block);
994
995 /// Inline the source block into the destination block before the given
996 /// iterator.
997 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
998
999 //===--------------------------------------------------------------------===//
1000 // Materializations
1001 //===--------------------------------------------------------------------===//
1002
1003 /// Build an unresolved materialization operation given a range of output
1004 /// types and a list of input operands. Returns the inputs if they their
1005 /// types match the output types.
1006 ///
1007 /// If a cast op was built, it can optionally be returned with the `castOp`
1008 /// output argument.
1009 ///
1010 /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
1011 /// the results of the unresolved materialization in the conversion value
1012 /// mapping.
1013 ///
1014 /// If `isPureTypeConversion` is "true", the materialization is created only
1015 /// to resolve a type mismatch. That means it is not a regular value
1016 /// replacement issued by the user. (Replacement values that are created
1017 /// "out of thin air" appear like unresolved materializations because they are
1018 /// unrealized_conversion_cast ops. However, they must be treated like
1019 /// regular value replacements.)
1020 ValueRange buildUnresolvedMaterialization(
1021 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1022 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1023 Type originalType, const TypeConverter *converter,
1024 bool isPureTypeConversion = true);
1025
1026 /// Find a replacement value for the given SSA value in the conversion value
1027 /// mapping. The replacement value must have the same type as the given SSA
1028 /// value. If there is no replacement value with the correct type, find the
1029 /// latest replacement value (regardless of the type) and build a source
1030 /// materialization.
1031 Value findOrBuildReplacementValue(Value value,
1032 const TypeConverter *converter);
1033
1034 //===--------------------------------------------------------------------===//
1035 // Rewriter Notification Hooks
1036 //===--------------------------------------------------------------------===//
1037
1038 //// Notifies that an op was inserted.
1039 void notifyOperationInserted(Operation *op,
1040 OpBuilder::InsertPoint previous) override;
1041
1042 /// Notifies that a block was inserted.
1043 void notifyBlockInserted(Block *block, Region *previous,
1044 Region::iterator previousIt) override;
1045
1046 /// Notifies that a pattern match failed for the given reason.
1047 void
1048 notifyMatchFailure(Location loc,
1049 function_ref<void(Diagnostic &)> reasonCallback) override;
1050
1051 //===--------------------------------------------------------------------===//
1052 // IR Erasure
1053 //===--------------------------------------------------------------------===//
1054
1055 /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1056 /// operation or block is erased multiple times. This rewriter assumes that
1057 /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1059 public:
1062 std::function<void(Operation *)> opErasedCallback = nullptr)
1063 : RewriterBase(context, /*listener=*/this),
1064 opErasedCallback(std::move(opErasedCallback)) {}
1065
1066 /// Erase the given op (unless it was already erased).
1067 void eraseOp(Operation *op) override {
1068 if (wasErased(op))
1069 return;
1070 op->dropAllUses();
1072 }
1073
1074 /// Erase the given block (unless it was already erased).
1075 void eraseBlock(Block *block) override {
1076 if (wasErased(block))
1077 return;
1078 assert(block->empty() && "expected empty block");
1079 block->dropAllDefinedValueUses();
1081 }
1082
1083 bool wasErased(void *ptr) const { return erased.contains(ptr); }
1084
1086 erased.insert(op);
1087 if (opErasedCallback)
1088 opErasedCallback(op);
1089 }
1090
1091 void notifyBlockErased(Block *block) override { erased.insert(block); }
1092
1093 private:
1094 /// Pointers to all erased operations and blocks.
1095 DenseSet<void *> erased;
1096
1097 /// A callback that is invoked when an operation is erased.
1098 std::function<void(Operation *)> opErasedCallback;
1099 };
1100
1101 //===--------------------------------------------------------------------===//
1102 // State
1103 //===--------------------------------------------------------------------===//
1104
1105 /// The rewriter that is used to perform the conversion.
1106 ConversionPatternRewriter &rewriter;
1107
1108 // Mapping between replaced values that differ in type. This happens when
1109 // replacing a value with one of a different type.
1110 ConversionValueMapping mapping;
1111
1112 /// Ordered list of block operations (creations, splits, motions).
1113 /// This vector is maintained only if `allowPatternRollback` is set to
1114 /// "true". Otherwise, all IR rewrites are materialized immediately and no
1115 /// bookkeeping is needed.
1117
1118 /// A set of operations that should no longer be considered for legalization.
1119 /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1120 /// tracked separately.
1122
1123 /// A set of operations that were replaced/erased. Such ops are not erased
1124 /// immediately but only when the dialect conversion succeeds. In the mean
1125 /// time, they should no longer be considered for legalization and any attempt
1126 /// to modify/access them is invalid rewriter API usage.
1128
1129 /// A set of operations that were created by the current pattern.
1131
1132 /// A set of operations that were modified by the current pattern.
1134
1135 /// A list of unresolved materializations that were created by the current
1136 /// pattern.
1138
1139 /// A mapping for looking up metadata of unresolved materializations.
1140 llvm::MapVector<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
1142
1143 /// The current type converter, or nullptr if no type converter is currently
1144 /// active.
1146
1147 /// A mapping of regions to type converters that should be used when
1148 /// converting the arguments of blocks within that region.
1150
1151 /// Dialect conversion configuration.
1152 const ConversionConfig &config;
1153
1154 /// The operation converter to use for recursive legalization.
1156
1157 /// A set of erased operations. This set is utilized only if
1158 /// `allowPatternRollback` is set to "false". Conceptually, this set is
1159 /// similar to `replacedOps` (which is maintained when the flag is set to
1160 /// "true"). However, erasing from a DenseSet is more efficient than erasing
1161 /// from a SetVector.
1163
1164 /// A set of erased blocks. This set is utilized only if
1165 /// `allowPatternRollback` is set to "false".
1167
1168 /// A rewriter that notifies the listener (if any) about all IR
1169 /// modifications. This rewriter is utilized only if `allowPatternRollback`
1170 /// is set to "false". If the flag is set to "true", the listener is notified
1171 /// with a separate mechanism (e.g., in `IRRewrite::commit`).
1173
1174#ifndef NDEBUG
1175 /// A set of replaced values. This set is for debugging purposes only and it
1176 /// is maintained only if `allowPatternRollback` is set to "true".
1178
1179 /// A set of operations that have pending updates. This tracking isn't
1180 /// strictly necessary, and is thus only active during debug builds for extra
1181 /// verification.
1183
1184 /// A raw output stream used to prefix the debug log.
1185 llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
1186 llvm::dbgs()};
1187
1188 /// A logger used to emit diagnostics during the conversion process.
1189 llvm::ScopedPrinter logger{os};
1190 std::string logPrefix;
1191#endif
1192};
1193} // namespace detail
1194} // namespace mlir
1195
1196const ConversionConfig &IRRewrite::getConfig() const {
1197 return rewriterImpl.config;
1198}
1199
1200void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1201 // Inform the listener about all IR modifications that have already taken
1202 // place: References to the original block have been replaced with the new
1203 // block.
1204 if (auto *listener =
1205 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1206 for (Operation *op : getNewBlock()->getUsers())
1207 listener->notifyOperationModified(op);
1208}
1209
1210void BlockTypeConversionRewrite::rollback() {
1211 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1212}
1213
1214/// Replace all uses of `from` with `repl`.
1215static void
1217 function_ref<bool(OpOperand &)> functor = nullptr) {
1218 if (isa<BlockArgument>(repl)) {
1219 // `repl` is a block argument. Directly replace all uses.
1220 if (functor) {
1221 rewriter.replaceUsesWithIf(from, repl, functor);
1222 } else {
1223 rewriter.replaceAllUsesWith(from, repl);
1224 }
1225 return;
1226 }
1227
1228 // If the replacement value is an operation, only replace those uses that:
1229 // - are in a different block than the replacement operation, or
1230 // - are in the same block but after the replacement operation.
1231 //
1232 // Example:
1233 // ^bb0(%arg0: i32):
1234 // %0 = "consumer"(%arg0) : (i32) -> (i32)
1235 // "another_consumer"(%arg0) : (i32) -> ()
1236 //
1237 // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1238 // use in "another_consumer" but not the use in "consumer". When using the
1239 // normal RewriterBase API, this would typically be done with
1240 // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1241 // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1242 // it cannot be supported efficiently with `allowPatternRollback` set to
1243 // "true". Therefore, the conversion driver is trying to be smart and replaces
1244 // only those uses that do not lead to a dominance violation. E.g., the
1245 // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1246 // behavior.
1247 //
1248 // TODO: As we move more and more towards `allowPatternRollback` set to
1249 // "false", we should remove this special handling, in order to align the
1250 // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1251 Operation *replOp = repl.getDefiningOp();
1252 Block *replBlock = replOp->getBlock();
1253 rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
1254 Operation *user = operand.getOwner();
1255 bool result =
1256 user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1257 if (result && functor)
1258 result &= functor(operand);
1259 return result;
1260 });
1261}
1262
1263void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1264 Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
1265 if (!repl)
1266 return;
1267 performReplaceValue(rewriter, value, repl);
1268}
1269
1270void ReplaceValueRewrite::rollback() {
1271 rewriterImpl.mapping.erase({value});
1272#ifndef NDEBUG
1273 rewriterImpl.replacedValues.erase(value);
1274#endif // NDEBUG
1275}
1276
1277void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1278 auto *listener =
1279 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1280
1281 // Compute replacement values.
1282 SmallVector<Value> replacements =
1283 llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1284 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1285 });
1286
1287 // Notify the listener that the operation is about to be replaced.
1288 if (listener)
1289 listener->notifyOperationReplaced(op, replacements);
1290
1291 // Replace all uses with the new values.
1292 for (auto [result, newValue] :
1293 llvm::zip_equal(op->getResults(), replacements))
1294 if (newValue)
1295 rewriter.replaceAllUsesWith(result, newValue);
1296
1297 // The original op will be erased, so remove it from the set of unlegalized
1298 // ops.
1299 if (getConfig().unlegalizedOps)
1300 getConfig().unlegalizedOps->erase(op);
1301
1302 // Notify the listener that the operation and its contents are being erased.
1303 if (listener)
1304 notifyIRErased(listener, *op);
1305
1306 // Do not erase the operation yet. It may still be referenced in `mapping`.
1307 // Just unlink it for now and erase it during cleanup.
1308 if (!op->getBlock())
1309 llvm::reportFatalInternalError(
1310 "dialect conversion attempted to replace a root operation that has no "
1311 "parent block; the pass must ensure its target op is nested in a "
1312 "block");
1313 op->getBlock()->getOperations().remove(op);
1314}
1315
1316void ReplaceOperationRewrite::rollback() {
1317 for (auto result : op->getResults())
1318 rewriterImpl.mapping.erase({result});
1319}
1320
1321void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1322 rewriter.eraseOp(op);
1323}
1324
1325void CreateOperationRewrite::rollback() {
1326 for (Region &region : op->getRegions()) {
1327 while (!region.getBlocks().empty())
1328 region.getBlocks().remove(region.getBlocks().begin());
1329 }
1330 op->dropAllUses();
1331 op->erase();
1332}
1333
1334void UnresolvedMaterializationRewrite::rollback() {
1335 if (!mappedValues.empty())
1336 rewriterImpl.mapping.erase(mappedValues);
1337 rewriterImpl.unresolvedMaterializations.erase(getOperation());
1338 op->erase();
1339}
1340
1342 // Commit all rewrites. Use a new rewriter, so the modifications are not
1343 // tracked for rollback purposes etc.
1344 IRRewriter irRewriter(rewriter.getContext(), config.listener);
1345 // Note: New rewrites may be added during the "commit" phase and the
1346 // `rewrites` vector may reallocate.
1347 for (size_t i = 0; i < rewrites.size(); ++i)
1348 rewrites[i]->commit(irRewriter);
1349
1350 // Clean up all rewrites.
1351 SingleEraseRewriter eraseRewriter(
1352 rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
1353 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1354 unresolvedMaterializations.erase(castOp);
1355 });
1356 for (auto &rewrite : rewrites)
1357 rewrite->cleanup(eraseRewriter);
1358}
1359
1360//===----------------------------------------------------------------------===//
1361// State Management
1362//===----------------------------------------------------------------------===//
1363
1365 Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1366 // Helper function that looks up a single value.
1367 auto lookup = [&](const ValueVector &values) -> ValueVector {
1368 assert(!values.empty() && "expected non-empty value vector");
1369
1370 // If the pattern rollback is enabled, use the mapping to look up the
1371 // values.
1372 if (config.allowPatternRollback)
1373 return mapping.lookup(values);
1374
1375 // Otherwise, look up values by examining the IR. All replacements have
1376 // already been materialized in IR.
1377 Operation *op = getCommonDefiningOp(values);
1378 if (!op)
1379 return {};
1380 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1381 if (!castOp)
1382 return {};
1383 if (!this->unresolvedMaterializations.contains(castOp))
1384 return {};
1385 if (castOp.getOutputs() != values)
1386 return {};
1387 return castOp.getInputs();
1388 };
1389
1390 // Helper function that looks up each value in `values` individually and then
1391 // composes the results. If that fails, it tries to look up the entire vector
1392 // at once.
1393 auto composedLookup = [&](const ValueVector &values) -> ValueVector {
1394 // If possible, replace each value with (one or multiple) mapped values.
1395 ValueVector next;
1396 for (Value v : values) {
1397 ValueVector r = lookup({v});
1398 if (!r.empty()) {
1399 llvm::append_range(next, r);
1400 } else {
1401 next.push_back(v);
1402 }
1403 }
1404 if (next != values) {
1405 // At least one value was replaced.
1406 return next;
1407 }
1408
1409 // Otherwise: Check if there is a mapping for the entire vector. Such
1410 // mappings are materializations. (N:M mapping are not supported for value
1411 // replacements.)
1412 //
1413 // Note: From a correctness point of view, materializations do not have to
1414 // be stored (and looked up) in the mapping. But for performance reasons,
1415 // we choose to reuse existing IR (when possible) instead of creating it
1416 // multiple times.
1417 ValueVector r = lookup(values);
1418 if (r.empty()) {
1419 // No mapping found: The lookup stops here.
1420 return {};
1421 }
1422 return r;
1423 };
1424
1425 // Try to find the deepest values that have the desired types. If there is no
1426 // such mapping, simply return the deepest values.
1427 ValueVector desiredValue;
1428 ValueVector current{from};
1429 ValueVector lastNonMaterialization{from};
1430 do {
1431 // Store the current value if the types match.
1432 bool match = TypeRange(ValueRange(current)) == desiredTypes;
1433 if (skipPureTypeConversions) {
1434 // Skip pure type conversions, if requested.
1435 bool pureConversion = isPureTypeConversion(current);
1436 match &= !pureConversion;
1437 // Keep track of the last mapped value that was not a pure type
1438 // conversion.
1439 if (!pureConversion)
1440 lastNonMaterialization = current;
1441 }
1442 if (match)
1443 desiredValue = current;
1444
1445 // Lookup next value in the mapping.
1446 ValueVector next = composedLookup(current);
1447 if (next.empty())
1448 break;
1449 current = std::move(next);
1450 } while (true);
1451
1452 // If the desired values were found use them, otherwise default to the leaf
1453 // values. (Skip pure type conversions, if requested.)
1454 if (!desiredTypes.empty())
1455 return desiredValue;
1456 if (skipPureTypeConversions)
1457 return lastNonMaterialization;
1458 return current;
1459}
1460
1463 TypeRange desiredTypes) const {
1464 ValueVector result = lookupOrDefault(from, desiredTypes);
1465 if (result == ValueVector{from} ||
1466 (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
1467 return {};
1468 return result;
1469}
1470
1472 return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1473}
1474
1476 StringRef patternName) {
1477 // Undo any rewrites.
1478 undoRewrites(state.numRewrites, patternName);
1479
1480 // Pop all of the recorded ignored operations that are no longer valid.
1481 while (ignoredOps.size() != state.numIgnoredOperations)
1482 ignoredOps.pop_back();
1483
1484 while (replacedOps.size() != state.numReplacedOps)
1485 replacedOps.pop_back();
1486}
1487
1488void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1489 StringRef patternName) {
1490 for (auto &rewrite :
1491 llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1492 rewrite->rollback();
1493 rewrites.resize(numRewritesToKeep);
1494}
1495
1497 StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
1498 SmallVector<ValueVector> &remapped) {
1499 remapped.reserve(llvm::size(values));
1500
1501 for (const auto &it : llvm::enumerate(values)) {
1502 Value operand = it.value();
1503 Type origType = operand.getType();
1504 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1505
1506 if (!currentTypeConverter) {
1507 // The current pattern does not have a type converter. Pass the most
1508 // recently mapped values, excluding materializations. Materializations
1509 // are intentionally excluded because their presence may depend on other
1510 // patterns. Including materializations would make the lookup fragile
1511 // and unpredictable.
1512 remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1513 /*skipPureTypeConversions=*/true));
1514 continue;
1515 }
1516
1517 // If there is no legal conversion, fail to match this pattern.
1518 SmallVector<Type, 1> legalTypes;
1519 if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
1520 notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1521 diag << "unable to convert type for " << valueDiagTag << " #"
1522 << it.index() << ", type was " << origType;
1523 });
1524 return failure();
1525 }
1526 // If a type is converted to 0 types, there is nothing to do.
1527 if (legalTypes.empty()) {
1528 remapped.push_back({});
1529 continue;
1530 }
1531
1532 ValueVector repl = lookupOrDefault(operand, legalTypes);
1533 if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1534 // Mapped values have the correct type or there is an existing
1535 // materialization. Or the operand is not mapped at all and has the
1536 // correct type.
1537 remapped.push_back(std::move(repl));
1538 continue;
1539 }
1540
1541 // Create a materialization for the most recently mapped values.
1542 repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1543 /*skipPureTypeConversions=*/true);
1545 MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1546 /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1547 /*originalType=*/origType, currentTypeConverter);
1548 remapped.push_back(castValues);
1549 }
1550 return success();
1551}
1552
1554 // Check to see if this operation is ignored or was replaced.
1555 return wasOpReplaced(op) || ignoredOps.count(op);
1556}
1557
1559 // Check to see if this operation was replaced.
1560 return replacedOps.count(op) || erasedOps.count(op);
1561}
1562
1563//===----------------------------------------------------------------------===//
1564// Type Conversion
1565//===----------------------------------------------------------------------===//
1566
1568 Region *region, const TypeConverter &converter,
1569 TypeConverter::SignatureConversion *entryConversion) {
1570 regionToConverter[region] = &converter;
1571 if (region->empty())
1572 return nullptr;
1573
1574 // Convert the arguments of each non-entry block within the region.
1575 for (Block &block :
1576 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1577 // Compute the signature for the block with the provided converter.
1578 std::optional<TypeConverter::SignatureConversion> conversion =
1579 converter.convertBlockSignature(&block);
1580 if (!conversion)
1581 return failure();
1582 // Convert the block with the computed signature.
1583 applySignatureConversion(&block, &converter, *conversion);
1584 }
1585
1586 // Convert the entry block. If an entry signature conversion was provided,
1587 // use that one. Otherwise, compute the signature with the type converter.
1588 if (entryConversion)
1589 return applySignatureConversion(&region->front(), &converter,
1590 *entryConversion);
1591 std::optional<TypeConverter::SignatureConversion> conversion =
1592 converter.convertBlockSignature(&region->front());
1593 if (!conversion)
1594 return failure();
1595 return applySignatureConversion(&region->front(), &converter, *conversion);
1596}
1597
1599 Block *block, const TypeConverter *converter,
1600 TypeConverter::SignatureConversion &signatureConversion) {
1601#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1602 // A block cannot be converted multiple times.
1603 if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1604 llvm::reportFatalInternalError("block was already converted");
1605#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1606
1608
1609 // If no arguments are being changed or added, there is nothing to do.
1610 unsigned origArgCount = block->getNumArguments();
1611 auto convertedTypes = signatureConversion.getConvertedTypes();
1612 if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1613 return block;
1614
1615 // Compute the locations of all block arguments in the new block.
1616 SmallVector<Location> newLocs(convertedTypes.size(),
1617 rewriter.getUnknownLoc());
1618 for (unsigned i = 0; i < origArgCount; ++i) {
1619 auto inputMap = signatureConversion.getInputMapping(i);
1620 if (!inputMap || inputMap->replacedWithValues())
1621 continue;
1622 Location origLoc = block->getArgument(i).getLoc();
1623 for (unsigned j = 0; j < inputMap->size; ++j)
1624 newLocs[inputMap->inputNo + j] = origLoc;
1625 }
1626
1627 // Insert a new block with the converted block argument types and move all ops
1628 // from the old block to the new block.
1629 Block *newBlock =
1630 rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1631 convertedTypes, newLocs);
1632
1633 // If a listener is attached to the dialect conversion, ops cannot be moved
1634 // to the destination block in bulk ("fast path"). This is because at the time
1635 // the notifications are sent, it is unknown which ops were moved. Instead,
1636 // ops should be moved one-by-one ("slow path"), so that a separate
1637 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1638 // a bit more efficient, so we try to do that when possible.
1639 bool fastPath = !config.listener;
1640 if (fastPath) {
1641 if (config.allowPatternRollback)
1642 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1643 newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1644 } else {
1645 while (!block->empty())
1646 rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1647 }
1648
1649 // Replace all uses of the old block with the new block.
1650 block->replaceAllUsesWith(newBlock);
1651
1652 for (unsigned i = 0; i != origArgCount; ++i) {
1653 BlockArgument origArg = block->getArgument(i);
1654 Type origArgType = origArg.getType();
1655
1656 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1657 signatureConversion.getInputMapping(i);
1658 if (!inputMap) {
1659 // This block argument was dropped and no replacement value was provided.
1660 // Materialize a replacement value "out of thin air".
1661 // Note: Materialization must be built here because we cannot find a
1662 // valid insertion point in the new block. (Will point to the old block.)
1663 Value mat =
1665 MaterializationKind::Source,
1666 OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1667 origArg.getLoc(),
1668 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1669 /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1670 /*isPureTypeConversion=*/false)
1671 .front();
1672 replaceValueUses(origArg, mat, converter);
1673 continue;
1674 }
1675
1676 if (inputMap->replacedWithValues()) {
1677 // This block argument was dropped and replacement values were provided.
1678 assert(inputMap->size == 0 &&
1679 "invalid to provide a replacement value when the argument isn't "
1680 "dropped");
1681 replaceValueUses(origArg, inputMap->replacementValues, converter);
1682 continue;
1683 }
1684
1685 // This is a 1->1+ mapping.
1686 auto replArgs =
1687 newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1688 replaceValueUses(origArg, replArgs, converter);
1689 }
1690
1691 if (config.allowPatternRollback)
1692 appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1693
1694 // Erase the old block. (It is just unlinked for now and will be erased during
1695 // cleanup.)
1696 rewriter.eraseBlock(block);
1697
1698 return newBlock;
1699}
1700
1701//===----------------------------------------------------------------------===//
1702// Materializations
1703//===----------------------------------------------------------------------===//
1704
1705/// Build an unresolved materialization operation given an output type and set
1706/// of input operands.
1708 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1709 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1710 Type originalType, const TypeConverter *converter,
1711 bool isPureTypeConversion) {
1712 assert((!originalType || kind == MaterializationKind::Target) &&
1713 "original type is valid only for target materializations");
1714 assert(TypeRange(inputs) != outputTypes &&
1715 "materialization is not necessary");
1716
1717 // Create an unresolved materialization. We use a new OpBuilder to avoid
1718 // tracking the materialization like we do for other operations.
1719 OpBuilder builder(outputTypes.front().getContext());
1720 builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1721 UnrealizedConversionCastOp convertOp =
1722 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1723 if (config.attachDebugMaterializationKind) {
1724 StringRef kindStr =
1725 kind == MaterializationKind::Source ? "source" : "target";
1726 convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
1727 }
1729 convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
1730
1731 // Register the materialization.
1732 unresolvedMaterializations[convertOp] =
1733 UnresolvedMaterializationInfo(converter, kind, originalType);
1734 if (config.allowPatternRollback) {
1735 if (!valuesToMap.empty())
1736 mapping.map(valuesToMap, convertOp.getResults());
1738 std::move(valuesToMap));
1739 } else {
1740 patternMaterializations.insert(convertOp);
1741 }
1742 return convertOp.getResults();
1743}
1744
1746 Value value, const TypeConverter *converter) {
1747 assert(config.allowPatternRollback &&
1748 "this code path is valid only in rollback mode");
1749
1750 // Try to find a replacement value with the same type in the conversion value
1751 // mapping. This includes cached materializations. We try to reuse those
1752 // instead of generating duplicate IR.
1753 ValueVector repl = lookupOrNull(value, value.getType());
1754 if (!repl.empty())
1755 return repl.front();
1756
1757 // Check if the value is dead. No replacement value is needed in that case.
1758 // This is an approximate check that may have false negatives but does not
1759 // require computing and traversing an inverse mapping. (We may end up
1760 // building source materializations that are never used and that fold away.)
1761 if (llvm::all_of(value.getUsers(),
1762 [&](Operation *op) { return replacedOps.contains(op); }) &&
1763 !mapping.isMappedTo(value))
1764 return Value();
1765
1766 // No replacement value was found. Get the latest replacement value
1767 // (regardless of the type) and build a source materialization to the
1768 // original type.
1769 repl = lookupOrNull(value);
1770
1771 // Compute the insertion point of the materialization.
1773 if (repl.empty()) {
1774 // The source materialization has no inputs. Insert it right before the
1775 // value that it is replacing.
1776 ip = computeInsertPoint(value);
1777 } else {
1778 // Compute the "earliest" insertion point at which all values in `repl` are
1779 // defined. It is important to emit the materialization at that location
1780 // because the same materialization may be reused in a different context.
1781 // (That's because materializations are cached in the conversion value
1782 // mapping.) The insertion point of the materialization must be valid for
1783 // all future users that may be created later in the conversion process.
1784 ip = computeInsertPoint(repl);
1785 }
1787 MaterializationKind::Source, ip, value.getLoc(),
1788 /*valuesToMap=*/repl, /*inputs=*/repl,
1789 /*outputTypes=*/value.getType(),
1790 /*originalType=*/Type(), converter,
1791 /*isPureTypeConversion=*/!repl.empty())
1792 .front();
1793 return castValue;
1794}
1795
1796//===----------------------------------------------------------------------===//
1797// Rewriter Notification Hooks
1798//===----------------------------------------------------------------------===//
1799
1801 Operation *op, OpBuilder::InsertPoint previous) {
1802 // If no previous insertion point is provided, the op used to be detached.
1803 bool wasDetached = !previous.isSet();
1804 LLVM_DEBUG({
1805 logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
1806 << ")";
1807 if (wasDetached)
1808 logger.getOStream() << " (was detached)";
1809 logger.getOStream() << "\n";
1810 });
1811
1812 // In rollback mode, it is easier to misuse the API, so perform extra error
1813 // checking.
1814 assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
1815 "attempting to insert into a block within a replaced/erased op");
1816
1817 // In "no rollback" mode, the listener is always notified immediately.
1818 if (!config.allowPatternRollback && config.listener)
1819 config.listener->notifyOperationInserted(op, previous);
1820
1821 if (wasDetached) {
1822 // If the op was detached, it is most likely a newly created op. Add it the
1823 // set of newly created ops, so that it will be legalized. If this op is
1824 // not a newly created op, it will be legalized a second time, which is
1825 // inefficient but harmless.
1826 patternNewOps.insert(op);
1827
1828 if (config.allowPatternRollback) {
1829 // TODO: If the same op is inserted multiple times from a detached
1830 // state, the rollback mechanism may erase the same op multiple times.
1831 // This is a bug in the rollback-based dialect conversion driver.
1833 } else {
1834 // In "no rollback" mode, there is an extra data structure for tracking
1835 // erased operations that must be kept up to date.
1836 erasedOps.erase(op);
1837 }
1838 return;
1839 }
1840
1841 // The op was moved from one place to another.
1842 if (config.allowPatternRollback)
1844}
1845
1846/// Given that `fromRange` is about to be replaced with `toRange`, compute
1847/// replacement values with the types of `fromRange`.
1848static SmallVector<Value>
1850 const SmallVector<SmallVector<Value>> &toRange,
1851 const TypeConverter *converter) {
1852 assert(!impl.config.allowPatternRollback &&
1853 "this code path is valid only in 'no rollback' mode");
1854 SmallVector<Value> repls;
1855 for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1856 if (from.use_empty()) {
1857 // The replaced value is dead. No replacement value is needed.
1858 repls.push_back(Value());
1859 continue;
1860 }
1861
1862 if (to.empty()) {
1863 // The replaced value is dropped. Materialize a replacement value "out of
1864 // thin air".
1865 Value srcMat = impl.buildUnresolvedMaterialization(
1866 MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
1867 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1868 /*outputTypes=*/from.getType(), /*originalType=*/Type(),
1869 converter)[0];
1870 repls.push_back(srcMat);
1871 continue;
1872 }
1873
1874 if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
1875 // The replacement value already has the correct type. Use it directly.
1876 repls.push_back(to[0]);
1877 continue;
1878 }
1879
1880 // The replacement value has the wrong type. Build a source materialization
1881 // to the original type.
1882 // TODO: This is a bit inefficient. We should try to reuse existing
1883 // materializations if possible. This would require an extension of the
1884 // `lookupOrDefault` API.
1885 Value srcMat = impl.buildUnresolvedMaterialization(
1886 MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
1887 /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
1888 /*originalType=*/Type(), converter)[0];
1889 repls.push_back(srcMat);
1890 }
1891
1892 return repls;
1893}
1894
1896 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1897 assert(newValues.size() == op->getNumResults() &&
1898 "incorrect number of replacement values");
1899 LLVM_DEBUG({
1900 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
1901 << ")\n";
1903 // If the user-provided replacement types are different from the
1904 // legalized types, as per the current type converter, print a note.
1905 // In most cases, the replacement types are expected to match the types
1906 // produced by the type converter, so this could indicate a bug in the
1907 // user code.
1908 for (auto [result, repls] :
1909 llvm::zip_equal(op->getResults(), newValues)) {
1910 Type resultType = result.getType();
1911 auto logProlog = [&, repls = repls]() {
1912 logger.startLine() << " Note: Replacing op result of type "
1913 << resultType << " with value(s) of type (";
1914 llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) {
1915 logger.getOStream() << v.getType();
1916 });
1917 logger.getOStream() << ")";
1918 };
1919 SmallVector<Type> convertedTypes;
1920 if (failed(currentTypeConverter->convertTypes(resultType,
1921 convertedTypes))) {
1922 logProlog();
1923 logger.getOStream() << ", but the type converter failed to legalize "
1924 "the original type.\n";
1925 continue;
1926 }
1927 if (TypeRange(convertedTypes) != TypeRange(ValueRange(repls))) {
1928 logProlog();
1929 logger.getOStream() << ", but the legalized type(s) is/are (";
1930 llvm::interleaveComma(convertedTypes, logger.getOStream(),
1931 [&](Type t) { logger.getOStream() << t; });
1932 logger.getOStream() << ")\n";
1933 }
1934 }
1935 }
1936 });
1937
1938 if (!config.allowPatternRollback) {
1939 // Pattern rollback is not allowed: materialize all IR changes immediately.
1941 *this, op->getResults(), newValues, currentTypeConverter);
1942 // Update internal data structures, so that there are no dangling pointers
1943 // to erased IR.
1944 op->walk([&](Operation *op) {
1945 erasedOps.insert(op);
1946 ignoredOps.remove(op);
1947 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1948 unresolvedMaterializations.erase(castOp);
1949 patternMaterializations.erase(castOp);
1950 }
1951 // The original op will be erased, so remove it from the set of
1952 // unlegalized ops.
1953 if (config.unlegalizedOps)
1954 config.unlegalizedOps->erase(op);
1955 });
1956 op->walk([&](Block *block) { erasedBlocks.insert(block); });
1957 // Replace the op with the replacement values and notify the listener.
1958 notifyingRewriter.replaceOp(op, repls);
1959 return;
1960 }
1961
1962 assert(!ignoredOps.contains(op) && "operation was already replaced");
1963#ifndef NDEBUG
1964 for (Value v : op->getResults())
1965 assert(!replacedValues.contains(v) &&
1966 "attempting to replace a value that was already replaced");
1967#endif // NDEBUG
1968
1969 // Check if replaced op is an unresolved materialization, i.e., an
1970 // unrealized_conversion_cast op that was created by the conversion driver.
1971 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1972 // Make sure that the user does not mess with unresolved materializations
1973 // that were inserted by the conversion driver. We keep track of these
1974 // ops in internal data structures.
1975 assert(!unresolvedMaterializations.contains(castOp) &&
1976 "attempting to replace/erase an unresolved materialization");
1977 }
1978
1979 // Create mappings for each of the new result values.
1980 for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
1981 mapping.map(static_cast<Value>(result), std::move(repl));
1982
1984 // Mark this operation and all nested ops as replaced.
1985 op->walk([&](Operation *op) { replacedOps.insert(op); });
1986}
1987
1989 Value from, ValueRange to, const TypeConverter *converter,
1990 function_ref<bool(OpOperand &)> functor) {
1991 LLVM_DEBUG({
1992 logger.startLine() << "** Replace Value : '" << from << "'";
1993 if (auto blockArg = dyn_cast<BlockArgument>(from)) {
1994 if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
1995 logger.getOStream() << " (in region of '" << parentOp->getName()
1996 << "' (" << parentOp << ")";
1997 } else {
1998 logger.getOStream() << " (unlinked block)";
1999 }
2000 }
2001 if (functor) {
2002 logger.getOStream() << ", conditional replacement";
2003 }
2004 });
2005
2006 if (!config.allowPatternRollback) {
2007 SmallVector<Value> toConv = llvm::to_vector(to);
2008 SmallVector<Value> repls =
2009 getReplacementValues(*this, from, {toConv}, converter);
2010 IRRewriter r(from.getContext());
2011 Value repl = repls.front();
2012 if (!repl)
2013 return;
2014
2015 performReplaceValue(r, from, repl, functor);
2016 return;
2017 }
2018
2019#ifndef NDEBUG
2020 // Make sure that a value is not replaced multiple times. In rollback mode,
2021 // `replaceAllUsesWith` replaces not only all current uses of the given value,
2022 // but also all future uses that may be introduced by future pattern
2023 // applications. Therefore, it does not make sense to call
2024 // `replaceAllUsesWith` multiple times with the same value. Doing so would
2025 // overwrite the mapping and mess with the internal state of the dialect
2026 // conversion driver.
2027 assert(!replacedValues.contains(from) &&
2028 "attempting to replace a value that was already replaced");
2029 assert(!wasOpReplaced(from.getDefiningOp()) &&
2030 "attempting to replace a op result that was already replaced");
2031 replacedValues.insert(from);
2032#endif // NDEBUG
2033
2034 if (functor)
2035 llvm::reportFatalInternalError(
2036 "conditional value replacement is not supported in rollback mode");
2037 mapping.map(from, to);
2038 appendRewrite<ReplaceValueRewrite>(from, converter);
2039}
2040
2042 if (!config.allowPatternRollback) {
2043 // Pattern rollback is not allowed: materialize all IR changes immediately.
2044 // Update internal data structures, so that there are no dangling pointers
2045 // to erased IR.
2046 block->walk([&](Operation *op) {
2047 erasedOps.insert(op);
2048 ignoredOps.remove(op);
2049 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2050 unresolvedMaterializations.erase(castOp);
2051 patternMaterializations.erase(castOp);
2052 }
2053 // The original op will be erased, so remove it from the set of
2054 // unlegalized ops.
2055 if (config.unlegalizedOps)
2056 config.unlegalizedOps->erase(op);
2057 });
2058 block->walk([&](Block *block) { erasedBlocks.insert(block); });
2059 // Erase the block and notify the listener.
2060 notifyingRewriter.eraseBlock(block);
2061 return;
2062 }
2063
2064 assert(!wasOpReplaced(block->getParentOp()) &&
2065 "attempting to erase a block within a replaced/erased op");
2067
2068 // Unlink the block from its parent region. The block is kept in the rewrite
2069 // object and will be actually destroyed when rewrites are applied. This
2070 // allows us to keep the operations in the block live and undo the removal by
2071 // re-inserting the block.
2072 block->getParent()->getBlocks().remove(block);
2073
2074 // Mark all nested ops as erased.
2075 block->walk([&](Operation *op) { replacedOps.insert(op); });
2076}
2077
2079 Block *block, Region *previous, Region::iterator previousIt) {
2080 // If no previous insertion point is provided, the block used to be detached.
2081 bool wasDetached = !previous;
2082 Operation *newParentOp = block->getParentOp();
2083 LLVM_DEBUG(
2084 {
2085 Operation *parent = newParentOp;
2086 if (parent) {
2087 logger.startLine() << "** Insert Block into : '" << parent->getName()
2088 << "' (" << parent << ")";
2089 } else {
2090 logger.startLine()
2091 << "** Insert Block into detached Region (nullptr parent op)";
2092 }
2093 if (wasDetached)
2094 logger.getOStream() << " (was detached)";
2095 logger.getOStream() << "\n";
2096 });
2097
2098 // In rollback mode, it is easier to misuse the API, so perform extra error
2099 // checking.
2100 assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
2101 "attempting to insert into a region within a replaced/erased op");
2102 (void)newParentOp;
2103
2104 // In "no rollback" mode, the listener is always notified immediately.
2105 if (!config.allowPatternRollback && config.listener)
2106 config.listener->notifyBlockInserted(block, previous, previousIt);
2107
2108 if (wasDetached) {
2109 // If the block was detached, it is most likely a newly created block.
2110 if (config.allowPatternRollback) {
2111 // TODO: If the same block is inserted multiple times from a detached
2112 // state, the rollback mechanism may erase the same block multiple times.
2113 // This is a bug in the rollback-based dialect conversion driver.
2115 } else {
2116 // In "no rollback" mode, there is an extra data structure for tracking
2117 // erased blocks that must be kept up to date.
2118 erasedBlocks.erase(block);
2119 }
2120 return;
2121 }
2122
2123 // The block was moved from one place to another.
2124 if (config.allowPatternRollback)
2125 appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
2126}
2127
2129 Block *dest,
2130 Block::iterator before) {
2131 appendRewrite<InlineBlockRewrite>(dest, source, before);
2132}
2133
2135 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
2136 LLVM_DEBUG({
2138 reasonCallback(diag);
2139 logger.startLine() << "** Failure : " << diag.str() << "\n";
2140 if (config.notifyCallback)
2141 config.notifyCallback(diag);
2142 });
2143}
2144
2145//===----------------------------------------------------------------------===//
2146// ConversionPatternRewriter
2147//===----------------------------------------------------------------------===//
2148
2149ConversionPatternRewriter::ConversionPatternRewriter(
2150 MLIRContext *ctx, const ConversionConfig &config,
2151 OperationConverter &opConverter)
2153 *this, config, opConverter)) {
2154 setListener(impl.get());
2155}
2156
2157ConversionPatternRewriter::~ConversionPatternRewriter() = default;
2158
2159const ConversionConfig &ConversionPatternRewriter::getConfig() const {
2160 return impl->config;
2161}
2162
2163void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
2164 assert(op && newOp && "expected non-null op");
2165 replaceOp(op, newOp->getResults());
2166}
2167
2168void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
2169 assert(op->getNumResults() == newValues.size() &&
2170 "incorrect # of replacement values");
2171
2172 // If the current insertion point is before the erased operation, we adjust
2173 // the insertion point to be after the operation.
2174 if (getInsertionPoint() == op->getIterator())
2176
2177 SmallVector<SmallVector<Value>> newVals =
2178 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2179 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2180 });
2181 impl->replaceOp(op, std::move(newVals));
2182}
2183
2184void ConversionPatternRewriter::replaceOpWithMultiple(
2185 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2186 assert(op->getNumResults() == newValues.size() &&
2187 "incorrect # of replacement values");
2188
2189 // If the current insertion point is before the erased operation, we adjust
2190 // the insertion point to be after the operation.
2191 if (getInsertionPoint() == op->getIterator())
2193
2194 impl->replaceOp(op, std::move(newValues));
2195}
2196
2197void ConversionPatternRewriter::eraseOp(Operation *op) {
2198 LLVM_DEBUG({
2199 impl->logger.startLine()
2200 << "** Erase : '" << op->getName() << "'(" << op << ")\n";
2201 });
2202
2203 // If the current insertion point is before the erased operation, we adjust
2204 // the insertion point to be after the operation.
2205 if (getInsertionPoint() == op->getIterator())
2207
2208 SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
2209 impl->replaceOp(op, std::move(nullRepls));
2210}
2211
2212void ConversionPatternRewriter::eraseBlock(Block *block) {
2213 impl->eraseBlock(block);
2214}
2215
2216Block *ConversionPatternRewriter::applySignatureConversion(
2217 Block *block, TypeConverter::SignatureConversion &conversion,
2218 const TypeConverter *converter) {
2219 assert(!impl->wasOpReplaced(block->getParentOp()) &&
2220 "attempting to apply a signature conversion to a block within a "
2221 "replaced/erased op");
2222 return impl->applySignatureConversion(block, converter, conversion);
2223}
2224
2225FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2226 Region *region, const TypeConverter &converter,
2227 TypeConverter::SignatureConversion *entryConversion) {
2228 assert(!impl->wasOpReplaced(region->getParentOp()) &&
2229 "attempting to apply a signature conversion to a block within a "
2230 "replaced/erased op");
2231 return impl->convertRegionTypes(region, converter, entryConversion);
2232}
2233
2234void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
2235 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2236}
2237
2238void ConversionPatternRewriter::replaceUsesWithIf(
2239 Value from, ValueRange to, function_ref<bool(OpOperand &)> functor,
2240 bool *allUsesReplaced) {
2241 assert(!allUsesReplaced &&
2242 "allUsesReplaced is not supported in a dialect conversion");
2243 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2244}
2245
2246Value ConversionPatternRewriter::getRemappedValue(Value key) {
2247 SmallVector<ValueVector> remappedValues;
2248 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
2249 remappedValues)))
2250 return nullptr;
2251 assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
2252 return remappedValues.front().front();
2253}
2254
2255LogicalResult
2256ConversionPatternRewriter::getRemappedValues(ValueRange keys,
2257 SmallVectorImpl<Value> &results) {
2258 if (keys.empty())
2259 return success();
2260 SmallVector<ValueVector> remapped;
2261 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
2262 remapped)))
2263 return failure();
2264 for (const auto &values : remapped) {
2265 assert(values.size() == 1 && "1:N conversion not supported");
2266 results.push_back(values.front());
2267 }
2268 return success();
2269}
2270
2271void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
2272 Block::iterator before,
2273 ValueRange argValues) {
2274#ifndef NDEBUG
2275 assert(argValues.size() == source->getNumArguments() &&
2276 "incorrect # of argument replacement values");
2277 assert(!impl->wasOpReplaced(source->getParentOp()) &&
2278 "attempting to inline a block from a replaced/erased op");
2279 assert(!impl->wasOpReplaced(dest->getParentOp()) &&
2280 "attempting to inline a block into a replaced/erased op");
2281 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
2282 // The source block will be deleted, so it should not have any users (i.e.,
2283 // there should be no predecessors).
2284 assert(llvm::all_of(source->getUsers(), opIgnored) &&
2285 "expected 'source' to have no predecessors");
2286#endif // NDEBUG
2287
2288 // If a listener is attached to the dialect conversion, ops cannot be moved
2289 // to the destination block in bulk ("fast path"). This is because at the time
2290 // the notifications are sent, it is unknown which ops were moved. Instead,
2291 // ops should be moved one-by-one ("slow path"), so that a separate
2292 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
2293 // a bit more efficient, so we try to do that when possible.
2294 bool fastPath = !getConfig().listener;
2295
2296 if (fastPath && impl->config.allowPatternRollback)
2297 impl->inlineBlockBefore(source, dest, before);
2298
2299 // Replace all uses of block arguments.
2300 for (auto it : llvm::zip(source->getArguments(), argValues))
2301 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2302
2303 if (fastPath) {
2304 // Move all ops at once.
2305 dest->getOperations().splice(before, source->getOperations());
2306 } else {
2307 // Move op by op.
2308 while (!source->empty())
2309 moveOpBefore(&source->front(), dest, before);
2310 }
2311
2312 // If the current insertion point is within the source block, adjust the
2313 // insertion point to the destination block.
2314 if (getInsertionBlock() == source)
2315 setInsertionPoint(dest, getInsertionPoint());
2316
2317 // Erase the source block.
2318 eraseBlock(source);
2319}
2320
2321void ConversionPatternRewriter::startOpModification(Operation *op) {
2322 if (!impl->config.allowPatternRollback) {
2323 // Pattern rollback is not allowed: no extra bookkeeping is needed.
2325 return;
2326 }
2327 assert(!impl->wasOpReplaced(op) &&
2328 "attempting to modify a replaced/erased op");
2329#ifndef NDEBUG
2330 impl->pendingRootUpdates.insert(op);
2331#endif
2332 impl->appendRewrite<ModifyOperationRewrite>(op);
2333}
2334
2335void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2336 impl->patternModifiedOps.insert(op);
2337 if (!impl->config.allowPatternRollback) {
2339 if (getConfig().listener)
2340 getConfig().listener->notifyOperationModified(op);
2341 return;
2342 }
2343
2344 // There is nothing to do here, we only need to track the operation at the
2345 // start of the update.
2346#ifndef NDEBUG
2347 assert(!impl->wasOpReplaced(op) &&
2348 "attempting to modify a replaced/erased op");
2349 assert(impl->pendingRootUpdates.erase(op) &&
2350 "operation did not have a pending in-place update");
2351#endif
2352}
2353
2354void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2355 if (!impl->config.allowPatternRollback) {
2357 return;
2358 }
2359#ifndef NDEBUG
2360 assert(impl->pendingRootUpdates.erase(op) &&
2361 "operation did not have a pending in-place update");
2362#endif
2363 // Erase the last update for this operation.
2364 auto it = llvm::find_if(
2365 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
2366 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2367 return modifyRewrite && modifyRewrite->getOperation() == op;
2368 });
2369 assert(it != impl->rewrites.rend() && "no root update started on op");
2370 (*it)->rollback();
2371 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2372 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2373}
2374
2375detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2376 return *impl;
2377}
2378
2379//===----------------------------------------------------------------------===//
2380// ConversionPattern
2381//===----------------------------------------------------------------------===//
2382
2383FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2384 ArrayRef<ValueRange> operands) const {
2385 SmallVector<Value> oneToOneOperands;
2386 oneToOneOperands.reserve(operands.size());
2387 for (ValueRange operand : operands) {
2388 if (operand.size() != 1)
2389 return failure();
2390
2391 oneToOneOperands.push_back(operand.front());
2392 }
2393 return std::move(oneToOneOperands);
2394}
2395
2396LogicalResult
2397ConversionPattern::matchAndRewrite(Operation *op,
2398 PatternRewriter &rewriter) const {
2399 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
2400 auto &rewriterImpl = dialectRewriter.getImpl();
2401
2402 // Track the current conversion pattern type converter in the rewriter.
2403 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
2404 getTypeConverter());
2405
2406 // Remap the operands of the operation.
2407 SmallVector<ValueVector> remapped;
2408 if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
2409 op->getOperands(), remapped))) {
2410 return failure();
2411 }
2412 SmallVector<ValueRange> remappedAsRange =
2413 llvm::to_vector_of<ValueRange>(remapped);
2414 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2415}
2416
2417//===----------------------------------------------------------------------===//
2418// OperationLegalizer
2419//===----------------------------------------------------------------------===//
2420
2421namespace {
2422/// A set of rewrite patterns that can be used to legalize a given operation.
2423using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2424
2425/// This class defines a recursive operation legalizer.
2426class OperationLegalizer {
2427public:
2428 using LegalizationAction = ConversionTarget::LegalizationAction;
2429
2430 OperationLegalizer(ConversionPatternRewriter &rewriter,
2431 const ConversionTarget &targetInfo,
2432 const FrozenRewritePatternSet &patterns);
2433
2434 /// Returns true if the given operation is known to be illegal on the target.
2435 bool isIllegal(Operation *op) const;
2436
2437 /// Attempt to legalize the given operation. Returns success if the operation
2438 /// was legalized, failure otherwise.
2439 LogicalResult legalize(Operation *op);
2440
2441 /// Returns the conversion target in use by the legalizer.
2442 const ConversionTarget &getTarget() { return target; }
2443
2444private:
2445 /// Attempt to legalize the given operation by folding it.
2446 LogicalResult legalizeWithFold(Operation *op);
2447
2448 /// Attempt to legalize the given operation by applying a pattern. Returns
2449 /// success if the operation was legalized, failure otherwise.
2450 LogicalResult legalizeWithPattern(Operation *op);
2451
2452 /// Return true if the given pattern may be applied to the given operation,
2453 /// false otherwise.
2454 bool canApplyPattern(Operation *op, const Pattern &pattern);
2455
2456 /// Legalize the resultant IR after successfully applying the given pattern.
2457 LogicalResult
2458 legalizePatternResult(Operation *op, const Pattern &pattern,
2459 const RewriterState &curState,
2460 const SetVector<Operation *> &newOps,
2461 const SetVector<Operation *> &modifiedOps);
2462
2463 LogicalResult
2464 legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
2465 LogicalResult
2466 legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
2467
2468 //===--------------------------------------------------------------------===//
2469 // Cost Model
2470 //===--------------------------------------------------------------------===//
2471
2472 /// Build an optimistic legalization graph given the provided patterns. This
2473 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2474 /// patterns for operations that are not directly legal, but may be
2475 /// transitively legal for the current target given the provided patterns.
2476 void buildLegalizationGraph(
2477 LegalizationPatterns &anyOpLegalizerPatterns,
2479
2480 /// Compute the benefit of each node within the computed legalization graph.
2481 /// This orders the patterns within 'legalizerPatterns' based upon two
2482 /// criteria:
2483 /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2484 /// represent the more direct mapping to the target.
2485 /// 2) When comparing patterns with the same legalization depth, prefer the
2486 /// pattern with the highest PatternBenefit. This allows for users to
2487 /// prefer specific legalizations over others.
2488 void computeLegalizationGraphBenefit(
2489 LegalizationPatterns &anyOpLegalizerPatterns,
2491
2492 /// Compute the legalization depth when legalizing an operation of the given
2493 /// type.
2494 unsigned computeOpLegalizationDepth(
2495 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2497
2498 /// Apply the conversion cost model to the given set of patterns, and return
2499 /// the smallest legalization depth of any of the patterns. See
2500 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2501 unsigned applyCostModelToPatterns(
2502 LegalizationPatterns &patterns,
2503 DenseMap<OperationName, unsigned> &minOpPatternDepth,
2505
2506 /// The current set of patterns that have been applied.
2507 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2508
2509 /// The rewriter to use when converting operations.
2510 ConversionPatternRewriter &rewriter;
2511
2512 /// The legalization information provided by the target.
2513 const ConversionTarget &target;
2514
2515 /// The pattern applicator to use for conversions.
2516 PatternApplicator applicator;
2517};
2518} // namespace
2519
2520OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2521 const ConversionTarget &targetInfo,
2522 const FrozenRewritePatternSet &patterns)
2523 : rewriter(rewriter), target(targetInfo), applicator(patterns) {
2524 // The set of patterns that can be applied to illegal operations to transform
2525 // them into legal ones.
2527 LegalizationPatterns anyOpLegalizerPatterns;
2528
2529 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2530 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2531}
2532
2533bool OperationLegalizer::isIllegal(Operation *op) const {
2534 return target.isIllegal(op);
2535}
2536
2537LogicalResult OperationLegalizer::legalize(Operation *op) {
2538#ifndef NDEBUG
2539 const char *logLineComment =
2540 "//===-------------------------------------------===//\n";
2541
2542 auto &logger = rewriter.getImpl().logger;
2543#endif
2544
2545 // Check to see if the operation is ignored and doesn't need to be converted.
2546 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2547
2548 LLVM_DEBUG({
2549 logger.getOStream() << "\n";
2550 logger.startLine() << logLineComment;
2551 logger.startLine() << "Legalizing operation : ";
2552 // Do not print the operation name if the operation is ignored. Ignored ops
2553 // may have been erased and should not be accessed. The pointer can be
2554 // printed safely.
2555 if (!isIgnored)
2556 logger.getOStream() << "'" << op->getName() << "' ";
2557 logger.getOStream() << "(" << op << ") {\n";
2558 logger.indent();
2559
2560 // If the operation has no regions, just print it here.
2561 if (!isIgnored && op->getNumRegions() == 0) {
2562 logger.startLine() << OpWithFlags(op,
2563 OpPrintingFlags().printGenericOpForm())
2564 << "\n";
2565 }
2566 });
2567
2568 if (isIgnored) {
2569 LLVM_DEBUG({
2570 logSuccess(logger, "operation marked 'ignored' during conversion");
2571 logger.startLine() << logLineComment;
2572 });
2573 return success();
2574 }
2575
2576 // Check if this operation is legal on the target.
2577 if (auto legalityInfo = target.isLegal(op)) {
2578 LLVM_DEBUG({
2579 logSuccess(
2580 logger, "operation marked legal by the target{0}",
2581 legalityInfo->isRecursivelyLegal
2582 ? "; NOTE: operation is recursively legal; skipping internals"
2583 : "");
2584 logger.startLine() << logLineComment;
2585 });
2586
2587 // If this operation is recursively legal, mark its children as ignored so
2588 // that we don't consider them for legalization.
2589 if (legalityInfo->isRecursivelyLegal) {
2590 op->walk([&](Operation *nested) {
2591 if (op != nested)
2592 rewriter.getImpl().ignoredOps.insert(nested);
2593 });
2594 }
2595
2596 return success();
2597 }
2598
2599 // If the operation is not legal, try to fold it in-place if the folding mode
2600 // is 'BeforePatterns'. 'Never' will skip this.
2601 const ConversionConfig &config = rewriter.getConfig();
2602 if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2603 if (succeeded(legalizeWithFold(op))) {
2604 LLVM_DEBUG({
2605 logSuccess(logger, "operation was folded");
2606 logger.startLine() << logLineComment;
2607 });
2608 return success();
2609 }
2610 }
2611
2612 // Otherwise, we need to apply a legalization pattern to this operation.
2613 if (succeeded(legalizeWithPattern(op))) {
2614 LLVM_DEBUG({
2615 logSuccess(logger, "");
2616 logger.startLine() << logLineComment;
2617 });
2618 return success();
2619 }
2620
2621 // If the operation can't be legalized via patterns, try to fold it in-place
2622 // if the folding mode is 'AfterPatterns'.
2623 if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2624 if (succeeded(legalizeWithFold(op))) {
2625 LLVM_DEBUG({
2626 logSuccess(logger, "operation was folded");
2627 logger.startLine() << logLineComment;
2628 });
2629 return success();
2630 }
2631 }
2632
2633 LLVM_DEBUG({
2634 logFailure(logger, "no matched legalization pattern");
2635 logger.startLine() << logLineComment;
2636 });
2637 return failure();
2638}
2639
2640/// Helper function that moves and returns the given object. Also resets the
2641/// original object, so that it is in a valid, empty state again.
2642template <typename T>
2643static T moveAndReset(T &obj) {
2644 T result = std::move(obj);
2645 obj = T();
2646 return result;
2647}
2648
2649LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2650 auto &rewriterImpl = rewriter.getImpl();
2651 LLVM_DEBUG({
2652 rewriterImpl.logger.startLine() << "* Fold {\n";
2653 rewriterImpl.logger.indent();
2654 });
2655
2656 // Clear pattern state, so that the next pattern application starts with a
2657 // clean slate. (The op/block sets are populated by listener notifications.)
2658 llvm::scope_exit cleanup([&]() {
2659 rewriterImpl.patternNewOps.clear();
2660 rewriterImpl.patternModifiedOps.clear();
2661 });
2662
2663 // Upon failure, undo all changes made by the folder.
2664 RewriterState curState = rewriterImpl.getCurrentState();
2665
2666 // Try to fold the operation.
2667 StringRef opName = op->getName().getStringRef();
2668 SmallVector<Value, 2> replacementValues;
2669 SmallVector<Operation *, 2> newOps;
2670 rewriter.setInsertionPoint(op);
2671 rewriter.startOpModification(op);
2672 if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2673 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2674 rewriter.cancelOpModification(op);
2675 return failure();
2676 }
2677 rewriter.finalizeOpModification(op);
2678
2679 // An empty list of replacement values indicates that the fold was in-place.
2680 // As the operation changed, a new legalization needs to be attempted.
2681 if (replacementValues.empty())
2682 return legalize(op);
2683
2684 // Insert a replacement for 'op' with the folded replacement values.
2685 rewriter.replaceOp(op, replacementValues);
2686
2687 // Recursively legalize any new constant operations.
2688 for (Operation *newOp : newOps) {
2689 if (failed(legalize(newOp))) {
2690 LLVM_DEBUG(logFailure(rewriterImpl.logger,
2691 "failed to legalize generated constant '{0}'",
2692 newOp->getName()));
2693 if (!rewriter.getConfig().allowPatternRollback) {
2694 // Rolling back a folder is like rolling back a pattern.
2695 llvm::reportFatalInternalError(
2696 "op '" + opName +
2697 "' folder rollback of IR modifications requested");
2698 }
2699 rewriterImpl.resetState(
2700 curState, std::string(op->getName().getStringRef()) + " folder");
2701 return failure();
2702 }
2703 }
2704
2705 LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2706 return success();
2707}
2708
2709/// Report a fatal error indicating that newly produced or modified IR could
2710/// not be legalized.
2711static void
2713 const SetVector<Operation *> &newOps,
2714 const SetVector<Operation *> &modifiedOps) {
2715 auto newOpNames = llvm::map_range(
2716 newOps, [](Operation *op) { return op->getName().getStringRef(); });
2717 auto modifiedOpNames = llvm::map_range(
2718 modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
2719 llvm::reportFatalInternalError("pattern '" + pattern.getDebugName() +
2720 "' produced IR that could not be legalized. " +
2721 "new ops: {" + llvm::join(newOpNames, ", ") +
2722 "}, " + "modified ops: {" +
2723 llvm::join(modifiedOpNames, ", ") + "}");
2724}
2725
2726LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2727 auto &rewriterImpl = rewriter.getImpl();
2728 const ConversionConfig &config = rewriter.getConfig();
2729
2730#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2731 Operation *checkOp;
2732 std::optional<OperationFingerPrint> topLevelFingerPrint;
2733 if (!rewriterImpl.config.allowPatternRollback) {
2734 // The op may be getting erased, so we have to check the parent op.
2735 // (In rare cases, a pattern may even erase the parent op, which will cause
2736 // a crash here. Expensive checks are "best effort".) Skip the check if the
2737 // op does not have a parent op.
2738 if ((checkOp = op->getParentOp())) {
2739 if (!op->getContext()->isMultithreadingEnabled()) {
2740 topLevelFingerPrint = OperationFingerPrint(checkOp);
2741 } else {
2742 // Another thread may be modifying a sibling operation. Therefore, the
2743 // fingerprinting mechanism of the parent op works only in
2744 // single-threaded mode.
2745 LLVM_DEBUG({
2746 rewriterImpl.logger.startLine()
2747 << "WARNING: Multi-threadeding is enabled. Some dialect "
2748 "conversion expensive checks are skipped in multithreading "
2749 "mode!\n";
2750 });
2751 }
2752 }
2753 }
2754#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2755
2756 // Functor that returns if the given pattern may be applied.
2757 auto canApply = [&](const Pattern &pattern) {
2758 bool canApply = canApplyPattern(op, pattern);
2759 if (canApply && config.listener)
2760 config.listener->notifyPatternBegin(pattern, op);
2761 return canApply;
2762 };
2763
2764 // Functor that cleans up the rewriter state after a pattern failed to match.
2765 RewriterState curState = rewriterImpl.getCurrentState();
2766 auto onFailure = [&](const Pattern &pattern) {
2767 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2768 if (!rewriterImpl.config.allowPatternRollback) {
2769 // Erase all unresolved materializations.
2770 for (auto op : rewriterImpl.patternMaterializations) {
2771 rewriterImpl.unresolvedMaterializations.erase(op);
2772 op.erase();
2773 }
2774 rewriterImpl.patternMaterializations.clear();
2775#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2776 // Expensive pattern check that can detect API violations.
2777 if (checkOp && topLevelFingerPrint) {
2778 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2779 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2780 llvm::reportFatalInternalError(
2781 "pattern '" + pattern.getDebugName() +
2782 "' returned failure but IR did change");
2783 }
2784#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2785 }
2786 rewriterImpl.patternNewOps.clear();
2787 rewriterImpl.patternModifiedOps.clear();
2788 LLVM_DEBUG({
2789 logFailure(rewriterImpl.logger, "pattern failed to match");
2790 if (rewriterImpl.config.notifyCallback) {
2791 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
2792 diag << "Failed to apply pattern \"" << pattern.getDebugName()
2793 << "\" on op:\n"
2794 << *op;
2795 rewriterImpl.config.notifyCallback(diag);
2796 }
2797 });
2798 if (config.listener)
2799 config.listener->notifyPatternEnd(pattern, failure());
2800 rewriterImpl.resetState(curState, pattern.getDebugName());
2801 appliedPatterns.erase(&pattern);
2802 };
2803
2804 // Functor that performs additional legalization when a pattern is
2805 // successfully applied.
2806 auto onSuccess = [&](const Pattern &pattern) {
2807 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2808 if (!rewriterImpl.config.allowPatternRollback) {
2809 // Eagerly erase unused materializations.
2810 for (auto op : rewriterImpl.patternMaterializations) {
2811 if (op->use_empty()) {
2812 rewriterImpl.unresolvedMaterializations.erase(op);
2813 op.erase();
2814 }
2815 }
2816 rewriterImpl.patternMaterializations.clear();
2817 }
2818 SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2819 SetVector<Operation *> modifiedOps =
2820 moveAndReset(rewriterImpl.patternModifiedOps);
2821 auto result =
2822 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2823 appliedPatterns.erase(&pattern);
2824 if (failed(result)) {
2825 if (!rewriterImpl.config.allowPatternRollback)
2826 reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
2827 rewriterImpl.resetState(curState, pattern.getDebugName());
2828 }
2829 if (config.listener)
2830 config.listener->notifyPatternEnd(pattern, result);
2831 return result;
2832 };
2833
2834 // Try to match and rewrite a pattern on this operation.
2835 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2836 onSuccess);
2837}
2838
2839bool OperationLegalizer::canApplyPattern(Operation *op,
2840 const Pattern &pattern) {
2841 LLVM_DEBUG({
2842 auto &os = rewriter.getImpl().logger;
2843 os.getOStream() << "\n";
2844 os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2845 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2846 os.getOStream() << ")' {\n";
2847 os.indent();
2848 });
2849
2850 // Ensure that we don't cycle by not allowing the same pattern to be
2851 // applied twice in the same recursion stack if it is not known to be safe.
2852 if (!pattern.hasBoundedRewriteRecursion() &&
2853 !appliedPatterns.insert(&pattern).second) {
2854 LLVM_DEBUG(
2855 logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2856 return false;
2857 }
2858 return true;
2859}
2860
2861LogicalResult OperationLegalizer::legalizePatternResult(
2862 Operation *op, const Pattern &pattern, const RewriterState &curState,
2863 const SetVector<Operation *> &newOps,
2864 const SetVector<Operation *> &modifiedOps) {
2865 [[maybe_unused]] auto &impl = rewriter.getImpl();
2866 assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2867
2868#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2869 if (impl.config.allowPatternRollback) {
2870 // Check that the root was either replaced or updated in place.
2871 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2872 auto replacedRoot = [&] {
2873 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2874 };
2875 auto updatedRootInPlace = [&] {
2876 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2877 };
2878 if (!replacedRoot() && !updatedRootInPlace())
2879 llvm::reportFatalInternalError(
2880 "expected pattern to replace the root operation "
2881 "or modify it in place");
2882 }
2883#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2884
2885 // Legalize each of the actions registered during application.
2886 if (failed(legalizePatternRootUpdates(modifiedOps)) ||
2887 failed(legalizePatternCreatedOperations(newOps))) {
2888 return failure();
2889 }
2890
2891 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2892 return success();
2893}
2894
2895LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2896 const SetVector<Operation *> &newOps) {
2897 for (Operation *op : newOps) {
2898 if (failed(legalize(op))) {
2899 LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
2900 "failed to legalize generated operation '{0}'({1})",
2901 op->getName(), op));
2902 return failure();
2903 }
2904 }
2905 return success();
2906}
2907
2908LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2909 const SetVector<Operation *> &modifiedOps) {
2910 for (Operation *op : modifiedOps) {
2911 if (failed(legalize(op))) {
2912 LLVM_DEBUG(
2913 logFailure(rewriter.getImpl().logger,
2914 "failed to legalize operation updated in-place '{0}'",
2915 op->getName()));
2916 return failure();
2917 }
2918 }
2919 return success();
2920}
2921
2922//===----------------------------------------------------------------------===//
2923// Cost Model
2924//===----------------------------------------------------------------------===//
2925
2926void OperationLegalizer::buildLegalizationGraph(
2927 LegalizationPatterns &anyOpLegalizerPatterns,
2929 // A mapping between an operation and a set of operations that can be used to
2930 // generate it.
2932 // A mapping between an operation and any currently invalid patterns it has.
2934 // A worklist of patterns to consider for legality.
2935 SetVector<const Pattern *> patternWorklist;
2936
2937 // Build the mapping from operations to the parent ops that may generate them.
2938 applicator.walkAllPatterns([&](const Pattern &pattern) {
2939 std::optional<OperationName> root = pattern.getRootKind();
2940
2941 // If the pattern has no specific root, we can't analyze the relationship
2942 // between the root op and generated operations. Given that, add all such
2943 // patterns to the legalization set.
2944 if (!root) {
2945 anyOpLegalizerPatterns.push_back(&pattern);
2946 return;
2947 }
2948
2949 // Skip operations that are always known to be legal.
2950 if (target.getOpAction(*root) == LegalizationAction::Legal)
2951 return;
2952
2953 // Add this pattern to the invalid set for the root op and record this root
2954 // as a parent for any generated operations.
2955 invalidPatterns[*root].insert(&pattern);
2956 for (auto op : pattern.getGeneratedOps())
2957 parentOps[op].insert(*root);
2958
2959 // Add this pattern to the worklist.
2960 patternWorklist.insert(&pattern);
2961 });
2962
2963 // If there are any patterns that don't have a specific root kind, we can't
2964 // make direct assumptions about what operations will never be legalized.
2965 // Note: Technically we could, but it would require an analysis that may
2966 // recurse into itself. It would be better to perform this kind of filtering
2967 // at a higher level than here anyways.
2968 if (!anyOpLegalizerPatterns.empty()) {
2969 for (const Pattern *pattern : patternWorklist)
2970 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2971 return;
2972 }
2973
2974 while (!patternWorklist.empty()) {
2975 auto *pattern = patternWorklist.pop_back_val();
2976
2977 // Check to see if any of the generated operations are invalid.
2978 if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2979 std::optional<LegalizationAction> action = target.getOpAction(op);
2980 return !legalizerPatterns.count(op) &&
2981 (!action || action == LegalizationAction::Illegal);
2982 }))
2983 continue;
2984
2985 // Otherwise, if all of the generated operation are valid, this op is now
2986 // legal so add all of the child patterns to the worklist.
2987 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2988 invalidPatterns[*pattern->getRootKind()].erase(pattern);
2989
2990 // Add any invalid patterns of the parent operations to see if they have now
2991 // become legal.
2992 for (auto op : parentOps[*pattern->getRootKind()])
2993 patternWorklist.set_union(invalidPatterns[op]);
2994 }
2995}
2996
2997void OperationLegalizer::computeLegalizationGraphBenefit(
2998 LegalizationPatterns &anyOpLegalizerPatterns,
3000 // The smallest pattern depth, when legalizing an operation.
3001 DenseMap<OperationName, unsigned> minOpPatternDepth;
3002
3003 // For each operation that is transitively legal, compute a cost for it.
3004 for (auto &opIt : legalizerPatterns)
3005 if (!minOpPatternDepth.count(opIt.first))
3006 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3007 legalizerPatterns);
3008
3009 // Apply the cost model to the patterns that can match any operation. Those
3010 // with a specific operation type are already resolved when computing the op
3011 // legalization depth.
3012 if (!anyOpLegalizerPatterns.empty())
3013 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3014 legalizerPatterns);
3015
3016 // Apply a cost model to the pattern applicator. We order patterns first by
3017 // depth then benefit. `legalizerPatterns` contains per-op patterns by
3018 // decreasing benefit.
3019 applicator.applyCostModel([&](const Pattern &pattern) {
3020 ArrayRef<const Pattern *> orderedPatternList;
3021 if (std::optional<OperationName> rootName = pattern.getRootKind())
3022 orderedPatternList = legalizerPatterns[*rootName];
3023 else
3024 orderedPatternList = anyOpLegalizerPatterns;
3025
3026 // If the pattern is not found, then it was removed and cannot be matched.
3027 auto *it = llvm::find(orderedPatternList, &pattern);
3028 if (it == orderedPatternList.end())
3030
3031 // Patterns found earlier in the list have higher benefit.
3032 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3033 });
3034}
3035
3036unsigned OperationLegalizer::computeOpLegalizationDepth(
3037 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
3039 // Check for existing depth.
3040 auto depthIt = minOpPatternDepth.find(op);
3041 if (depthIt != minOpPatternDepth.end())
3042 return depthIt->second;
3043
3044 // If a mapping for this operation does not exist, then this operation
3045 // is always legal. Return 0 as the depth for a directly legal operation.
3046 auto opPatternsIt = legalizerPatterns.find(op);
3047 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3048 return 0u;
3049
3050 // Record this initial depth in case we encounter this op again when
3051 // recursively computing the depth.
3052 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3053
3054 // Apply the cost model to the operation patterns, and update the minimum
3055 // depth.
3056 unsigned minDepth = applyCostModelToPatterns(
3057 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3058 minOpPatternDepth[op] = minDepth;
3059 return minDepth;
3060}
3061
3062unsigned OperationLegalizer::applyCostModelToPatterns(
3063 LegalizationPatterns &patterns,
3064 DenseMap<OperationName, unsigned> &minOpPatternDepth,
3066 unsigned minDepth = std::numeric_limits<unsigned>::max();
3067
3068 // Compute the depth for each pattern within the set.
3069 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3070 patternsByDepth.reserve(patterns.size());
3071 for (const Pattern *pattern : patterns) {
3072 unsigned depth = 1;
3073 for (auto generatedOp : pattern->getGeneratedOps()) {
3074 unsigned generatedOpDepth = computeOpLegalizationDepth(
3075 generatedOp, minOpPatternDepth, legalizerPatterns);
3076 depth = std::max(depth, generatedOpDepth + 1);
3077 }
3078 patternsByDepth.emplace_back(pattern, depth);
3079
3080 // Update the minimum depth of the pattern list.
3081 minDepth = std::min(minDepth, depth);
3082 }
3083
3084 // If the operation only has one legalization pattern, there is no need to
3085 // sort them.
3086 if (patternsByDepth.size() == 1)
3087 return minDepth;
3088
3089 // Sort the patterns by those likely to be the most beneficial.
3090 llvm::stable_sort(patternsByDepth,
3091 [](const std::pair<const Pattern *, unsigned> &lhs,
3092 const std::pair<const Pattern *, unsigned> &rhs) {
3093 // First sort by the smaller pattern legalization
3094 // depth.
3095 if (lhs.second != rhs.second)
3096 return lhs.second < rhs.second;
3097
3098 // Then sort by the larger pattern benefit.
3099 auto lhsBenefit = lhs.first->getBenefit();
3100 auto rhsBenefit = rhs.first->getBenefit();
3101 return lhsBenefit > rhsBenefit;
3102 });
3103
3104 // Update the legalization pattern to use the new sorted list.
3105 patterns.clear();
3106 for (auto &patternIt : patternsByDepth)
3107 patterns.push_back(patternIt.first);
3108 return minDepth;
3109}
3110
3111//===----------------------------------------------------------------------===//
3112// Reconcile Unrealized Casts
3113//===----------------------------------------------------------------------===//
3114
3115/// Try to reconcile all given UnrealizedConversionCastOps and store the
3116/// left-over ops in `remainingCastOps` (if provided). See documentation in
3117/// DialectConversion.h for more details.
3118/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3119/// algorithm may visit an operand (or user) which is a cast op, but will not
3120/// try to reconcile it if not in the filtered set.
3121template <typename RangeT>
3123 RangeT castOps,
3124 function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3126 // A worklist of cast ops to process.
3127 SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
3128
3129 // Helper function that return the unrealized_conversion_cast op that
3130 // defines all inputs of the given op (in the same order). Return "nullptr"
3131 // if there is no such op.
3132 auto getInputCast =
3133 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3134 if (castOp.getInputs().empty())
3135 return {};
3136 auto inputCastOp =
3137 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3138 if (!inputCastOp)
3139 return {};
3140 if (inputCastOp.getOutputs() != castOp.getInputs())
3141 return {};
3142 return inputCastOp;
3143 };
3144
3145 // Process ops in the worklist bottom-to-top.
3146 while (!worklist.empty()) {
3147 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3148
3149 // Traverse the chain of input cast ops to see if an op with the same
3150 // input types can be found.
3151 UnrealizedConversionCastOp nextCast = castOp;
3152 while (nextCast) {
3153 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3154 if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
3155 return v.getDefiningOp() == castOp;
3156 })) {
3157 // Ran into a cycle.
3158 break;
3159 }
3160
3161 // Found a cast where the input types match the output types of the
3162 // matched op. We can directly use those inputs.
3163 castOp.replaceAllUsesWith(nextCast.getInputs());
3164 break;
3165 }
3166 nextCast = getInputCast(nextCast);
3167 }
3168 }
3169
3170 // A set of all alive cast ops. I.e., ops whose results are (transitively)
3171 // used by an op that is not a cast op.
3172 DenseSet<Operation *> liveOps;
3173
3174 // Helper function that marks the given op and transitively reachable input
3175 // cast ops as alive.
3176 auto markOpLive = [&](Operation *rootOp) {
3177 SmallVector<Operation *> worklist;
3178 worklist.push_back(rootOp);
3179 while (!worklist.empty()) {
3180 Operation *op = worklist.pop_back_val();
3181 if (liveOps.insert(op).second) {
3182 // Successfully inserted: process reachable input cast ops.
3183 for (Value v : op->getOperands())
3184 if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3185 if (isCastOpOfInterestFn(castOp))
3186 worklist.push_back(castOp);
3187 }
3188 }
3189 };
3190
3191 // Find all alive cast ops.
3192 for (UnrealizedConversionCastOp op : castOps) {
3193 // The op may have been marked live already as being an operand of another
3194 // live cast op.
3195 if (liveOps.contains(op.getOperation()))
3196 continue;
3197 // If any of the users is not a cast op, mark the current op (and its
3198 // input ops) as live.
3199 if (llvm::any_of(op->getUsers(), [&](Operation *user) {
3200 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3201 return !castOp || !isCastOpOfInterestFn(castOp);
3202 }))
3203 markOpLive(op);
3204 }
3205
3206 // Erase all dead cast ops.
3207 for (UnrealizedConversionCastOp op : castOps) {
3208 if (liveOps.contains(op)) {
3209 // Op is alive and was not erased. Add it to the remaining cast ops.
3210 if (remainingCastOps)
3211 remainingCastOps->push_back(op);
3212 continue;
3213 }
3214
3215 // Op is dead. Erase it.
3216 op->dropAllUses();
3217 op->erase();
3218 }
3219}
3220
3222 ArrayRef<UnrealizedConversionCastOp> castOps,
3223 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3224 // Set of all cast ops for faster lookups.
3225 DenseSet<UnrealizedConversionCastOp> castOpSet;
3226 for (UnrealizedConversionCastOp op : castOps)
3227 castOpSet.insert(op);
3228 reconcileUnrealizedCasts(castOpSet, remainingCastOps);
3229}
3230
3232 const DenseSet<UnrealizedConversionCastOp> &castOps,
3233 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3235 llvm::make_range(castOps.begin(), castOps.end()),
3236 [&](UnrealizedConversionCastOp castOp) {
3237 return castOps.contains(castOp);
3238 },
3239 remainingCastOps);
3240}
3241
3242namespace mlir {
3244 const llvm::MapVector<UnrealizedConversionCastOp,
3245 UnresolvedMaterializationInfo> &castOps,
3248 castOps.keys(),
3249 [&](UnrealizedConversionCastOp castOp) {
3250 return castOps.contains(castOp);
3251 },
3252 remainingCastOps);
3253}
3254} // namespace mlir
3255
3256//===----------------------------------------------------------------------===//
3257// OperationConverter
3258//===----------------------------------------------------------------------===//
3259
3260namespace mlir {
3261// This class converts operations to a given conversion target via a set of
3262// rewrite patterns. The conversion behaves differently depending on the
3263// conversion mode.
3266 const FrozenRewritePatternSet &patterns,
3267 const ConversionConfig &config,
3268 OpConversionMode mode)
3269 : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
3270 mode(mode) {}
3271
3272 /// Applies the conversion to the given operations (and their nested
3273 /// operations).
3274 LogicalResult applyConversion(ArrayRef<Operation *> ops);
3275
3276 /// Legalizes the given operations (and their nested operations) to the
3277 /// conversion target.
3278 template <typename Fn>
3279 LogicalResult legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
3280 bool isRecursiveLegalization = false);
3282 bool isRecursiveLegalization = false) {
3283 return legalizeOperations(
3284 ops, /*onFailure=*/[&]() {}, isRecursiveLegalization);
3285 }
3286
3287 /// Converts a single operation. If `isRecursiveLegalization` is "true", the
3288 /// conversion is a recursive legalization request, triggered from within a
3289 /// pattern. In that case, do not emit errors because there will be another
3290 /// attempt at legalizing the operation later (via the regular pre-order
3291 /// legalization mechanism).
3292 LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
3293
3294 const ConversionTarget &getTarget() { return opLegalizer.getTarget(); }
3295
3296private:
3297 /// The rewriter to use when converting operations.
3298 ConversionPatternRewriter rewriter;
3299
3300 /// The legalizer to use when converting operations.
3301 OperationLegalizer opLegalizer;
3302
3303 /// The conversion mode to use when legalizing operations.
3304 OpConversionMode mode;
3305};
3306} // namespace mlir
3307
3309 bool isRecursiveLegalization) {
3310 const ConversionConfig &config = rewriter.getConfig();
3311 auto emitFailedToLegalizeDiag = [&](bool wasExplicitlyIllegal) {
3313 << "failed to legalize operation '"
3314 << op->getName() << "'";
3315 if (wasExplicitlyIllegal)
3316 diag << " that was explicitly marked illegal";
3317 diag << ": " << OpWithFlags(op, OpPrintingFlags().skipRegions());
3318 };
3319
3320 // Legalize the given operation.
3321 if (failed(opLegalizer.legalize(op))) {
3322 // Handle the case of a failed conversion for each of the different modes.
3323 // Full conversions expect all operations to be converted.
3324 if (mode == OpConversionMode::Full) {
3325 if (!isRecursiveLegalization)
3326 emitFailedToLegalizeDiag(/*wasExplicitlyIllegal=*/false);
3327 return failure();
3328 }
3329 // Partial conversions allow conversions to fail iff the operation was not
3330 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
3331 // set, non-legalizable ops are added to that set.
3332 if (mode == OpConversionMode::Partial) {
3333 if (opLegalizer.isIllegal(op)) {
3334 if (!isRecursiveLegalization)
3335 emitFailedToLegalizeDiag(/*wasExplicitlyIllegal=*/true);
3336 return failure();
3337 }
3338 if (config.unlegalizedOps && !isRecursiveLegalization)
3339 config.unlegalizedOps->insert(op);
3340 }
3341 } else if (mode == OpConversionMode::Analysis) {
3342 // Analysis conversions don't fail if any operations fail to legalize,
3343 // they are only interested in the operations that were successfully
3344 // legalized.
3345 if (config.legalizableOps && !isRecursiveLegalization)
3346 config.legalizableOps->insert(op);
3347 }
3348 return success();
3349}
3350
3351static LogicalResult
3353 UnrealizedConversionCastOp op,
3354 const UnresolvedMaterializationInfo &info) {
3355 assert(!op.use_empty() &&
3356 "expected that dead materializations have already been DCE'd");
3357 Operation::operand_range inputOperands = op.getOperands();
3358
3359 // Try to materialize the conversion.
3360 if (const TypeConverter *converter = info.getConverter()) {
3361 rewriter.setInsertionPoint(op);
3362 SmallVector<Value> newMaterialization;
3363 switch (info.getMaterializationKind()) {
3364 case MaterializationKind::Target:
3365 newMaterialization = converter->materializeTargetConversion(
3366 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3367 info.getOriginalType());
3368 break;
3369 case MaterializationKind::Source:
3370 assert(op->getNumResults() == 1 && "expected single result");
3371 Value sourceMat = converter->materializeSourceConversion(
3372 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3373 if (sourceMat)
3374 newMaterialization.push_back(sourceMat);
3375 break;
3376 }
3377 if (!newMaterialization.empty()) {
3378#ifndef NDEBUG
3379 ValueRange newMaterializationRange(newMaterialization);
3380 assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
3381 "materialization callback produced value of incorrect type");
3382#endif // NDEBUG
3383 rewriter.replaceOp(op, newMaterialization);
3384 return success();
3385 }
3386 }
3387
3388 InFlightDiagnostic diag = op->emitError()
3389 << "failed to legalize unresolved materialization "
3390 "from ("
3391 << inputOperands.getTypes() << ") to ("
3392 << op.getResultTypes()
3393 << ") that remained live after conversion";
3394 diag.attachNote(op->getUsers().begin()->getLoc())
3395 << "see existing live user here: " << *op->getUsers().begin();
3396 return failure();
3397}
3398
3399template <typename Fn>
3400LogicalResult
3402 bool isRecursiveLegalization) {
3403 const ConversionTarget &target = opLegalizer.getTarget();
3404
3405 // Compute the set of operations and blocks to convert.
3406 SmallVector<Operation *> toConvert;
3407 for (Operation *op : ops) {
3409 [&](Operation *op) {
3410 toConvert.push_back(op);
3411 // Don't check this operation's children for conversion if the
3412 // operation is recursively legal.
3413 auto legalityInfo = target.isLegal(op);
3414 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3415 return WalkResult::skip();
3416 return WalkResult::advance();
3417 });
3418 }
3419 for (Operation *op : toConvert) {
3420 if (failed(convert(op, isRecursiveLegalization))) {
3421 // Failed to convert an operation.
3422 onFailure();
3423 return failure();
3424 }
3425 }
3426 return success();
3427}
3428
3429LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
3430 return impl->opConverter.legalizeOperations(op,
3431 /*isRecursiveLegalization=*/true);
3432}
3433
3434LogicalResult ConversionPatternRewriter::legalize(Region *r) {
3435 // Fast path: If the region is empty, there is nothing to legalize.
3436 if (r->empty())
3437 return success();
3438
3439 // Gather a list of all operations to legalize. This is done before
3440 // converting the entry block signature because unrealized_conversion_cast
3441 // ops should not be included.
3443 for (Block &b : *r)
3444 for (Operation &op : b)
3445 ops.push_back(&op);
3446
3447 // If the current pattern runs with a type converter, convert the entry block
3448 // signature.
3449 if (const TypeConverter *converter = impl->currentTypeConverter) {
3450 std::optional<TypeConverter::SignatureConversion> conversion =
3451 converter->convertBlockSignature(&r->front());
3452 if (!conversion)
3453 return failure();
3454 applySignatureConversion(&r->front(), *conversion, converter);
3455 }
3456
3457 // Legalize all operations in the region. This includes all nested
3458 // operations.
3459 return impl->opConverter.legalizeOperations(ops,
3460 /*isRecursiveLegalization=*/true);
3461}
3462
3464 // Convert each operation and discard rewrites on failure.
3465 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
3466
3467 LogicalResult status = legalizeOperations(ops, /*onFailure=*/[&]() {
3468 // Dialect conversion failed.
3469 if (rewriterImpl.config.allowPatternRollback) {
3470 // Rollback is allowed: restore the original IR.
3471 rewriterImpl.undoRewrites();
3472 } else {
3473 // Rollback is not allowed: apply all modifications that have been
3474 // performed so far.
3475 rewriterImpl.applyRewrites();
3476 }
3477 });
3478 if (failed(status))
3479 return failure();
3480
3481 // After a successful conversion, apply rewrites.
3482 rewriterImpl.applyRewrites();
3483
3484 // Reconcile all UnrealizedConversionCastOps that were inserted by the
3485 // dialect conversion frameworks. (Not the ones that were inserted by
3486 // patterns.)
3487 const llvm::MapVector<UnrealizedConversionCastOp,
3488 UnresolvedMaterializationInfo> &materializations =
3489 rewriterImpl.unresolvedMaterializations;
3491 reconcileUnrealizedCasts(materializations, &remainingCastOps);
3492
3493 // Drop markers.
3494 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3495 castOp->removeAttr(kPureTypeConversionMarker);
3496
3497 // Try to legalize all unresolved materializations.
3498 if (rewriter.getConfig().buildMaterializations) {
3499 // Use a new rewriter, so the modifications are not tracked for rollback
3500 // purposes etc.
3501 IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
3502 rewriter.getConfig().listener);
3503 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3504 auto it = materializations.find(castOp);
3505 assert(it != materializations.end() && "inconsistent state");
3506 if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
3507 it->second)))
3508 return failure();
3509 }
3510 }
3511
3512 return success();
3513}
3514
3515//===----------------------------------------------------------------------===//
3516// Type Conversion
3517//===----------------------------------------------------------------------===//
3518
3519void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
3520 ArrayRef<Type> types) {
3521 assert(!types.empty() && "expected valid types");
3522 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
3523 addInputs(types);
3524}
3525
3526void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
3527 assert(!types.empty() &&
3528 "1->0 type remappings don't need to be added explicitly");
3529 argTypes.append(types.begin(), types.end());
3530}
3531
3532void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3533 unsigned newInputNo,
3534 unsigned newInputCount) {
3535 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3536 assert(newInputCount != 0 && "expected valid input count");
3537 remappedInputs[origInputNo] =
3538 InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
3539}
3540
3541void TypeConverter::SignatureConversion::remapInput(
3542 unsigned origInputNo, ArrayRef<Value> replacements) {
3543 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3544 remappedInputs[origInputNo] = InputMapping{
3545 origInputNo, /*size=*/0,
3546 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3547}
3548
3549/// Internal implementation of the type conversion.
3550/// This is used with either a Type or a Value as the first argument.
3551/// - we can cache the context-free conversions until the last registered
3552/// context-aware conversion.
3553/// - we can't cache the result of type conversion happening after context-aware
3554/// conversions, because the type converter may return different results for the
3555/// same input type.
3556LogicalResult
3557TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3558 SmallVectorImpl<Type> &results) const {
3559 assert(typeOrValue && "expected non-null type");
3560 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3561 : cast<Type>(typeOrValue);
3562 {
3563 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3564 std::defer_lock);
3566 cacheReadLock.lock();
3567 auto existingIt = cachedDirectConversions.find(t);
3568 if (existingIt != cachedDirectConversions.end()) {
3569 if (existingIt->second)
3570 results.push_back(existingIt->second);
3571 return success(existingIt->second != nullptr);
3572 }
3573 auto multiIt = cachedMultiConversions.find(t);
3574 if (multiIt != cachedMultiConversions.end()) {
3575 results.append(multiIt->second.begin(), multiIt->second.end());
3576 return success();
3577 }
3578 }
3579 // Walk the added converters in reverse order to apply the most recently
3580 // registered first.
3581 size_t currentCount = results.size();
3582
3583 // We can cache the context-free conversions until the last registered
3584 // context-aware conversion. But only if we're processing a Value right now.
3585 auto isCacheable = [&](int index) {
3586 int numberOfConversionsUntilContextAware =
3587 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3588 return index < numberOfConversionsUntilContextAware;
3589 };
3590
3591 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3592 std::defer_lock);
3593
3594 for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3595 const ConversionCallbackFn &converter = indexedConverter.value();
3596 std::optional<LogicalResult> result = converter(typeOrValue, results);
3597 if (!result) {
3598 assert(results.size() == currentCount &&
3599 "failed type conversion should not change results");
3600 continue;
3601 }
3602 if (!isCacheable(indexedConverter.index()))
3603 return success();
3605 cacheWriteLock.lock();
3606 if (!succeeded(*result)) {
3607 assert(results.size() == currentCount &&
3608 "failed type conversion should not change results");
3609 cachedDirectConversions.try_emplace(t, nullptr);
3610 return failure();
3611 }
3612 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3613 if (newTypes.size() == 1)
3614 cachedDirectConversions.try_emplace(t, newTypes.front());
3615 else
3616 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3617 return success();
3618 }
3619 return failure();
3620}
3621
3622LogicalResult TypeConverter::convertType(Type t,
3623 SmallVectorImpl<Type> &results) const {
3624 return convertTypeImpl(t, results);
3625}
3626
3627LogicalResult TypeConverter::convertType(Value v,
3628 SmallVectorImpl<Type> &results) const {
3629 return convertTypeImpl(v, results);
3630}
3631
3632Type TypeConverter::convertType(Type t) const {
3633 // Use the multi-type result version to convert the type.
3634 SmallVector<Type, 1> results;
3635 if (failed(convertType(t, results)))
3636 return nullptr;
3637
3638 // Check to ensure that only one type was produced.
3639 return results.size() == 1 ? results.front() : nullptr;
3640}
3641
3642Type TypeConverter::convertType(Value v) const {
3643 // Use the multi-type result version to convert the type.
3644 SmallVector<Type, 1> results;
3645 if (failed(convertType(v, results)))
3646 return nullptr;
3647
3648 // Check to ensure that only one type was produced.
3649 return results.size() == 1 ? results.front() : nullptr;
3650}
3651
3652LogicalResult
3653TypeConverter::convertTypes(TypeRange types,
3654 SmallVectorImpl<Type> &results) const {
3655 for (Type type : types)
3656 if (failed(convertType(type, results)))
3657 return failure();
3658 return success();
3659}
3660
3661LogicalResult
3662TypeConverter::convertTypes(ValueRange values,
3663 SmallVectorImpl<Type> &results) const {
3664 for (Value value : values)
3665 if (failed(convertType(value, results)))
3666 return failure();
3667 return success();
3668}
3669
3670bool TypeConverter::isLegal(Type type) const {
3671 return convertType(type) == type;
3672}
3673
3674bool TypeConverter::isLegal(Value value) const {
3675 return convertType(value) == value.getType();
3676}
3677
3678bool TypeConverter::isLegal(Operation *op) const {
3679 return isLegal(op->getOperands()) && isLegal(op->getResults());
3680}
3681
3682bool TypeConverter::isLegal(Region *region) const {
3683 return llvm::all_of(
3684 *region, [this](Block &block) { return isLegal(block.getArguments()); });
3685}
3686
3687bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3688 if (!isLegal(ty.getInputs()))
3689 return false;
3690 if (!isLegal(ty.getResults()))
3691 return false;
3692 return true;
3693}
3694
3695LogicalResult
3696TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
3697 SignatureConversion &result) const {
3698 // Try to convert the given input type.
3699 SmallVector<Type, 1> convertedTypes;
3700 if (failed(convertType(type, convertedTypes)))
3701 return failure();
3702
3703 // If this argument is being dropped, there is nothing left to do.
3704 if (convertedTypes.empty())
3705 return success();
3706
3707 // Otherwise, add the new inputs.
3708 result.addInputs(inputNo, convertedTypes);
3709 return success();
3710}
3711LogicalResult
3712TypeConverter::convertSignatureArgs(TypeRange types,
3713 SignatureConversion &result,
3714 unsigned origInputOffset) const {
3715 for (unsigned i = 0, e = types.size(); i != e; ++i)
3716 if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3717 return failure();
3718 return success();
3719}
3720LogicalResult
3721TypeConverter::convertSignatureArg(unsigned inputNo, Value value,
3722 SignatureConversion &result) const {
3723 // Try to convert the given input type.
3724 SmallVector<Type, 1> convertedTypes;
3725 if (failed(convertType(value, convertedTypes)))
3726 return failure();
3727
3728 // If this argument is being dropped, there is nothing left to do.
3729 if (convertedTypes.empty())
3730 return success();
3731
3732 // Otherwise, add the new inputs.
3733 result.addInputs(inputNo, convertedTypes);
3734 return success();
3735}
3736LogicalResult
3737TypeConverter::convertSignatureArgs(ValueRange values,
3738 SignatureConversion &result,
3739 unsigned origInputOffset) const {
3740 for (unsigned i = 0, e = values.size(); i != e; ++i)
3741 if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
3742 return failure();
3743 return success();
3744}
3745
3746Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3747 Location loc, Type resultType,
3748 ValueRange inputs) const {
3749 for (const SourceMaterializationCallbackFn &fn :
3750 llvm::reverse(sourceMaterializations))
3751 if (Value result = fn(builder, resultType, inputs, loc))
3752 return result;
3753 return nullptr;
3754}
3755
3756Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3757 Location loc, Type resultType,
3758 ValueRange inputs,
3759 Type originalType) const {
3760 SmallVector<Value> result = materializeTargetConversion(
3761 builder, loc, TypeRange(resultType), inputs, originalType);
3762 if (result.empty())
3763 return nullptr;
3764 assert(result.size() == 1 && "expected single result");
3765 return result.front();
3766}
3767
3768SmallVector<Value> TypeConverter::materializeTargetConversion(
3769 OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3770 Type originalType) const {
3771 for (const TargetMaterializationCallbackFn &fn :
3772 llvm::reverse(targetMaterializations)) {
3773 SmallVector<Value> result =
3774 fn(builder, resultTypes, inputs, loc, originalType);
3775 if (result.empty())
3776 continue;
3777 assert(TypeRange(ValueRange(result)) == resultTypes &&
3778 "callback produced incorrect number of values or values with "
3779 "incorrect types");
3780 return result;
3781 }
3782 return {};
3783}
3784
3785std::optional<TypeConverter::SignatureConversion>
3786TypeConverter::convertBlockSignature(Block *block) const {
3787 SignatureConversion conversion(block->getNumArguments());
3788 if (failed(convertSignatureArgs(block->getArguments(), conversion)))
3789 return std::nullopt;
3790 return conversion;
3791}
3792
3793//===----------------------------------------------------------------------===//
3794// Type attribute conversion
3795//===----------------------------------------------------------------------===//
3796TypeConverter::AttributeConversionResult
3797TypeConverter::AttributeConversionResult::result(Attribute attr) {
3798 return AttributeConversionResult(attr, resultTag);
3799}
3800
3801TypeConverter::AttributeConversionResult
3802TypeConverter::AttributeConversionResult::na() {
3803 return AttributeConversionResult(nullptr, naTag);
3804}
3805
3806TypeConverter::AttributeConversionResult
3807TypeConverter::AttributeConversionResult::abort() {
3808 return AttributeConversionResult(nullptr, abortTag);
3809}
3810
3811bool TypeConverter::AttributeConversionResult::hasResult() const {
3812 return impl.getInt() == resultTag;
3813}
3814
3815bool TypeConverter::AttributeConversionResult::isNa() const {
3816 return impl.getInt() == naTag;
3817}
3818
3819bool TypeConverter::AttributeConversionResult::isAbort() const {
3820 return impl.getInt() == abortTag;
3821}
3822
3823Attribute TypeConverter::AttributeConversionResult::getResult() const {
3824 assert(hasResult() && "Cannot get result from N/A or abort");
3825 return impl.getPointer();
3826}
3827
3828std::optional<Attribute>
3829TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
3830 for (const TypeAttributeConversionCallbackFn &fn :
3831 llvm::reverse(typeAttributeConversions)) {
3832 AttributeConversionResult res = fn(type, attr);
3833 if (res.hasResult())
3834 return res.getResult();
3835 if (res.isAbort())
3836 return std::nullopt;
3837 }
3838 return std::nullopt;
3839}
3840
3841//===----------------------------------------------------------------------===//
3842// FunctionOpInterfaceSignatureConversion
3843//===----------------------------------------------------------------------===//
3844
3845static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3846 const TypeConverter &typeConverter,
3847 ConversionPatternRewriter &rewriter) {
3848 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3849 if (!type)
3850 return failure();
3851
3852 // Convert the function signature (inputs and results).
3853 TypeConverter::SignatureConversion funcConversion(type.getNumInputs());
3854 SmallVector<Type, 1> newResults;
3855 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
3856 funcConversion)) ||
3857 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3858 return failure();
3859
3860 // If the function has a body, apply a separate signature conversion to the
3861 // entry block. Some function ops (e.g., gpu.func) have extra block arguments
3862 // beyond the function type inputs (e.g., workgroup memory arguments that are
3863 // not part of the public signature). Use a distinct conversion sized for all
3864 // entry block arguments so that applySignatureConversion does not access
3865 // out-of-bounds mappings.
3866 if (!funcOp.getFunctionBody().empty()) {
3867 Block *entryBlock = &funcOp.getFunctionBody().front();
3868 unsigned numEntryBlockArgs = entryBlock->getNumArguments();
3869 unsigned numFuncTypeInputs = type.getNumInputs();
3870 TypeConverter::SignatureConversion blockConversion(numEntryBlockArgs);
3871 // Convert the function-type inputs the same way as for the function type.
3872 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
3873 blockConversion)))
3874 return failure();
3875 // Add identity mappings for extra block args beyond the function type
3876 // inputs. These arguments are preserved as-is.
3877 for (unsigned i = numFuncTypeInputs; i < numEntryBlockArgs; ++i)
3878 blockConversion.addInputs(i, entryBlock->getArgument(i).getType());
3879 rewriter.applySignatureConversion(entryBlock, blockConversion,
3880 &typeConverter);
3881 }
3882
3883 auto newType = FunctionType::get(
3884 rewriter.getContext(), funcConversion.getConvertedTypes(), newResults);
3885
3886 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3887
3888 return success();
3889}
3890
3891/// Create a default conversion pattern that rewrites the type signature of a
3892/// FunctionOpInterface op. This only supports ops which use FunctionType to
3893/// represent their type.
3894namespace {
3895struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3896 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3897 MLIRContext *ctx,
3898 const TypeConverter &converter,
3899 PatternBenefit benefit)
3900 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3901
3902 LogicalResult
3903 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3904 ConversionPatternRewriter &rewriter) const override {
3905 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3906 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3907 }
3908};
3909
3910struct AnyFunctionOpInterfaceSignatureConversion
3911 : public OpInterfaceConversionPattern<FunctionOpInterface> {
3912 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3913
3914 LogicalResult
3915 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3916 ConversionPatternRewriter &rewriter) const override {
3917 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3918 }
3919};
3920} // namespace
3921
3922FailureOr<Operation *>
3923mlir::convertOpResultTypes(Operation *op, ValueRange operands,
3924 const TypeConverter &converter,
3925 ConversionPatternRewriter &rewriter) {
3926 assert(op && "Invalid op");
3927 Location loc = op->getLoc();
3928 if (converter.isLegal(op))
3929 return rewriter.notifyMatchFailure(loc, "op already legal");
3930
3931 OperationState newOp(loc, op->getName());
3932 newOp.addOperands(operands);
3933
3934 SmallVector<Type> newResultTypes;
3935 if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
3936 return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3937
3938 newOp.addTypes(newResultTypes);
3939 newOp.addAttributes(op->getAttrs());
3940 return rewriter.create(newOp);
3941}
3942
3943void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3944 StringRef functionLikeOpName, RewritePatternSet &patterns,
3945 const TypeConverter &converter, PatternBenefit benefit) {
3946 patterns.add<FunctionOpInterfaceSignatureConversion>(
3947 functionLikeOpName, patterns.getContext(), converter, benefit);
3948}
3949
3950void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3951 RewritePatternSet &patterns, const TypeConverter &converter,
3952 PatternBenefit benefit) {
3953 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3954 converter, patterns.getContext(), benefit);
3955}
3956
3957//===----------------------------------------------------------------------===//
3958// ConversionTarget
3959//===----------------------------------------------------------------------===//
3960
3961void ConversionTarget::setOpAction(OperationName op,
3962 LegalizationAction action) {
3963 legalOperations[op].action = action;
3964}
3965
3966void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3967 LegalizationAction action) {
3968 for (StringRef dialect : dialectNames)
3969 legalDialects[dialect] = action;
3970}
3971
3972auto ConversionTarget::getOpAction(OperationName op) const
3973 -> std::optional<LegalizationAction> {
3974 std::optional<LegalizationInfo> info = getOpInfo(op);
3975 return info ? info->action : std::optional<LegalizationAction>();
3976}
3977
3978auto ConversionTarget::isLegal(Operation *op) const
3979 -> std::optional<LegalOpDetails> {
3980 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3981 if (!info)
3982 return std::nullopt;
3983
3984 // Returns true if this operation instance is known to be legal.
3985 auto isOpLegal = [&] {
3986 // Handle dynamic legality either with the provided legality function.
3987 if (info->action == LegalizationAction::Dynamic) {
3988 std::optional<bool> result = info->legalityFn(op);
3989 if (result)
3990 return *result;
3991 }
3992
3993 // Otherwise, the operation is only legal if it was marked 'Legal'.
3994 return info->action == LegalizationAction::Legal;
3995 };
3996 if (!isOpLegal())
3997 return std::nullopt;
3998
3999 // This operation is legal, compute any additional legality information.
4000 LegalOpDetails legalityDetails;
4001 if (info->isRecursivelyLegal) {
4002 auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
4003 if (legalityFnIt != opRecursiveLegalityFns.end()) {
4004 legalityDetails.isRecursivelyLegal =
4005 legalityFnIt->second(op).value_or(true);
4006 } else {
4007 legalityDetails.isRecursivelyLegal = true;
4008 }
4009 }
4010 return legalityDetails;
4011}
4012
4013bool ConversionTarget::isIllegal(Operation *op) const {
4014 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
4015 if (!info)
4016 return false;
4017
4018 if (info->action == LegalizationAction::Dynamic) {
4019 std::optional<bool> result = info->legalityFn(op);
4020 if (!result)
4021 return false;
4022
4023 return !(*result);
4024 }
4025
4026 return info->action == LegalizationAction::Illegal;
4027}
4028
4029static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
4030 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
4031 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
4032 if (!oldCallback)
4033 return newCallback;
4034
4035 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
4036 Operation *op) -> std::optional<bool> {
4037 if (std::optional<bool> result = newCl(op))
4038 return *result;
4039
4040 return oldCl(op);
4041 };
4042 return chain;
4043}
4044
4045void ConversionTarget::setLegalityCallback(
4046 OperationName name, const DynamicLegalityCallbackFn &callback) {
4047 assert(callback && "expected valid legality callback");
4048 auto *infoIt = legalOperations.find(name);
4049 assert(infoIt != legalOperations.end() &&
4050 infoIt->second.action == LegalizationAction::Dynamic &&
4051 "expected operation to already be marked as dynamically legal");
4052 infoIt->second.legalityFn =
4053 composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
4054}
4055
4056void ConversionTarget::markOpRecursivelyLegal(
4057 OperationName name, const DynamicLegalityCallbackFn &callback) {
4058 auto *infoIt = legalOperations.find(name);
4059 assert(infoIt != legalOperations.end() &&
4060 infoIt->second.action != LegalizationAction::Illegal &&
4061 "expected operation to already be marked as legal");
4062 infoIt->second.isRecursivelyLegal = true;
4063 if (callback)
4064 opRecursiveLegalityFns[name] = composeLegalityCallbacks(
4065 std::move(opRecursiveLegalityFns[name]), callback);
4066 else
4067 opRecursiveLegalityFns.erase(name);
4068}
4069
4070void ConversionTarget::setLegalityCallback(
4071 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
4072 assert(callback && "expected valid legality callback");
4073 for (StringRef dialect : dialects)
4074 dialectLegalityFns[dialect] = composeLegalityCallbacks(
4075 std::move(dialectLegalityFns[dialect]), callback);
4076}
4077
4078void ConversionTarget::setLegalityCallback(
4079 const DynamicLegalityCallbackFn &callback) {
4080 assert(callback && "expected valid legality callback");
4081 unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
4082}
4083
4084auto ConversionTarget::getOpInfo(OperationName op) const
4085 -> std::optional<LegalizationInfo> {
4086 // Check for info for this specific operation.
4087 const auto *it = legalOperations.find(op);
4088 if (it != legalOperations.end())
4089 return it->second;
4090 // Check for info for the parent dialect.
4091 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4092 if (dialectIt != legalDialects.end()) {
4093 DynamicLegalityCallbackFn callback;
4094 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4095 if (dialectFn != dialectLegalityFns.end())
4096 callback = dialectFn->second;
4097 return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
4098 callback};
4099 }
4100 // Otherwise, check if we mark unknown operations as dynamic.
4101 if (unknownLegalityFn)
4102 return LegalizationInfo{LegalizationAction::Dynamic,
4103 /*isRecursivelyLegal=*/false, unknownLegalityFn};
4104 return std::nullopt;
4105}
4106
4107#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4108//===----------------------------------------------------------------------===//
4109// PDL Configuration
4110//===----------------------------------------------------------------------===//
4111
4112void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4113 auto &rewriterImpl =
4114 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4115 rewriterImpl.currentTypeConverter = getTypeConverter();
4116}
4117
4118void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4119 auto &rewriterImpl =
4120 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4121 rewriterImpl.currentTypeConverter = nullptr;
4122}
4123
4124/// Remap the given value using the rewriter and the type converter in the
4125/// provided config.
4126static FailureOr<SmallVector<Value>>
4127pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
4128 SmallVector<Value> mappedValues;
4129 if (failed(rewriter.getRemappedValues(values, mappedValues)))
4130 return failure();
4131 return std::move(mappedValues);
4132}
4133
4134void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
4135 patterns.getPDLPatterns().registerRewriteFunction(
4136 "convertValue",
4137 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4138 auto results = pdllConvertValues(
4139 static_cast<ConversionPatternRewriter &>(rewriter), value);
4140 if (failed(results))
4141 return failure();
4142 return results->front();
4143 });
4144 patterns.getPDLPatterns().registerRewriteFunction(
4145 "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
4146 return pdllConvertValues(
4147 static_cast<ConversionPatternRewriter &>(rewriter), values);
4148 });
4149 patterns.getPDLPatterns().registerRewriteFunction(
4150 "convertType",
4151 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4152 auto &rewriterImpl =
4153 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4154 if (const TypeConverter *converter =
4155 rewriterImpl.currentTypeConverter) {
4156 if (Type newType = converter->convertType(type))
4157 return newType;
4158 return failure();
4159 }
4160 return type;
4161 });
4162 patterns.getPDLPatterns().registerRewriteFunction(
4163 "convertTypes",
4164 [](PatternRewriter &rewriter,
4165 TypeRange types) -> FailureOr<SmallVector<Type>> {
4166 auto &rewriterImpl =
4167 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4168 const TypeConverter *converter = rewriterImpl.currentTypeConverter;
4169 if (!converter)
4170 return SmallVector<Type>(types);
4171
4172 SmallVector<Type> remappedTypes;
4173 if (failed(converter->convertTypes(types, remappedTypes)))
4174 return failure();
4175 return std::move(remappedTypes);
4176 });
4177}
4178#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
4179
4180//===----------------------------------------------------------------------===//
4181// Op Conversion Entry Points
4182//===----------------------------------------------------------------------===//
4183
4184/// This is the type of Action that is dispatched when a conversion is applied.
4186 : public tracing::ActionImpl<ApplyConversionAction> {
4187public:
4190 static constexpr StringLiteral tag = "apply-conversion";
4191 static constexpr StringLiteral desc =
4192 "Encapsulate the application of a dialect conversion";
4193
4194 void print(raw_ostream &os) const override { os << tag; }
4195};
4196
4198 const ConversionTarget &target,
4199 const FrozenRewritePatternSet &patterns,
4200 ConversionConfig config,
4201 OpConversionMode mode) {
4202 if (ops.empty())
4203 return success();
4204 MLIRContext *ctx = ops.front()->getContext();
4205 LogicalResult status = success();
4206 SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
4208 [&] {
4209 OperationConverter opConverter(ops.front()->getContext(), target,
4210 patterns, config, mode);
4211 status = opConverter.applyConversion(ops);
4212 },
4213 irUnits);
4214 return status;
4215}
4216
4217//===----------------------------------------------------------------------===//
4218// Partial Conversion
4219//===----------------------------------------------------------------------===//
4220
4221LogicalResult mlir::applyPartialConversion(
4222 ArrayRef<Operation *> ops, const ConversionTarget &target,
4223 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4224 return applyConversion(ops, target, patterns, config,
4225 OpConversionMode::Partial);
4226}
4227LogicalResult
4228mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
4229 const FrozenRewritePatternSet &patterns,
4230 ConversionConfig config) {
4231 return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
4232}
4233
4234//===----------------------------------------------------------------------===//
4235// Full Conversion
4236//===----------------------------------------------------------------------===//
4237
4238LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4239 const ConversionTarget &target,
4240 const FrozenRewritePatternSet &patterns,
4241 ConversionConfig config) {
4242 return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
4243}
4244LogicalResult mlir::applyFullConversion(Operation *op,
4245 const ConversionTarget &target,
4246 const FrozenRewritePatternSet &patterns,
4247 ConversionConfig config) {
4248 return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
4249}
4250
4251//===----------------------------------------------------------------------===//
4252// Analysis Conversion
4253//===----------------------------------------------------------------------===//
4254
4255/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
4256/// op is a top-level module op (which is expected to be isolated from above),
4257/// return that op.
4259 // Check if there is a top-level operation within `ops`. If so, return that
4260 // op.
4261 for (Operation *op : ops) {
4262 if (!op->getParentOp()) {
4263#ifndef NDEBUG
4264 assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
4265 "expected top-level op to be isolated from above");
4266 for (Operation *other : ops)
4267 assert(op->isAncestor(other) &&
4268 "expected ops to have a common ancestor");
4269#endif // NDEBUG
4270 return op;
4271 }
4272 }
4273
4274 // No top-level op. Find a common ancestor.
4275 Operation *commonAncestor =
4276 ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
4277 for (Operation *op : ops.drop_front()) {
4278 while (!commonAncestor->isProperAncestor(op)) {
4279 commonAncestor =
4281 assert(commonAncestor &&
4282 "expected to find a common isolated from above ancestor");
4283 }
4284 }
4285
4286 return commonAncestor;
4287}
4288
4289LogicalResult mlir::applyAnalysisConversion(
4290 ArrayRef<Operation *> ops, ConversionTarget &target,
4291 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4292#ifndef NDEBUG
4293 if (config.legalizableOps)
4294 assert(config.legalizableOps->empty() && "expected empty set");
4295#endif // NDEBUG
4296
4297 // Clone closted common ancestor that is isolated from above.
4298 Operation *commonAncestor = findCommonAncestor(ops);
4299 IRMapping mapping;
4300 Operation *clonedAncestor = commonAncestor->clone(mapping);
4301 // Compute inverse IR mapping.
4302 DenseMap<Operation *, Operation *> inverseOperationMap;
4303 for (auto &it : mapping.getOperationMap())
4304 inverseOperationMap[it.second] = it.first;
4305
4306 // Convert the cloned operations. The original IR will remain unchanged.
4307 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4308 ops, [&](Operation *op) { return mapping.lookup(op); });
4309 LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
4310 OpConversionMode::Analysis);
4311
4312 // Remap `legalizableOps`, so that they point to the original ops and not the
4313 // cloned ops.
4314 if (config.legalizableOps) {
4315 DenseSet<Operation *> originalLegalizableOps;
4316 for (Operation *op : *config.legalizableOps)
4317 originalLegalizableOps.insert(inverseOperationMap[op]);
4318 *config.legalizableOps = std::move(originalLegalizableOps);
4319 }
4320
4321 // Erase the cloned IR.
4322 clonedAncestor->erase();
4323 return status;
4324}
4325
4326LogicalResult
4327mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
4328 const FrozenRewritePatternSet &patterns,
4329 ConversionConfig config) {
4330 return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
4331}
return success()
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
SmallVector< Value, 2 > ValueVector
A vector of SSA values, optimized for the most common case of one or two values.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
#define DEBUG_TYPE
b getContext())
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition SCCP.cpp:67
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Definition Value.h:306
Location getLoc() const
Return the location for this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:318
BlockArgListType getArguments()
Definition Block.h:97
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
UnitAttr getUnitAttr()
Definition Builders.cpp:102
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
MLIRContext * context
Definition Builders.h:204
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
Definition Dominance.h:143
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:161
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition Builders.h:329
Block::iterator getPoint() const
Definition Builders.h:342
bool isSet() const
Returns true if this insert point is set.
Definition Builders.h:339
Block * getBlock() const
Definition Builders.h:341
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition Builders.cpp:478
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:426
This class represents an operand of an operation.
Definition Value.h:254
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition Value.h:454
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1142
type_range getTypes() const
void destroyOpProperties(PropertyRef properties) const
This hooks destroy the op properties.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
TypeID getOpPropertiesTypeID() const
Return the TypeID of the op properties.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
PropertyRef getPropertiesStorage()
Return a generic (but typed) reference to the property type storage.
Definition Operation.h:926
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:243
void copyProperties(PropertyRef rhs)
Copy properties from an existing other properties object.
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:877
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:774
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:859
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:585
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:273
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:699
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
OperandRange operand_range
Definition Operation.h:396
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:702
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:288
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
result_range getResults()
Definition Operation.h:440
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
CRTP Implementation of an action.
Definition Action.h:76
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Definition Action.h:66
AttrTypeReplacer.
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
static void reconcileUnrealizedCasts(const llvm::MapVector< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
This iterator enumerates elements according to their dominance relationship.
Definition Iterators.h:52
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition Builders.h:310
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition Builders.h:300
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
const ConversionTarget & getTarget()
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, bool isRecursiveLegalization=false)
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, Fn onFailure, bool isRecursiveLegalization=false)
Legalizes the given operations (and their nested operations) to the conversion target.
LogicalResult applyConversion(ArrayRef< Operation * > ops)
Applies the conversion to the given operations (and their nested operations).
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
llvm::MapVector< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.