MLIR 23.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/Config/mlir-config.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Iterators.h"
17#include "mlir/IR/Operation.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/SaveAndRestore.h"
27#include "llvm/Support/ScopedPrinter.h"
28#include <optional>
29#include <utility>
30
31using namespace mlir;
32using namespace mlir::detail;
33
34#define DEBUG_TYPE "dialect-conversion"
35
36/// A utility function to log a successful result for the given reason.
37template <typename... Args>
38static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
39 LLVM_DEBUG({
40 os.unindent();
41 os.startLine() << "} -> SUCCESS";
42 if (!fmt.empty())
43 os.getOStream() << " : "
44 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
45 os.getOStream() << "\n";
46 });
47}
48
49/// A utility function to log a failure result for the given reason.
50template <typename... Args>
51static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
52 LLVM_DEBUG({
53 os.unindent();
54 os.startLine() << "} -> FAILURE : "
55 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
56 << "\n";
57 });
58}
59
60/// Helper function that computes an insertion point where the given value is
61/// defined and can be used without a dominance violation.
63 Block *insertBlock = value.getParentBlock();
64 Block::iterator insertPt = insertBlock->begin();
65 if (OpResult inputRes = dyn_cast<OpResult>(value))
66 insertPt = ++inputRes.getOwner()->getIterator();
67 return OpBuilder::InsertPoint(insertBlock, insertPt);
68}
69
70/// Helper function that computes an insertion point where the given values are
71/// defined and can be used without a dominance violation.
73 assert(!vals.empty() && "expected at least one value");
74 DominanceInfo domInfo;
76 for (Value v : vals.drop_front()) {
77 // Choose the "later" insertion point.
79 if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
80 nextPt.getPoint())) {
81 // pt is before nextPt => choose nextPt.
82 pt = nextPt;
83 } else {
84#ifndef NDEBUG
85 // nextPt should be before pt => choose pt.
86 // If pt, nextPt are no dominance relationship, then there is no valid
87 // insertion point at which all given values are defined.
88 bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
89 pt.getBlock(), pt.getPoint());
90 assert(dom && "unable to find valid insertion point");
91#endif // NDEBUG
92 }
93 }
94 return pt;
95}
96
97namespace {
98enum OpConversionMode {
99 /// In this mode, the conversion will ignore failed conversions to allow
100 /// illegal operations to co-exist in the IR.
101 Partial,
102
103 /// In this mode, all operations must be legal for the given target for the
104 /// conversion to succeed.
105 Full,
106
107 /// In this mode, operations are analyzed for legality. No actual rewrites are
108 /// applied to the operations on success.
109 Analysis,
110};
111} // namespace
112
113//===----------------------------------------------------------------------===//
114// ConversionValueMapping
115//===----------------------------------------------------------------------===//
116
117/// A vector of SSA values, optimized for the most common case of one or two
118/// values. Inline size chosen empirically based on compilation profiling.
119/// Profiled: 2.3M calls, avg=2.0+-0.3. N=2 covers 98% of cases inline.
121
122namespace {
123
124/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
125struct ValueVectorMapInfo {
126 static ValueVector getEmptyKey() { return ValueVector{Value()}; }
127 static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
128 static ::llvm::hash_code getHashValue(const ValueVector &val) {
129 return ::llvm::hash_combine_range(val);
130 }
131 static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
132 return LHS == RHS;
133 }
134};
135
136/// This class wraps a IRMapping to provide recursive lookup
137/// functionality, i.e. we will traverse if the mapped value also has a mapping.
138struct ConversionValueMapping {
139 /// Return "true" if an SSA value is mapped to the given value. May return
140 /// false positives.
141 bool isMappedTo(Value value) const { return mappedTo.contains(value); }
142
143 /// Lookup a value in the mapping.
144 ValueVector lookup(const ValueVector &from) const;
145
146 template <typename T>
147 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
148
149 /// Map a value vector to the one provided.
150 template <typename OldVal, typename NewVal>
151 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
152 map(OldVal &&oldVal, NewVal &&newVal) {
153 LLVM_DEBUG({
154 ValueVector next(newVal);
155 while (true) {
156 assert(next != oldVal && "inserting cyclic mapping");
157 auto it = mapping.find(next);
158 if (it == mapping.end())
159 break;
160 next = it->second;
161 }
162 });
163 mappedTo.insert_range(newVal);
164
165 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
166 }
167
168 /// Map a value vector or single value to the one provided.
169 template <typename OldVal, typename NewVal>
170 std::enable_if_t<!IsValueVector<OldVal>::value ||
171 !IsValueVector<NewVal>::value>
172 map(OldVal &&oldVal, NewVal &&newVal) {
173 if constexpr (IsValueVector<OldVal>{}) {
174 map(std::forward<OldVal>(oldVal), ValueVector{newVal});
175 } else if constexpr (IsValueVector<NewVal>{}) {
176 map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
177 } else {
178 map(ValueVector{oldVal}, ValueVector{newVal});
179 }
180 }
181
182 void map(Value oldVal, SmallVector<Value> &&newVal) {
183 map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
184 }
185
186 /// Drop the last mapping for the given values.
187 void erase(const ValueVector &value) { mapping.erase(value); }
188
189private:
190 /// Current value mappings.
192
193 /// All SSA values that are mapped to. May contain false positives.
194 DenseSet<Value> mappedTo;
195};
196} // namespace
197
198/// Marker attribute for pure type conversions. I.e., mappings whose only
199/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
200/// the replacement values of a "replaceOp" call, etc., are not pure type
201/// conversions.)
202static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
203
204/// Return the operation that defines all values in the vector. Return nullptr
205/// if the values are not defined by the same operation.
207 assert(!values.empty() && "expected non-empty value vector");
208 Operation *op = values.front().getDefiningOp();
209 for (Value v : llvm::drop_begin(values)) {
210 if (v.getDefiningOp() != op)
211 return nullptr;
212 }
213 return op;
214}
215
216/// A vector of values is a pure type conversion if all values are defined by
217/// the same operation and the operation has the `kPureTypeConversionMarker`
218/// attribute.
219static bool isPureTypeConversion(const ValueVector &values) {
220 assert(!values.empty() && "expected non-empty value vector");
221 Operation *op = getCommonDefiningOp(values);
222 return op && op->hasAttr(kPureTypeConversionMarker);
223}
224
225ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
226 auto it = mapping.find(from);
227 if (it == mapping.end()) {
228 // No mapping found: The lookup stops here.
229 return {};
230 }
231 return it->second;
232}
233
234//===----------------------------------------------------------------------===//
235// Rewriter and Translation State
236//===----------------------------------------------------------------------===//
237namespace {
238/// This class contains a snapshot of the current conversion rewriter state.
239/// This is useful when saving and undoing a set of rewrites.
240struct RewriterState {
241 RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
242 unsigned numReplacedOps)
243 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
244 numReplacedOps(numReplacedOps) {}
245
246 /// The current number of rewrites performed.
247 unsigned numRewrites;
248
249 /// The current number of ignored operations.
250 unsigned numIgnoredOperations;
251
252 /// The current number of replaced ops that are scheduled for erasure.
253 unsigned numReplacedOps;
254};
255
256//===----------------------------------------------------------------------===//
257// IR rewrites
258//===----------------------------------------------------------------------===//
259
260static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
261
262/// Notify the listener that the given block and its contents are being erased.
263static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
264 for (Operation &op : b)
265 notifyIRErased(listener, op);
266 listener->notifyBlockErased(&b);
267}
268
269/// Notify the listener that the given operation and its contents are being
270/// erased.
271static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
272 for (Region &r : op.getRegions()) {
273 for (Block &b : r) {
274 notifyIRErased(listener, b);
275 }
276 }
277 listener->notifyOperationErased(&op);
278}
279
280/// An IR rewrite that can be committed (upon success) or rolled back (upon
281/// failure).
282///
283/// The dialect conversion keeps track of IR modifications (requested by the
284/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
285/// are directly applied to the IR as the rewriter API is used, some are applied
286/// partially, and some are delayed until the `IRRewrite` objects are committed.
287class IRRewrite {
288public:
289 /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
290 /// Enum values are ordered, so that they can be used in `classof`: first all
291 /// block rewrites, then all operation rewrites.
292 enum class Kind {
293 // Block rewrites
294 CreateBlock,
295 EraseBlock,
296 InlineBlock,
297 MoveBlock,
298 BlockTypeConversion,
299 // Operation rewrites
300 MoveOperation,
301 ModifyOperation,
302 ReplaceOperation,
303 CreateOperation,
304 UnresolvedMaterialization,
305 // Value rewrites
306 ReplaceValue
307 };
308
309 virtual ~IRRewrite() = default;
310
311 /// Roll back the rewrite. Operations may be erased during rollback.
312 virtual void rollback() = 0;
313
314 /// Commit the rewrite. At this point, it is certain that the dialect
315 /// conversion will succeed. All IR modifications, except for operation/block
316 /// erasure, must be performed through the given rewriter.
317 ///
318 /// Instead of erasing operations/blocks, they should merely be unlinked
319 /// commit phase and finally be erased during the cleanup phase. This is
320 /// because internal dialect conversion state (such as `mapping`) may still
321 /// be using them.
322 ///
323 /// Any IR modification that was already performed before the commit phase
324 /// (e.g., insertion of an op) must be communicated to the listener that may
325 /// be attached to the given rewriter.
326 virtual void commit(RewriterBase &rewriter) {}
327
328 /// Cleanup operations/blocks. Cleanup is called after commit.
329 virtual void cleanup(RewriterBase &rewriter) {}
330
331 Kind getKind() const { return kind; }
332
333 static bool classof(const IRRewrite *rewrite) { return true; }
334
335protected:
336 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
337 : kind(kind), rewriterImpl(rewriterImpl) {}
338
339 const ConversionConfig &getConfig() const;
340
341 const Kind kind;
342 ConversionPatternRewriterImpl &rewriterImpl;
343};
344
345/// A block rewrite.
346class BlockRewrite : public IRRewrite {
347public:
348 /// Return the block that this rewrite operates on.
349 Block *getBlock() const { return block; }
350
351 static bool classof(const IRRewrite *rewrite) {
352 return rewrite->getKind() >= Kind::CreateBlock &&
353 rewrite->getKind() <= Kind::BlockTypeConversion;
354 }
355
356protected:
357 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
358 Block *block)
359 : IRRewrite(kind, rewriterImpl), block(block) {}
360
361 // The block that this rewrite operates on.
362 Block *block;
363};
364
365/// A value rewrite.
366class ValueRewrite : public IRRewrite {
367public:
368 /// Return the value that this rewrite operates on.
369 Value getValue() const { return value; }
370
371 static bool classof(const IRRewrite *rewrite) {
372 return rewrite->getKind() == Kind::ReplaceValue;
373 }
374
375protected:
376 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
377 Value value)
378 : IRRewrite(kind, rewriterImpl), value(value) {}
379
380 // The value that this rewrite operates on.
381 Value value;
382};
383
384/// Creation of a block. Block creations are immediately reflected in the IR.
385/// There is no extra work to commit the rewrite. During rollback, the newly
386/// created block is erased.
387class CreateBlockRewrite : public BlockRewrite {
388public:
389 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
390 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
391
392 static bool classof(const IRRewrite *rewrite) {
393 return rewrite->getKind() == Kind::CreateBlock;
394 }
395
396 void commit(RewriterBase &rewriter) override {
397 // The block was already created and inserted. Just inform the listener.
398 if (auto *listener = rewriter.getListener())
399 listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
400 }
401
402 void rollback() override {
403 // Unlink all of the operations within this block, they will be deleted
404 // separately.
405 auto &blockOps = block->getOperations();
406 while (!blockOps.empty())
407 blockOps.remove(blockOps.begin());
408 block->dropAllUses();
409 if (block->getParent())
410 block->erase();
411 else
412 delete block;
413 }
414};
415
416/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
417/// blocks are immediately unlinked, but only erased during cleanup. This makes
418/// it easier to rollback a block erasure: the block is simply inserted into its
419/// original location.
420class EraseBlockRewrite : public BlockRewrite {
421public:
422 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
423 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
424 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
425
426 static bool classof(const IRRewrite *rewrite) {
427 return rewrite->getKind() == Kind::EraseBlock;
428 }
429
430 ~EraseBlockRewrite() override {
431 assert(!block &&
432 "rewrite was neither rolled back nor committed/cleaned up");
433 }
434
435 void rollback() override {
436 // The block (owned by this rewrite) was not actually erased yet. It was
437 // just unlinked. Put it back into its original position.
438 assert(block && "expected block");
439 auto &blockList = region->getBlocks();
440 Region::iterator before = insertBeforeBlock
441 ? Region::iterator(insertBeforeBlock)
442 : blockList.end();
443 blockList.insert(before, block);
444 block = nullptr;
445 }
446
447 void commit(RewriterBase &rewriter) override {
448 assert(block && "expected block");
449
450 // Notify the listener that the block and its contents are being erased.
451 if (auto *listener =
452 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
453 notifyIRErased(listener, *block);
454 }
455
456 void cleanup(RewriterBase &rewriter) override {
457 // Erase the contents of the block.
458 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
459 rewriter.eraseOp(&op);
460 assert(block->empty() && "expected empty block");
461
462 // Erase the block.
463 block->dropAllDefinedValueUses();
464 delete block;
465 block = nullptr;
466 }
467
468private:
469 // The region in which this block was previously contained.
470 Region *region;
471
472 // The original successor of this block before it was unlinked. "nullptr" if
473 // this block was the only block in the region.
474 Block *insertBeforeBlock;
475};
476
477/// Inlining of a block. This rewrite is immediately reflected in the IR.
478/// Note: This rewrite represents only the inlining of the operations. The
479/// erasure of the inlined block is a separate rewrite.
480class InlineBlockRewrite : public BlockRewrite {
481public:
482 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
483 Block *sourceBlock, Block::iterator before)
484 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
485 sourceBlock(sourceBlock),
486 firstInlinedInst(sourceBlock->empty() ? nullptr
487 : &sourceBlock->front()),
488 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
489 // If a listener is attached to the dialect conversion, ops must be moved
490 // one-by-one. When they are moved in bulk, notifications cannot be sent
491 // because the ops that used to be in the source block at the time of the
492 // inlining (before the "commit" phase) are unknown at the time when
493 // notifications are sent (which is during the "commit" phase).
494 assert(!getConfig().listener &&
495 "InlineBlockRewrite not supported if listener is attached");
496 }
497
498 static bool classof(const IRRewrite *rewrite) {
499 return rewrite->getKind() == Kind::InlineBlock;
500 }
501
502 void rollback() override {
503 // Put the operations from the destination block (owned by the rewrite)
504 // back into the source block.
505 if (firstInlinedInst) {
506 assert(lastInlinedInst && "expected operation");
507 sourceBlock->getOperations().splice(sourceBlock->begin(),
508 block->getOperations(),
509 Block::iterator(firstInlinedInst),
510 ++Block::iterator(lastInlinedInst));
511 }
512 }
513
514private:
515 // The block that originally contained the operations.
516 Block *sourceBlock;
517
518 // The first inlined operation.
519 Operation *firstInlinedInst;
520
521 // The last inlined operation.
522 Operation *lastInlinedInst;
523};
524
525/// Moving of a block. This rewrite is immediately reflected in the IR.
526class MoveBlockRewrite : public BlockRewrite {
527public:
528 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
529 Region *previousRegion, Region::iterator previousIt)
530 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
531 region(previousRegion),
532 insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
533 : &*previousIt) {}
534
535 static bool classof(const IRRewrite *rewrite) {
536 return rewrite->getKind() == Kind::MoveBlock;
537 }
538
539 void commit(RewriterBase &rewriter) override {
540 // The block was already moved. Just inform the listener.
541 if (auto *listener = rewriter.getListener()) {
542 // Note: `previousIt` cannot be passed because this is a delayed
543 // notification and iterators into past IR state cannot be represented.
544 listener->notifyBlockInserted(block, /*previous=*/region,
545 /*previousIt=*/{});
546 }
547 }
548
549 void rollback() override {
550 // Move the block back to its original position.
551 Region::iterator before =
552 insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
553 if (Region *currentParent = block->getParent()) {
554 // Block is still in a region, use cheap splice to move it back.
555 region->getBlocks().splice(before, currentParent->getBlocks(), block);
556 return;
557 }
558 // Block was orphaned by a prior rollback, can't splice.
559 region->getBlocks().insert(before, block);
560 }
561
562private:
563 // The region in which this block was previously contained.
564 Region *region;
565
566 // The original successor of this block before it was moved. "nullptr" if
567 // this block was the only block in the region.
568 Block *insertBeforeBlock;
569};
570
571/// Block type conversion. This rewrite is partially reflected in the IR.
572class BlockTypeConversionRewrite : public BlockRewrite {
573public:
574 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
575 Block *origBlock, Block *newBlock)
576 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
577 newBlock(newBlock) {}
578
579 static bool classof(const IRRewrite *rewrite) {
580 return rewrite->getKind() == Kind::BlockTypeConversion;
581 }
582
583 Block *getOrigBlock() const { return block; }
584
585 Block *getNewBlock() const { return newBlock; }
586
587 void commit(RewriterBase &rewriter) override;
588
589 void rollback() override;
590
591private:
592 /// The new block that was created as part of this signature conversion.
593 Block *newBlock;
594};
595
596/// Replacing a value. This rewrite is not immediately reflected in the
597/// IR. An internal IR mapping is updated, but the actual replacement is delayed
598/// until the rewrite is committed.
599class ReplaceValueRewrite : public ValueRewrite {
600public:
601 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
602 const TypeConverter *converter)
603 : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
604 converter(converter) {}
605
606 static bool classof(const IRRewrite *rewrite) {
607 return rewrite->getKind() == Kind::ReplaceValue;
608 }
609
610 void commit(RewriterBase &rewriter) override;
611
612 void rollback() override;
613
614private:
615 /// The current type converter when the value was replaced.
616 const TypeConverter *converter;
617};
618
619/// An operation rewrite.
620class OperationRewrite : public IRRewrite {
621public:
622 /// Return the operation that this rewrite operates on.
623 Operation *getOperation() const { return op; }
624
625 static bool classof(const IRRewrite *rewrite) {
626 return rewrite->getKind() >= Kind::MoveOperation &&
627 rewrite->getKind() <= Kind::UnresolvedMaterialization;
628 }
629
630protected:
631 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
632 Operation *op)
633 : IRRewrite(kind, rewriterImpl), op(op) {}
634
635 // The operation that this rewrite operates on.
636 Operation *op;
637};
638
639/// Moving of an operation. This rewrite is immediately reflected in the IR.
640class MoveOperationRewrite : public OperationRewrite {
641public:
642 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
643 Operation *op, OpBuilder::InsertPoint previous)
644 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
645 block(previous.getBlock()),
646 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
647 ? nullptr
648 : &*previous.getPoint()) {}
649
650 static bool classof(const IRRewrite *rewrite) {
651 return rewrite->getKind() == Kind::MoveOperation;
652 }
653
654 void commit(RewriterBase &rewriter) override {
655 // The operation was already moved. Just inform the listener.
656 if (auto *listener = rewriter.getListener()) {
657 // Note: `previousIt` cannot be passed because this is a delayed
658 // notification and iterators into past IR state cannot be represented.
659 listener->notifyOperationInserted(
660 op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
661 /*insertPt=*/{}));
662 }
663 }
664
665 void rollback() override {
666 // Move the operation back to its original position.
667 Block::iterator before =
668 insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
669 block->getOperations().splice(before, op->getBlock()->getOperations(), op);
670 }
671
672private:
673 // The block in which this operation was previously contained.
674 Block *block;
675
676 // The original successor of this operation before it was moved. "nullptr"
677 // if this operation was the only operation in the region.
678 Operation *insertBeforeOp;
679};
680
681/// In-place modification of an op. This rewrite is immediately reflected in
682/// the IR. The previous state of the operation is stored in this object.
683class ModifyOperationRewrite : public OperationRewrite {
684public:
685 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
686 Operation *op)
687 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
688 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
689 operands(op->operand_begin(), op->operand_end()),
690 successors(op->successor_begin(), op->successor_end()) {
691 if (PropertyRef prop = op->getPropertiesStorage()) {
692 // Make a copy of the properties.
693 propertiesStorage = operator new(op->getPropertiesStorageSize());
694 PropertyRef propCopy(name.getOpPropertiesTypeID(), propertiesStorage);
695 name.initOpProperties(propCopy, /*init=*/prop);
696 }
697 }
698
699 static bool classof(const IRRewrite *rewrite) {
700 return rewrite->getKind() == Kind::ModifyOperation;
701 }
702
703 ~ModifyOperationRewrite() override {
704 assert(!propertiesStorage &&
705 "rewrite was neither committed nor rolled back");
706 }
707
708 void commit(RewriterBase &rewriter) override {
709 // Notify the listener that the operation was modified in-place.
710 if (auto *listener =
711 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
712 listener->notifyOperationModified(op);
713
714 if (propertiesStorage) {
715 PropertyRef propCopy(name.getOpPropertiesTypeID(), propertiesStorage);
716 // Note: The operation may have been erased in the mean time, so
717 // OperationName must be stored in this object.
718 name.destroyOpProperties(propCopy);
719 operator delete(propertiesStorage);
720 propertiesStorage = nullptr;
721 }
722 }
723
724 void rollback() override {
725 op->setLoc(loc);
726 op->setAttrs(attrs);
727 op->setOperands(operands);
728 for (const auto &it : llvm::enumerate(successors))
729 op->setSuccessor(it.value(), it.index());
730 if (propertiesStorage) {
731 PropertyRef propCopy(name.getOpPropertiesTypeID(), propertiesStorage);
732 op->copyProperties(propCopy);
733 name.destroyOpProperties(propCopy);
734 operator delete(propertiesStorage);
735 propertiesStorage = nullptr;
736 }
737 }
738
739private:
740 OperationName name;
741 LocationAttr loc;
742 DictionaryAttr attrs;
743 SmallVector<Value, 8> operands;
744 SmallVector<Block *, 2> successors;
745 void *propertiesStorage = nullptr;
746};
747
748/// Replacing an operation. Erasing an operation is treated as a special case
749/// with "null" replacements. This rewrite is not immediately reflected in the
750/// IR. An internal IR mapping is updated, but values are not replaced and the
751/// original op is not erased until the rewrite is committed.
752class ReplaceOperationRewrite : public OperationRewrite {
753public:
754 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
755 Operation *op, const TypeConverter *converter)
756 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
757 converter(converter) {}
758
759 static bool classof(const IRRewrite *rewrite) {
760 return rewrite->getKind() == Kind::ReplaceOperation;
761 }
762
763 void commit(RewriterBase &rewriter) override;
764
765 void rollback() override;
766
767 void cleanup(RewriterBase &rewriter) override;
768
769private:
770 /// An optional type converter that can be used to materialize conversions
771 /// between the new and old values if necessary.
772 const TypeConverter *converter;
773};
774
775class CreateOperationRewrite : public OperationRewrite {
776public:
777 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
778 Operation *op)
779 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
780
781 static bool classof(const IRRewrite *rewrite) {
782 return rewrite->getKind() == Kind::CreateOperation;
783 }
784
785 void commit(RewriterBase &rewriter) override {
786 // The operation was already created and inserted. Just inform the listener.
787 if (auto *listener = rewriter.getListener())
788 listener->notifyOperationInserted(op, /*previous=*/{});
789 }
790
791 void rollback() override;
792};
793
794/// The type of materialization.
795enum MaterializationKind {
796 /// This materialization materializes a conversion from an illegal type to a
797 /// legal one.
798 Target,
799
800 /// This materialization materializes a conversion from a legal type back to
801 /// an illegal one.
802 Source
803};
804
805/// Helper class that stores metadata about an unresolved materialization.
806class UnresolvedMaterializationInfo {
807public:
808 UnresolvedMaterializationInfo() = default;
809 UnresolvedMaterializationInfo(const TypeConverter *converter,
810 MaterializationKind kind, Type originalType)
811 : converterAndKind(converter, kind), originalType(originalType) {}
812
813 /// Return the type converter of this materialization (which may be null).
814 const TypeConverter *getConverter() const {
815 return converterAndKind.getPointer();
816 }
817
818 /// Return the kind of this materialization.
819 MaterializationKind getMaterializationKind() const {
820 return converterAndKind.getInt();
821 }
822
823 /// Return the original type of the SSA value.
824 Type getOriginalType() const { return originalType; }
825
826private:
827 /// The corresponding type converter to use when resolving this
828 /// materialization, and the kind of this materialization.
829 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
830 converterAndKind;
831
832 /// The original type of the SSA value. Only used for target
833 /// materializations.
834 Type originalType;
835};
836
837/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
838/// op. Unresolved materializations fold away or are replaced with
839/// source/target materializations at the end of the dialect conversion.
840class UnresolvedMaterializationRewrite : public OperationRewrite {
841public:
842 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
843 UnrealizedConversionCastOp op,
844 ValueVector mappedValues)
845 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
846 mappedValues(std::move(mappedValues)) {}
847
848 static bool classof(const IRRewrite *rewrite) {
849 return rewrite->getKind() == Kind::UnresolvedMaterialization;
850 }
851
852 void rollback() override;
853
854 UnrealizedConversionCastOp getOperation() const {
855 return cast<UnrealizedConversionCastOp>(op);
856 }
857
858private:
859 /// The values in the conversion value mapping that are being replaced by the
860 /// results of this unresolved materialization.
861 ValueVector mappedValues;
862};
863} // namespace
864
865#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
866/// Return "true" if there is an operation rewrite that matches the specified
867/// rewrite type and operation among the given rewrites.
868template <typename RewriteTy, typename R>
869static bool hasRewrite(R &&rewrites, Operation *op) {
870 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
871 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
872 return rewriteTy && rewriteTy->getOperation() == op;
873 });
874}
875
876/// Return "true" if there is a block rewrite that matches the specified
877/// rewrite type and block among the given rewrites.
878template <typename RewriteTy, typename R>
879static bool hasRewrite(R &&rewrites, Block *block) {
880 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
881 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
882 return rewriteTy && rewriteTy->getBlock() == block;
883 });
884}
885#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
886
887//===----------------------------------------------------------------------===//
888// ConversionPatternRewriterImpl
889//===----------------------------------------------------------------------===//
890namespace mlir {
891namespace detail {
893 explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
894 const ConversionConfig &config,
898
899 //===--------------------------------------------------------------------===//
900 // State Management
901 //===--------------------------------------------------------------------===//
902
903 /// Return the current state of the rewriter.
904 RewriterState getCurrentState();
905
906 /// Apply all requested operation rewrites. This method is invoked when the
907 /// conversion process succeeds.
908 void applyRewrites();
909
910 /// Reset the state of the rewriter to a previously saved point. Optionally,
911 /// the name of the pattern that triggered the rollback can specified for
912 /// debugging purposes.
913 void resetState(RewriterState state, StringRef patternName = "");
914
915 /// Append a rewrite. Rewrites are committed upon success and rolled back upon
916 /// failure.
917 template <typename RewriteTy, typename... Args>
918 void appendRewrite(Args &&...args) {
919 assert(config.allowPatternRollback && "appending rewrites is not allowed");
920 rewrites.push_back(
921 std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
922 }
923
924 /// Undo the rewrites (motions, splits) one by one in reverse order until
925 /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
926 /// that triggered the rollback can specified for debugging purposes.
927 void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
928
929 /// Remap the given values to those with potentially different types. Returns
930 /// success if the values could be remapped, failure otherwise. `valueDiagTag`
931 /// is the tag used when describing a value within a diagnostic, e.g.
932 /// "operand".
933 LogicalResult remapValues(StringRef valueDiagTag,
934 std::optional<Location> inputLoc, ValueRange values,
935 SmallVector<ValueVector> &remapped);
936
937 /// Return "true" if the given operation is ignored, and does not need to be
938 /// converted.
939 bool isOpIgnored(Operation *op) const;
940
941 /// Return "true" if the given operation was replaced or erased.
942 bool wasOpReplaced(Operation *op) const;
943
944 /// Lookup the most recently mapped values with the desired types in the
945 /// mapping, taking into account only replacements. Perform a best-effort
946 /// search for existing materializations with the desired types.
947 ///
948 /// If `skipPureTypeConversions` is "true", materializations that are pure
949 /// type conversions are not considered.
950 ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
951 bool skipPureTypeConversions = false) const;
952
953 /// Lookup the given value within the map, or return an empty vector if the
954 /// value is not mapped. If it is mapped, this follows the same behavior
955 /// as `lookupOrDefault`.
956 ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
957
958 //===--------------------------------------------------------------------===//
959 // IR Rewrites / Type Conversion
960 //===--------------------------------------------------------------------===//
961
962 /// Convert the types of block arguments within the given region.
963 FailureOr<Block *>
964 convertRegionTypes(Region *region, const TypeConverter &converter,
965 TypeConverter::SignatureConversion *entryConversion);
966
967 /// Apply the given signature conversion on the given block. The new block
968 /// containing the updated signature is returned. If no conversions were
969 /// necessary, e.g. if the block has no arguments, `block` is returned.
970 /// `converter` is used to generate any necessary cast operations that
971 /// translate between the origin argument types and those specified in the
972 /// signature conversion.
973 Block *applySignatureConversion(
974 Block *block, const TypeConverter *converter,
975 TypeConverter::SignatureConversion &signatureConversion);
976
977 /// Replace the results of the given operation with the given values and
978 /// erase the operation.
979 ///
980 /// There can be multiple replacement values for each result (1:N
981 /// replacement). If the replacement values are empty, the respective result
982 /// is dropped and a source materialization is built if the result still has
983 /// uses.
984 void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
985
986 /// Replace the uses of the given value with the given values. The specified
987 /// converter is used to build materializations (if necessary). If `functor`
988 /// is specified, only the uses that the functor returns "true" for are
989 /// replaced.
990 void replaceValueUses(Value from, ValueRange to,
991 const TypeConverter *converter,
992 function_ref<bool(OpOperand &)> functor = nullptr);
993
994 /// Erase the given block and its contents.
995 void eraseBlock(Block *block);
996
997 /// Inline the source block into the destination block before the given
998 /// iterator.
999 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
1000
1001 //===--------------------------------------------------------------------===//
1002 // Materializations
1003 //===--------------------------------------------------------------------===//
1004
1005 /// Build an unresolved materialization operation given a range of output
1006 /// types and a list of input operands. Returns the inputs if they their
1007 /// types match the output types.
1008 ///
1009 /// If a cast op was built, it can optionally be returned with the `castOp`
1010 /// output argument.
1011 ///
1012 /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
1013 /// the results of the unresolved materialization in the conversion value
1014 /// mapping.
1015 ///
1016 /// If `isPureTypeConversion` is "true", the materialization is created only
1017 /// to resolve a type mismatch. That means it is not a regular value
1018 /// replacement issued by the user. (Replacement values that are created
1019 /// "out of thin air" appear like unresolved materializations because they are
1020 /// unrealized_conversion_cast ops. However, they must be treated like
1021 /// regular value replacements.)
1022 ValueRange buildUnresolvedMaterialization(
1023 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1024 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1025 Type originalType, const TypeConverter *converter,
1026 bool isPureTypeConversion = true);
1027
1028 /// Find a replacement value for the given SSA value in the conversion value
1029 /// mapping. The replacement value must have the same type as the given SSA
1030 /// value. If there is no replacement value with the correct type, find the
1031 /// latest replacement value (regardless of the type) and build a source
1032 /// materialization.
1033 Value findOrBuildReplacementValue(Value value,
1034 const TypeConverter *converter);
1035
1036 //===--------------------------------------------------------------------===//
1037 // Rewriter Notification Hooks
1038 //===--------------------------------------------------------------------===//
1039
1040 //// Notifies that an op was inserted.
1041 void notifyOperationInserted(Operation *op,
1042 OpBuilder::InsertPoint previous) override;
1043
1044 /// Notifies that a block was inserted.
1045 void notifyBlockInserted(Block *block, Region *previous,
1046 Region::iterator previousIt) override;
1047
1048 /// Notifies that a pattern match failed for the given reason.
1049 void
1050 notifyMatchFailure(Location loc,
1051 function_ref<void(Diagnostic &)> reasonCallback) override;
1052
1053 //===--------------------------------------------------------------------===//
1054 // IR Erasure
1055 //===--------------------------------------------------------------------===//
1056
1057 /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1058 /// operation or block is erased multiple times. This rewriter assumes that
1059 /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1061 public:
1064 std::function<void(Operation *)> opErasedCallback = nullptr)
1065 : RewriterBase(context, /*listener=*/this),
1066 opErasedCallback(std::move(opErasedCallback)) {}
1067
1068 /// Erase the given op (unless it was already erased).
1069 void eraseOp(Operation *op) override {
1070 if (wasErased(op))
1071 return;
1072 op->dropAllUses();
1074 }
1075
1076 /// Erase the given block (unless it was already erased).
1077 void eraseBlock(Block *block) override {
1078 if (wasErased(block))
1079 return;
1080 assert(block->empty() && "expected empty block");
1081 block->dropAllDefinedValueUses();
1083 }
1084
1085 bool wasErased(void *ptr) const { return erased.contains(ptr); }
1086
1088 erased.insert(op);
1089 if (opErasedCallback)
1090 opErasedCallback(op);
1091 }
1092
1093 void notifyBlockErased(Block *block) override { erased.insert(block); }
1094
1095 private:
1096 /// Pointers to all erased operations and blocks.
1097 DenseSet<void *> erased;
1098
1099 /// A callback that is invoked when an operation is erased.
1100 std::function<void(Operation *)> opErasedCallback;
1101 };
1102
1103 //===--------------------------------------------------------------------===//
1104 // State
1105 //===--------------------------------------------------------------------===//
1106
1107 /// The rewriter that is used to perform the conversion.
1108 ConversionPatternRewriter &rewriter;
1109
1110 // Mapping between replaced values that differ in type. This happens when
1111 // replacing a value with one of a different type.
1112 ConversionValueMapping mapping;
1113
1114 /// Ordered list of block operations (creations, splits, motions).
1115 /// This vector is maintained only if `allowPatternRollback` is set to
1116 /// "true". Otherwise, all IR rewrites are materialized immediately and no
1117 /// bookkeeping is needed.
1119
1120 /// A set of operations that should no longer be considered for legalization.
1121 /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1122 /// tracked separately.
1124
1125 /// A set of operations that were replaced/erased. Such ops are not erased
1126 /// immediately but only when the dialect conversion succeeds. In the mean
1127 /// time, they should no longer be considered for legalization and any attempt
1128 /// to modify/access them is invalid rewriter API usage.
1130
1131 /// A set of operations that were created by the current pattern.
1133
1134 /// A set of operations that were modified by the current pattern.
1136
1137 /// A list of unresolved materializations that were created by the current
1138 /// pattern.
1140
1141 /// A mapping for looking up metadata of unresolved materializations.
1144
1145 /// The current type converter, or nullptr if no type converter is currently
1146 /// active.
1148
1149 /// A mapping of regions to type converters that should be used when
1150 /// converting the arguments of blocks within that region.
1152
1153 /// Dialect conversion configuration.
1154 const ConversionConfig &config;
1155
1156 /// The operation converter to use for recursive legalization.
1158
1159 /// A set of erased operations. This set is utilized only if
1160 /// `allowPatternRollback` is set to "false". Conceptually, this set is
1161 /// similar to `replacedOps` (which is maintained when the flag is set to
1162 /// "true"). However, erasing from a DenseSet is more efficient than erasing
1163 /// from a SetVector.
1165
1166 /// A set of erased blocks. This set is utilized only if
1167 /// `allowPatternRollback` is set to "false".
1169
1170 /// A rewriter that notifies the listener (if any) about all IR
1171 /// modifications. This rewriter is utilized only if `allowPatternRollback`
1172 /// is set to "false". If the flag is set to "true", the listener is notified
1173 /// with a separate mechanism (e.g., in `IRRewrite::commit`).
1175
1176#ifndef NDEBUG
1177 /// A set of replaced values. This set is for debugging purposes only and it
1178 /// is maintained only if `allowPatternRollback` is set to "true".
1180
1181 /// A set of operations that have pending updates. This tracking isn't
1182 /// strictly necessary, and is thus only active during debug builds for extra
1183 /// verification.
1185
1186 /// A raw output stream used to prefix the debug log.
1187 llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
1188 llvm::dbgs()};
1189
1190 /// A logger used to emit diagnostics during the conversion process.
1191 llvm::ScopedPrinter logger{os};
1192 std::string logPrefix;
1193#endif
1194};
1195} // namespace detail
1196} // namespace mlir
1197
1198const ConversionConfig &IRRewrite::getConfig() const {
1199 return rewriterImpl.config;
1200}
1201
1202void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1203 // Inform the listener about all IR modifications that have already taken
1204 // place: References to the original block have been replaced with the new
1205 // block.
1206 if (auto *listener =
1207 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1208 for (Operation *op : getNewBlock()->getUsers())
1209 listener->notifyOperationModified(op);
1210}
1211
1212void BlockTypeConversionRewrite::rollback() {
1213 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1214}
1215
1216/// Replace all uses of `from` with `repl`.
1217static void
1219 function_ref<bool(OpOperand &)> functor = nullptr) {
1220 if (isa<BlockArgument>(repl)) {
1221 // `repl` is a block argument. Directly replace all uses.
1222 if (functor) {
1223 rewriter.replaceUsesWithIf(from, repl, functor);
1224 } else {
1225 rewriter.replaceAllUsesWith(from, repl);
1226 }
1227 return;
1228 }
1229
1230 // If the replacement value is an operation, only replace those uses that:
1231 // - are in a different block than the replacement operation, or
1232 // - are in the same block but after the replacement operation.
1233 //
1234 // Example:
1235 // ^bb0(%arg0: i32):
1236 // %0 = "consumer"(%arg0) : (i32) -> (i32)
1237 // "another_consumer"(%arg0) : (i32) -> ()
1238 //
1239 // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1240 // use in "another_consumer" but not the use in "consumer". When using the
1241 // normal RewriterBase API, this would typically be done with
1242 // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1243 // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1244 // it cannot be supported efficiently with `allowPatternRollback` set to
1245 // "true". Therefore, the conversion driver is trying to be smart and replaces
1246 // only those uses that do not lead to a dominance violation. E.g., the
1247 // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1248 // behavior.
1249 //
1250 // TODO: As we move more and more towards `allowPatternRollback` set to
1251 // "false", we should remove this special handling, in order to align the
1252 // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1253 Operation *replOp = repl.getDefiningOp();
1254 Block *replBlock = replOp->getBlock();
1255 rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
1256 Operation *user = operand.getOwner();
1257 bool result =
1258 user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1259 if (result && functor)
1260 result &= functor(operand);
1261 return result;
1262 });
1263}
1264
1265void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1266 Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
1267 if (!repl)
1268 return;
1269 performReplaceValue(rewriter, value, repl);
1270}
1271
1272void ReplaceValueRewrite::rollback() {
1273 rewriterImpl.mapping.erase({value});
1274#ifndef NDEBUG
1275 rewriterImpl.replacedValues.erase(value);
1276#endif // NDEBUG
1277}
1278
1279void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1280 auto *listener =
1281 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1282
1283 // Compute replacement values.
1284 SmallVector<Value> replacements =
1285 llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1286 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1287 });
1288
1289 // Notify the listener that the operation is about to be replaced.
1290 if (listener)
1291 listener->notifyOperationReplaced(op, replacements);
1292
1293 // Replace all uses with the new values.
1294 for (auto [result, newValue] :
1295 llvm::zip_equal(op->getResults(), replacements))
1296 if (newValue)
1297 rewriter.replaceAllUsesWith(result, newValue);
1298
1299 // The original op will be erased, so remove it from the set of unlegalized
1300 // ops.
1301 if (getConfig().unlegalizedOps)
1302 getConfig().unlegalizedOps->erase(op);
1303
1304 // Notify the listener that the operation and its contents are being erased.
1305 if (listener)
1306 notifyIRErased(listener, *op);
1307
1308 // Do not erase the operation yet. It may still be referenced in `mapping`.
1309 // Just unlink it for now and erase it during cleanup.
1310 if (!op->getBlock())
1311 llvm::reportFatalInternalError(
1312 "dialect conversion attempted to replace a root operation that has no "
1313 "parent block; the pass must ensure its target op is nested in a "
1314 "block");
1315 op->getBlock()->getOperations().remove(op);
1316}
1317
1318void ReplaceOperationRewrite::rollback() {
1319 for (auto result : op->getResults())
1320 rewriterImpl.mapping.erase({result});
1321}
1322
1323void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1324 rewriter.eraseOp(op);
1325}
1326
1327void CreateOperationRewrite::rollback() {
1328 for (Region &region : op->getRegions()) {
1329 while (!region.getBlocks().empty())
1330 region.getBlocks().remove(region.getBlocks().begin());
1331 }
1332 op->dropAllUses();
1333 op->erase();
1334}
1335
1336void UnresolvedMaterializationRewrite::rollback() {
1337 if (!mappedValues.empty())
1338 rewriterImpl.mapping.erase(mappedValues);
1339 rewriterImpl.unresolvedMaterializations.erase(getOperation());
1340 op->erase();
1341}
1342
1344 // Commit all rewrites. Use a new rewriter, so the modifications are not
1345 // tracked for rollback purposes etc.
1346 IRRewriter irRewriter(rewriter.getContext(), config.listener);
1347 // Note: New rewrites may be added during the "commit" phase and the
1348 // `rewrites` vector may reallocate.
1349 for (size_t i = 0; i < rewrites.size(); ++i)
1350 rewrites[i]->commit(irRewriter);
1351
1352 // Clean up all rewrites.
1353 SingleEraseRewriter eraseRewriter(
1354 rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
1355 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1356 unresolvedMaterializations.erase(castOp);
1357 });
1358 for (auto &rewrite : rewrites)
1359 rewrite->cleanup(eraseRewriter);
1360}
1361
1362//===----------------------------------------------------------------------===//
1363// State Management
1364//===----------------------------------------------------------------------===//
1365
1367 Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1368 // Helper function that looks up a single value.
1369 auto lookup = [&](const ValueVector &values) -> ValueVector {
1370 assert(!values.empty() && "expected non-empty value vector");
1371
1372 // If the pattern rollback is enabled, use the mapping to look up the
1373 // values.
1374 if (config.allowPatternRollback)
1375 return mapping.lookup(values);
1376
1377 // Otherwise, look up values by examining the IR. All replacements have
1378 // already been materialized in IR.
1379 Operation *op = getCommonDefiningOp(values);
1380 if (!op)
1381 return {};
1382 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1383 if (!castOp)
1384 return {};
1385 if (!this->unresolvedMaterializations.contains(castOp))
1386 return {};
1387 if (castOp.getOutputs() != values)
1388 return {};
1389 return castOp.getInputs();
1390 };
1391
1392 // Helper function that looks up each value in `values` individually and then
1393 // composes the results. If that fails, it tries to look up the entire vector
1394 // at once.
1395 auto composedLookup = [&](const ValueVector &values) -> ValueVector {
1396 // If possible, replace each value with (one or multiple) mapped values.
1397 ValueVector next;
1398 for (Value v : values) {
1399 ValueVector r = lookup({v});
1400 if (!r.empty()) {
1401 llvm::append_range(next, r);
1402 } else {
1403 next.push_back(v);
1404 }
1405 }
1406 if (next != values) {
1407 // At least one value was replaced.
1408 return next;
1409 }
1410
1411 // Otherwise: Check if there is a mapping for the entire vector. Such
1412 // mappings are materializations. (N:M mapping are not supported for value
1413 // replacements.)
1414 //
1415 // Note: From a correctness point of view, materializations do not have to
1416 // be stored (and looked up) in the mapping. But for performance reasons,
1417 // we choose to reuse existing IR (when possible) instead of creating it
1418 // multiple times.
1419 ValueVector r = lookup(values);
1420 if (r.empty()) {
1421 // No mapping found: The lookup stops here.
1422 return {};
1423 }
1424 return r;
1425 };
1426
1427 // Try to find the deepest values that have the desired types. If there is no
1428 // such mapping, simply return the deepest values.
1429 ValueVector desiredValue;
1430 ValueVector current{from};
1431 ValueVector lastNonMaterialization{from};
1432 do {
1433 // Store the current value if the types match.
1434 bool match = TypeRange(ValueRange(current)) == desiredTypes;
1435 if (skipPureTypeConversions) {
1436 // Skip pure type conversions, if requested.
1437 bool pureConversion = isPureTypeConversion(current);
1438 match &= !pureConversion;
1439 // Keep track of the last mapped value that was not a pure type
1440 // conversion.
1441 if (!pureConversion)
1442 lastNonMaterialization = current;
1443 }
1444 if (match)
1445 desiredValue = current;
1446
1447 // Lookup next value in the mapping.
1448 ValueVector next = composedLookup(current);
1449 if (next.empty())
1450 break;
1451 current = std::move(next);
1452 } while (true);
1453
1454 // If the desired values were found use them, otherwise default to the leaf
1455 // values. (Skip pure type conversions, if requested.)
1456 if (!desiredTypes.empty())
1457 return desiredValue;
1458 if (skipPureTypeConversions)
1459 return lastNonMaterialization;
1460 return current;
1461}
1462
1465 TypeRange desiredTypes) const {
1466 ValueVector result = lookupOrDefault(from, desiredTypes);
1467 if (result == ValueVector{from} ||
1468 (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
1469 return {};
1470 return result;
1471}
1472
1474 return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1475}
1476
1478 StringRef patternName) {
1479 // Undo any rewrites.
1480 undoRewrites(state.numRewrites, patternName);
1481
1482 // Pop all of the recorded ignored operations that are no longer valid.
1483 while (ignoredOps.size() != state.numIgnoredOperations)
1484 ignoredOps.pop_back();
1485
1486 while (replacedOps.size() != state.numReplacedOps)
1487 replacedOps.pop_back();
1488}
1489
1490void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1491 StringRef patternName) {
1492 for (auto &rewrite :
1493 llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1494 rewrite->rollback();
1495 rewrites.resize(numRewritesToKeep);
1496}
1497
1499 StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
1500 SmallVector<ValueVector> &remapped) {
1501 remapped.reserve(llvm::size(values));
1502
1503 for (const auto &it : llvm::enumerate(values)) {
1504 Value operand = it.value();
1505 Type origType = operand.getType();
1506 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1507
1508 if (!currentTypeConverter) {
1509 // The current pattern does not have a type converter. Pass the most
1510 // recently mapped values, excluding materializations. Materializations
1511 // are intentionally excluded because their presence may depend on other
1512 // patterns. Including materializations would make the lookup fragile
1513 // and unpredictable.
1514 remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1515 /*skipPureTypeConversions=*/true));
1516 continue;
1517 }
1518
1519 // If there is no legal conversion, fail to match this pattern.
1520 SmallVector<Type, 1> legalTypes;
1521 if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
1522 notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1523 diag << "unable to convert type for " << valueDiagTag << " #"
1524 << it.index() << ", type was " << origType;
1525 });
1526 return failure();
1527 }
1528 // If a type is converted to 0 types, there is nothing to do.
1529 if (legalTypes.empty()) {
1530 remapped.push_back({});
1531 continue;
1532 }
1533
1534 ValueVector repl = lookupOrDefault(operand, legalTypes);
1535 if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1536 // Mapped values have the correct type or there is an existing
1537 // materialization. Or the operand is not mapped at all and has the
1538 // correct type.
1539 remapped.push_back(std::move(repl));
1540 continue;
1541 }
1542
1543 // Create a materialization for the most recently mapped values.
1544 repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1545 /*skipPureTypeConversions=*/true);
1547 MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1548 /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1549 /*originalType=*/origType, currentTypeConverter);
1550 remapped.push_back(castValues);
1551 }
1552 return success();
1553}
1554
1556 // Check to see if this operation is ignored or was replaced.
1557 return wasOpReplaced(op) || ignoredOps.count(op);
1558}
1559
1561 // Check to see if this operation was replaced.
1562 return replacedOps.count(op) || erasedOps.count(op);
1563}
1564
1565//===----------------------------------------------------------------------===//
1566// Type Conversion
1567//===----------------------------------------------------------------------===//
1568
1570 Region *region, const TypeConverter &converter,
1571 TypeConverter::SignatureConversion *entryConversion) {
1572 regionToConverter[region] = &converter;
1573 if (region->empty())
1574 return nullptr;
1575
1576 // Convert the arguments of each non-entry block within the region.
1577 for (Block &block :
1578 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1579 // Compute the signature for the block with the provided converter.
1580 std::optional<TypeConverter::SignatureConversion> conversion =
1581 converter.convertBlockSignature(&block);
1582 if (!conversion)
1583 return failure();
1584 // Convert the block with the computed signature.
1585 applySignatureConversion(&block, &converter, *conversion);
1586 }
1587
1588 // Convert the entry block. If an entry signature conversion was provided,
1589 // use that one. Otherwise, compute the signature with the type converter.
1590 if (entryConversion)
1591 return applySignatureConversion(&region->front(), &converter,
1592 *entryConversion);
1593 std::optional<TypeConverter::SignatureConversion> conversion =
1594 converter.convertBlockSignature(&region->front());
1595 if (!conversion)
1596 return failure();
1597 return applySignatureConversion(&region->front(), &converter, *conversion);
1598}
1599
1601 Block *block, const TypeConverter *converter,
1602 TypeConverter::SignatureConversion &signatureConversion) {
1603#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1604 // A block cannot be converted multiple times.
1605 if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1606 llvm::reportFatalInternalError("block was already converted");
1607#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1608
1610
1611 // If no arguments are being changed or added, there is nothing to do.
1612 unsigned origArgCount = block->getNumArguments();
1613 auto convertedTypes = signatureConversion.getConvertedTypes();
1614 if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1615 return block;
1616
1617 // Compute the locations of all block arguments in the new block.
1618 SmallVector<Location> newLocs(convertedTypes.size(),
1619 rewriter.getUnknownLoc());
1620 for (unsigned i = 0; i < origArgCount; ++i) {
1621 auto inputMap = signatureConversion.getInputMapping(i);
1622 if (!inputMap || inputMap->replacedWithValues())
1623 continue;
1624 Location origLoc = block->getArgument(i).getLoc();
1625 for (unsigned j = 0; j < inputMap->size; ++j)
1626 newLocs[inputMap->inputNo + j] = origLoc;
1627 }
1628
1629 // Insert a new block with the converted block argument types and move all ops
1630 // from the old block to the new block.
1631 Block *newBlock =
1632 rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1633 convertedTypes, newLocs);
1634
1635 // If a listener is attached to the dialect conversion, ops cannot be moved
1636 // to the destination block in bulk ("fast path"). This is because at the time
1637 // the notifications are sent, it is unknown which ops were moved. Instead,
1638 // ops should be moved one-by-one ("slow path"), so that a separate
1639 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1640 // a bit more efficient, so we try to do that when possible.
1641 bool fastPath = !config.listener;
1642 if (fastPath) {
1643 if (config.allowPatternRollback)
1644 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1645 newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1646 } else {
1647 while (!block->empty())
1648 rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1649 }
1650
1651 // Replace all uses of the old block with the new block.
1652 block->replaceAllUsesWith(newBlock);
1653
1654 for (unsigned i = 0; i != origArgCount; ++i) {
1655 BlockArgument origArg = block->getArgument(i);
1656 Type origArgType = origArg.getType();
1657
1658 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1659 signatureConversion.getInputMapping(i);
1660 if (!inputMap) {
1661 // This block argument was dropped and no replacement value was provided.
1662 // Materialize a replacement value "out of thin air".
1663 // Note: Materialization must be built here because we cannot find a
1664 // valid insertion point in the new block. (Will point to the old block.)
1665 Value mat =
1667 MaterializationKind::Source,
1668 OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1669 origArg.getLoc(),
1670 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1671 /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1672 /*isPureTypeConversion=*/false)
1673 .front();
1674 replaceValueUses(origArg, mat, converter);
1675 continue;
1676 }
1677
1678 if (inputMap->replacedWithValues()) {
1679 // This block argument was dropped and replacement values were provided.
1680 assert(inputMap->size == 0 &&
1681 "invalid to provide a replacement value when the argument isn't "
1682 "dropped");
1683 replaceValueUses(origArg, inputMap->replacementValues, converter);
1684 continue;
1685 }
1686
1687 // This is a 1->1+ mapping.
1688 auto replArgs =
1689 newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1690 replaceValueUses(origArg, replArgs, converter);
1691 }
1692
1693 if (config.allowPatternRollback)
1694 appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1695
1696 // Erase the old block. (It is just unlinked for now and will be erased during
1697 // cleanup.)
1698 rewriter.eraseBlock(block);
1699
1700 return newBlock;
1701}
1702
1703//===----------------------------------------------------------------------===//
1704// Materializations
1705//===----------------------------------------------------------------------===//
1706
1707/// Build an unresolved materialization operation given an output type and set
1708/// of input operands.
1710 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1711 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1712 Type originalType, const TypeConverter *converter,
1713 bool isPureTypeConversion) {
1714 assert((!originalType || kind == MaterializationKind::Target) &&
1715 "original type is valid only for target materializations");
1716 assert(TypeRange(inputs) != outputTypes &&
1717 "materialization is not necessary");
1718
1719 // Create an unresolved materialization. We use a new OpBuilder to avoid
1720 // tracking the materialization like we do for other operations.
1721 OpBuilder builder(outputTypes.front().getContext());
1722 builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1723 UnrealizedConversionCastOp convertOp =
1724 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1725 if (config.attachDebugMaterializationKind) {
1726 StringRef kindStr =
1727 kind == MaterializationKind::Source ? "source" : "target";
1728 convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
1729 }
1731 convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
1732
1733 // Register the materialization.
1734 unresolvedMaterializations[convertOp] =
1735 UnresolvedMaterializationInfo(converter, kind, originalType);
1736 if (config.allowPatternRollback) {
1737 if (!valuesToMap.empty())
1738 mapping.map(valuesToMap, convertOp.getResults());
1740 std::move(valuesToMap));
1741 } else {
1742 patternMaterializations.insert(convertOp);
1743 }
1744 return convertOp.getResults();
1745}
1746
1748 Value value, const TypeConverter *converter) {
1749 assert(config.allowPatternRollback &&
1750 "this code path is valid only in rollback mode");
1751
1752 // Try to find a replacement value with the same type in the conversion value
1753 // mapping. This includes cached materializations. We try to reuse those
1754 // instead of generating duplicate IR.
1755 ValueVector repl = lookupOrNull(value, value.getType());
1756 if (!repl.empty())
1757 return repl.front();
1758
1759 // Check if the value is dead. No replacement value is needed in that case.
1760 // This is an approximate check that may have false negatives but does not
1761 // require computing and traversing an inverse mapping. (We may end up
1762 // building source materializations that are never used and that fold away.)
1763 if (llvm::all_of(value.getUsers(),
1764 [&](Operation *op) { return replacedOps.contains(op); }) &&
1765 !mapping.isMappedTo(value))
1766 return Value();
1767
1768 // No replacement value was found. Get the latest replacement value
1769 // (regardless of the type) and build a source materialization to the
1770 // original type.
1771 repl = lookupOrNull(value);
1772
1773 // Compute the insertion point of the materialization.
1775 if (repl.empty()) {
1776 // The source materialization has no inputs. Insert it right before the
1777 // value that it is replacing.
1778 ip = computeInsertPoint(value);
1779 } else {
1780 // Compute the "earliest" insertion point at which all values in `repl` are
1781 // defined. It is important to emit the materialization at that location
1782 // because the same materialization may be reused in a different context.
1783 // (That's because materializations are cached in the conversion value
1784 // mapping.) The insertion point of the materialization must be valid for
1785 // all future users that may be created later in the conversion process.
1786 ip = computeInsertPoint(repl);
1787 }
1789 MaterializationKind::Source, ip, value.getLoc(),
1790 /*valuesToMap=*/repl, /*inputs=*/repl,
1791 /*outputTypes=*/value.getType(),
1792 /*originalType=*/Type(), converter,
1793 /*isPureTypeConversion=*/!repl.empty())
1794 .front();
1795 return castValue;
1796}
1797
1798//===----------------------------------------------------------------------===//
1799// Rewriter Notification Hooks
1800//===----------------------------------------------------------------------===//
1801
1803 Operation *op, OpBuilder::InsertPoint previous) {
1804 // If no previous insertion point is provided, the op used to be detached.
1805 bool wasDetached = !previous.isSet();
1806 LLVM_DEBUG({
1807 logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
1808 << ")";
1809 if (wasDetached)
1810 logger.getOStream() << " (was detached)";
1811 logger.getOStream() << "\n";
1812 });
1813
1814 // In rollback mode, it is easier to misuse the API, so perform extra error
1815 // checking.
1816 assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
1817 "attempting to insert into a block within a replaced/erased op");
1818
1819 // In "no rollback" mode, the listener is always notified immediately.
1820 if (!config.allowPatternRollback && config.listener)
1821 config.listener->notifyOperationInserted(op, previous);
1822
1823 if (wasDetached) {
1824 // If the op was detached, it is most likely a newly created op. Add it the
1825 // set of newly created ops, so that it will be legalized. If this op is
1826 // not a newly created op, it will be legalized a second time, which is
1827 // inefficient but harmless.
1828 patternNewOps.insert(op);
1829
1830 if (config.allowPatternRollback) {
1831 // TODO: If the same op is inserted multiple times from a detached
1832 // state, the rollback mechanism may erase the same op multiple times.
1833 // This is a bug in the rollback-based dialect conversion driver.
1835 } else {
1836 // In "no rollback" mode, there is an extra data structure for tracking
1837 // erased operations that must be kept up to date.
1838 erasedOps.erase(op);
1839 }
1840 return;
1841 }
1842
1843 // The op was moved from one place to another.
1844 if (config.allowPatternRollback)
1846}
1847
1848/// Given that `fromRange` is about to be replaced with `toRange`, compute
1849/// replacement values with the types of `fromRange`.
1850static SmallVector<Value>
1852 const SmallVector<SmallVector<Value>> &toRange,
1853 const TypeConverter *converter) {
1854 assert(!impl.config.allowPatternRollback &&
1855 "this code path is valid only in 'no rollback' mode");
1856 SmallVector<Value> repls;
1857 for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1858 if (from.use_empty()) {
1859 // The replaced value is dead. No replacement value is needed.
1860 repls.push_back(Value());
1861 continue;
1862 }
1863
1864 if (to.empty()) {
1865 // The replaced value is dropped. Materialize a replacement value "out of
1866 // thin air".
1867 Value srcMat = impl.buildUnresolvedMaterialization(
1868 MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
1869 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1870 /*outputTypes=*/from.getType(), /*originalType=*/Type(),
1871 converter)[0];
1872 repls.push_back(srcMat);
1873 continue;
1874 }
1875
1876 if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
1877 // The replacement value already has the correct type. Use it directly.
1878 repls.push_back(to[0]);
1879 continue;
1880 }
1881
1882 // The replacement value has the wrong type. Build a source materialization
1883 // to the original type.
1884 // TODO: This is a bit inefficient. We should try to reuse existing
1885 // materializations if possible. This would require an extension of the
1886 // `lookupOrDefault` API.
1887 Value srcMat = impl.buildUnresolvedMaterialization(
1888 MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
1889 /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
1890 /*originalType=*/Type(), converter)[0];
1891 repls.push_back(srcMat);
1892 }
1893
1894 return repls;
1895}
1896
1898 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1899 assert(newValues.size() == op->getNumResults() &&
1900 "incorrect number of replacement values");
1901 LLVM_DEBUG({
1902 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
1903 << ")\n";
1905 // If the user-provided replacement types are different from the
1906 // legalized types, as per the current type converter, print a note.
1907 // In most cases, the replacement types are expected to match the types
1908 // produced by the type converter, so this could indicate a bug in the
1909 // user code.
1910 for (auto [result, repls] :
1911 llvm::zip_equal(op->getResults(), newValues)) {
1912 Type resultType = result.getType();
1913 auto logProlog = [&, repls = repls]() {
1914 logger.startLine() << " Note: Replacing op result of type "
1915 << resultType << " with value(s) of type (";
1916 llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) {
1917 logger.getOStream() << v.getType();
1918 });
1919 logger.getOStream() << ")";
1920 };
1921 SmallVector<Type> convertedTypes;
1922 if (failed(currentTypeConverter->convertTypes(resultType,
1923 convertedTypes))) {
1924 logProlog();
1925 logger.getOStream() << ", but the type converter failed to legalize "
1926 "the original type.\n";
1927 continue;
1928 }
1929 if (TypeRange(convertedTypes) != TypeRange(ValueRange(repls))) {
1930 logProlog();
1931 logger.getOStream() << ", but the legalized type(s) is/are (";
1932 llvm::interleaveComma(convertedTypes, logger.getOStream(),
1933 [&](Type t) { logger.getOStream() << t; });
1934 logger.getOStream() << ")\n";
1935 }
1936 }
1937 }
1938 });
1939
1940 if (!config.allowPatternRollback) {
1941 // Pattern rollback is not allowed: materialize all IR changes immediately.
1943 *this, op->getResults(), newValues, currentTypeConverter);
1944 // Update internal data structures, so that there are no dangling pointers
1945 // to erased IR.
1946 op->walk([&](Operation *op) {
1947 erasedOps.insert(op);
1948 ignoredOps.remove(op);
1949 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1950 unresolvedMaterializations.erase(castOp);
1951 patternMaterializations.erase(castOp);
1952 }
1953 // The original op will be erased, so remove it from the set of
1954 // unlegalized ops.
1955 if (config.unlegalizedOps)
1956 config.unlegalizedOps->erase(op);
1957 });
1958 op->walk([&](Block *block) { erasedBlocks.insert(block); });
1959 // Replace the op with the replacement values and notify the listener.
1960 notifyingRewriter.replaceOp(op, repls);
1961 return;
1962 }
1963
1964 assert(!ignoredOps.contains(op) && "operation was already replaced");
1965#ifndef NDEBUG
1966 for (Value v : op->getResults())
1967 assert(!replacedValues.contains(v) &&
1968 "attempting to replace a value that was already replaced");
1969#endif // NDEBUG
1970
1971 // Check if replaced op is an unresolved materialization, i.e., an
1972 // unrealized_conversion_cast op that was created by the conversion driver.
1973 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1974 // Make sure that the user does not mess with unresolved materializations
1975 // that were inserted by the conversion driver. We keep track of these
1976 // ops in internal data structures.
1977 assert(!unresolvedMaterializations.contains(castOp) &&
1978 "attempting to replace/erase an unresolved materialization");
1979 }
1980
1981 // Create mappings for each of the new result values.
1982 for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
1983 mapping.map(static_cast<Value>(result), std::move(repl));
1984
1986 // Mark this operation and all nested ops as replaced.
1987 op->walk([&](Operation *op) { replacedOps.insert(op); });
1988}
1989
1991 Value from, ValueRange to, const TypeConverter *converter,
1992 function_ref<bool(OpOperand &)> functor) {
1993 LLVM_DEBUG({
1994 logger.startLine() << "** Replace Value : '" << from << "'";
1995 if (auto blockArg = dyn_cast<BlockArgument>(from)) {
1996 if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
1997 logger.getOStream() << " (in region of '" << parentOp->getName()
1998 << "' (" << parentOp << ")";
1999 } else {
2000 logger.getOStream() << " (unlinked block)";
2001 }
2002 }
2003 if (functor) {
2004 logger.getOStream() << ", conditional replacement";
2005 }
2006 });
2007
2008 if (!config.allowPatternRollback) {
2009 SmallVector<Value> toConv = llvm::to_vector(to);
2010 SmallVector<Value> repls =
2011 getReplacementValues(*this, from, {toConv}, converter);
2012 IRRewriter r(from.getContext());
2013 Value repl = repls.front();
2014 if (!repl)
2015 return;
2016
2017 performReplaceValue(r, from, repl, functor);
2018 return;
2019 }
2020
2021#ifndef NDEBUG
2022 // Make sure that a value is not replaced multiple times. In rollback mode,
2023 // `replaceAllUsesWith` replaces not only all current uses of the given value,
2024 // but also all future uses that may be introduced by future pattern
2025 // applications. Therefore, it does not make sense to call
2026 // `replaceAllUsesWith` multiple times with the same value. Doing so would
2027 // overwrite the mapping and mess with the internal state of the dialect
2028 // conversion driver.
2029 assert(!replacedValues.contains(from) &&
2030 "attempting to replace a value that was already replaced");
2031 assert(!wasOpReplaced(from.getDefiningOp()) &&
2032 "attempting to replace a op result that was already replaced");
2033 replacedValues.insert(from);
2034#endif // NDEBUG
2035
2036 if (functor)
2037 llvm::reportFatalInternalError(
2038 "conditional value replacement is not supported in rollback mode");
2039 mapping.map(from, to);
2040 appendRewrite<ReplaceValueRewrite>(from, converter);
2041}
2042
2044 if (!config.allowPatternRollback) {
2045 // Pattern rollback is not allowed: materialize all IR changes immediately.
2046 // Update internal data structures, so that there are no dangling pointers
2047 // to erased IR.
2048 block->walk([&](Operation *op) {
2049 erasedOps.insert(op);
2050 ignoredOps.remove(op);
2051 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2052 unresolvedMaterializations.erase(castOp);
2053 patternMaterializations.erase(castOp);
2054 }
2055 // The original op will be erased, so remove it from the set of
2056 // unlegalized ops.
2057 if (config.unlegalizedOps)
2058 config.unlegalizedOps->erase(op);
2059 });
2060 block->walk([&](Block *block) { erasedBlocks.insert(block); });
2061 // Erase the block and notify the listener.
2062 notifyingRewriter.eraseBlock(block);
2063 return;
2064 }
2065
2066 assert(!wasOpReplaced(block->getParentOp()) &&
2067 "attempting to erase a block within a replaced/erased op");
2069
2070 // Unlink the block from its parent region. The block is kept in the rewrite
2071 // object and will be actually destroyed when rewrites are applied. This
2072 // allows us to keep the operations in the block live and undo the removal by
2073 // re-inserting the block.
2074 block->getParent()->getBlocks().remove(block);
2075
2076 // Mark all nested ops as erased.
2077 block->walk([&](Operation *op) { replacedOps.insert(op); });
2078}
2079
2081 Block *block, Region *previous, Region::iterator previousIt) {
2082 // If no previous insertion point is provided, the block used to be detached.
2083 bool wasDetached = !previous;
2084 Operation *newParentOp = block->getParentOp();
2085 LLVM_DEBUG(
2086 {
2087 Operation *parent = newParentOp;
2088 if (parent) {
2089 logger.startLine() << "** Insert Block into : '" << parent->getName()
2090 << "' (" << parent << ")";
2091 } else {
2092 logger.startLine()
2093 << "** Insert Block into detached Region (nullptr parent op)";
2094 }
2095 if (wasDetached)
2096 logger.getOStream() << " (was detached)";
2097 logger.getOStream() << "\n";
2098 });
2099
2100 // In rollback mode, it is easier to misuse the API, so perform extra error
2101 // checking.
2102 assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
2103 "attempting to insert into a region within a replaced/erased op");
2104 (void)newParentOp;
2105
2106 // In "no rollback" mode, the listener is always notified immediately.
2107 if (!config.allowPatternRollback && config.listener)
2108 config.listener->notifyBlockInserted(block, previous, previousIt);
2109
2110 if (wasDetached) {
2111 // If the block was detached, it is most likely a newly created block.
2112 if (config.allowPatternRollback) {
2113 // TODO: If the same block is inserted multiple times from a detached
2114 // state, the rollback mechanism may erase the same block multiple times.
2115 // This is a bug in the rollback-based dialect conversion driver.
2117 } else {
2118 // In "no rollback" mode, there is an extra data structure for tracking
2119 // erased blocks that must be kept up to date.
2120 erasedBlocks.erase(block);
2121 }
2122 return;
2123 }
2124
2125 // The block was moved from one place to another.
2126 if (config.allowPatternRollback)
2127 appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
2128}
2129
2131 Block *dest,
2132 Block::iterator before) {
2133 appendRewrite<InlineBlockRewrite>(dest, source, before);
2134}
2135
2137 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
2138 LLVM_DEBUG({
2140 reasonCallback(diag);
2141 logger.startLine() << "** Failure : " << diag.str() << "\n";
2142 if (config.notifyCallback)
2143 config.notifyCallback(diag);
2144 });
2145}
2146
2147//===----------------------------------------------------------------------===//
2148// ConversionPatternRewriter
2149//===----------------------------------------------------------------------===//
2150
2151ConversionPatternRewriter::ConversionPatternRewriter(
2152 MLIRContext *ctx, const ConversionConfig &config,
2153 OperationConverter &opConverter)
2155 *this, config, opConverter)) {
2156 setListener(impl.get());
2157}
2158
2159ConversionPatternRewriter::~ConversionPatternRewriter() = default;
2160
2161const ConversionConfig &ConversionPatternRewriter::getConfig() const {
2162 return impl->config;
2163}
2164
2165void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
2166 assert(op && newOp && "expected non-null op");
2167 replaceOp(op, newOp->getResults());
2168}
2169
2170void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
2171 assert(op->getNumResults() == newValues.size() &&
2172 "incorrect # of replacement values");
2173
2174 // If the current insertion point is before the erased operation, we adjust
2175 // the insertion point to be after the operation.
2176 if (getInsertionPoint() == op->getIterator())
2178
2179 SmallVector<SmallVector<Value>> newVals =
2180 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2181 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2182 });
2183 impl->replaceOp(op, std::move(newVals));
2184}
2185
2186void ConversionPatternRewriter::replaceOpWithMultiple(
2187 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2188 assert(op->getNumResults() == newValues.size() &&
2189 "incorrect # of replacement values");
2190
2191 // If the current insertion point is before the erased operation, we adjust
2192 // the insertion point to be after the operation.
2193 if (getInsertionPoint() == op->getIterator())
2195
2196 impl->replaceOp(op, std::move(newValues));
2197}
2198
2199void ConversionPatternRewriter::eraseOp(Operation *op) {
2200 LLVM_DEBUG({
2201 impl->logger.startLine()
2202 << "** Erase : '" << op->getName() << "'(" << op << ")\n";
2203 });
2204
2205 // If the current insertion point is before the erased operation, we adjust
2206 // the insertion point to be after the operation.
2207 if (getInsertionPoint() == op->getIterator())
2209
2210 SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
2211 impl->replaceOp(op, std::move(nullRepls));
2212}
2213
2214void ConversionPatternRewriter::eraseBlock(Block *block) {
2215 impl->eraseBlock(block);
2216}
2217
2218Block *ConversionPatternRewriter::applySignatureConversion(
2219 Block *block, TypeConverter::SignatureConversion &conversion,
2220 const TypeConverter *converter) {
2221 assert(!impl->wasOpReplaced(block->getParentOp()) &&
2222 "attempting to apply a signature conversion to a block within a "
2223 "replaced/erased op");
2224 return impl->applySignatureConversion(block, converter, conversion);
2225}
2226
2227FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2228 Region *region, const TypeConverter &converter,
2229 TypeConverter::SignatureConversion *entryConversion) {
2230 assert(!impl->wasOpReplaced(region->getParentOp()) &&
2231 "attempting to apply a signature conversion to a block within a "
2232 "replaced/erased op");
2233 return impl->convertRegionTypes(region, converter, entryConversion);
2234}
2235
2236void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
2237 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2238}
2239
2240void ConversionPatternRewriter::replaceUsesWithIf(
2241 Value from, ValueRange to, function_ref<bool(OpOperand &)> functor,
2242 bool *allUsesReplaced) {
2243 assert(!allUsesReplaced &&
2244 "allUsesReplaced is not supported in a dialect conversion");
2245 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2246}
2247
2248Value ConversionPatternRewriter::getRemappedValue(Value key) {
2249 SmallVector<ValueVector> remappedValues;
2250 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
2251 remappedValues)))
2252 return nullptr;
2253 assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
2254 return remappedValues.front().front();
2255}
2256
2257LogicalResult
2258ConversionPatternRewriter::getRemappedValues(ValueRange keys,
2259 SmallVectorImpl<Value> &results) {
2260 if (keys.empty())
2261 return success();
2262 SmallVector<ValueVector> remapped;
2263 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
2264 remapped)))
2265 return failure();
2266 for (const auto &values : remapped) {
2267 assert(values.size() == 1 && "1:N conversion not supported");
2268 results.push_back(values.front());
2269 }
2270 return success();
2271}
2272
2273void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
2274 Block::iterator before,
2275 ValueRange argValues) {
2276#ifndef NDEBUG
2277 assert(argValues.size() == source->getNumArguments() &&
2278 "incorrect # of argument replacement values");
2279 assert(!impl->wasOpReplaced(source->getParentOp()) &&
2280 "attempting to inline a block from a replaced/erased op");
2281 assert(!impl->wasOpReplaced(dest->getParentOp()) &&
2282 "attempting to inline a block into a replaced/erased op");
2283 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
2284 // The source block will be deleted, so it should not have any users (i.e.,
2285 // there should be no predecessors).
2286 assert(llvm::all_of(source->getUsers(), opIgnored) &&
2287 "expected 'source' to have no predecessors");
2288#endif // NDEBUG
2289
2290 // If a listener is attached to the dialect conversion, ops cannot be moved
2291 // to the destination block in bulk ("fast path"). This is because at the time
2292 // the notifications are sent, it is unknown which ops were moved. Instead,
2293 // ops should be moved one-by-one ("slow path"), so that a separate
2294 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
2295 // a bit more efficient, so we try to do that when possible.
2296 bool fastPath = !getConfig().listener;
2297
2298 if (fastPath && impl->config.allowPatternRollback)
2299 impl->inlineBlockBefore(source, dest, before);
2300
2301 // Replace all uses of block arguments.
2302 for (auto it : llvm::zip(source->getArguments(), argValues))
2303 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2304
2305 if (fastPath) {
2306 // Move all ops at once.
2307 dest->getOperations().splice(before, source->getOperations());
2308 } else {
2309 // Move op by op.
2310 while (!source->empty())
2311 moveOpBefore(&source->front(), dest, before);
2312 }
2313
2314 // If the current insertion point is within the source block, adjust the
2315 // insertion point to the destination block.
2316 if (getInsertionBlock() == source)
2317 setInsertionPoint(dest, getInsertionPoint());
2318
2319 // Erase the source block.
2320 eraseBlock(source);
2321}
2322
2323void ConversionPatternRewriter::startOpModification(Operation *op) {
2324 if (!impl->config.allowPatternRollback) {
2325 // Pattern rollback is not allowed: no extra bookkeeping is needed.
2327 return;
2328 }
2329 assert(!impl->wasOpReplaced(op) &&
2330 "attempting to modify a replaced/erased op");
2331#ifndef NDEBUG
2332 impl->pendingRootUpdates.insert(op);
2333#endif
2334 impl->appendRewrite<ModifyOperationRewrite>(op);
2335}
2336
2337void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2338 impl->patternModifiedOps.insert(op);
2339 if (!impl->config.allowPatternRollback) {
2341 if (getConfig().listener)
2342 getConfig().listener->notifyOperationModified(op);
2343 return;
2344 }
2345
2346 // There is nothing to do here, we only need to track the operation at the
2347 // start of the update.
2348#ifndef NDEBUG
2349 assert(!impl->wasOpReplaced(op) &&
2350 "attempting to modify a replaced/erased op");
2351 assert(impl->pendingRootUpdates.erase(op) &&
2352 "operation did not have a pending in-place update");
2353#endif
2354}
2355
2356void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2357 if (!impl->config.allowPatternRollback) {
2359 return;
2360 }
2361#ifndef NDEBUG
2362 assert(impl->pendingRootUpdates.erase(op) &&
2363 "operation did not have a pending in-place update");
2364#endif
2365 // Erase the last update for this operation.
2366 auto it = llvm::find_if(
2367 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
2368 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2369 return modifyRewrite && modifyRewrite->getOperation() == op;
2370 });
2371 assert(it != impl->rewrites.rend() && "no root update started on op");
2372 (*it)->rollback();
2373 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2374 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2375}
2376
2377detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2378 return *impl;
2379}
2380
2381//===----------------------------------------------------------------------===//
2382// ConversionPattern
2383//===----------------------------------------------------------------------===//
2384
2385FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2386 ArrayRef<ValueRange> operands) const {
2387 SmallVector<Value> oneToOneOperands;
2388 oneToOneOperands.reserve(operands.size());
2389 for (ValueRange operand : operands) {
2390 if (operand.size() != 1)
2391 return failure();
2392
2393 oneToOneOperands.push_back(operand.front());
2394 }
2395 return std::move(oneToOneOperands);
2396}
2397
2398LogicalResult
2399ConversionPattern::matchAndRewrite(Operation *op,
2400 PatternRewriter &rewriter) const {
2401 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
2402 auto &rewriterImpl = dialectRewriter.getImpl();
2403
2404 // Track the current conversion pattern type converter in the rewriter.
2405 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
2406 getTypeConverter());
2407
2408 // Remap the operands of the operation.
2409 SmallVector<ValueVector> remapped;
2410 if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
2411 op->getOperands(), remapped))) {
2412 return failure();
2413 }
2414 SmallVector<ValueRange> remappedAsRange =
2415 llvm::to_vector_of<ValueRange>(remapped);
2416 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2417}
2418
2419//===----------------------------------------------------------------------===//
2420// OperationLegalizer
2421//===----------------------------------------------------------------------===//
2422
2423namespace {
2424/// A set of rewrite patterns that can be used to legalize a given operation.
2425using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2426
2427/// This class defines a recursive operation legalizer.
2428class OperationLegalizer {
2429public:
2430 using LegalizationAction = ConversionTarget::LegalizationAction;
2431
2432 OperationLegalizer(ConversionPatternRewriter &rewriter,
2433 const ConversionTarget &targetInfo,
2434 const FrozenRewritePatternSet &patterns);
2435
2436 /// Returns true if the given operation is known to be illegal on the target.
2437 bool isIllegal(Operation *op) const;
2438
2439 /// Attempt to legalize the given operation. Returns success if the operation
2440 /// was legalized, failure otherwise.
2441 LogicalResult legalize(Operation *op);
2442
2443 /// Returns the conversion target in use by the legalizer.
2444 const ConversionTarget &getTarget() { return target; }
2445
2446private:
2447 /// Attempt to legalize the given operation by folding it.
2448 LogicalResult legalizeWithFold(Operation *op);
2449
2450 /// Attempt to legalize the given operation by applying a pattern. Returns
2451 /// success if the operation was legalized, failure otherwise.
2452 LogicalResult legalizeWithPattern(Operation *op);
2453
2454 /// Return true if the given pattern may be applied to the given operation,
2455 /// false otherwise.
2456 bool canApplyPattern(Operation *op, const Pattern &pattern);
2457
2458 /// Legalize the resultant IR after successfully applying the given pattern.
2459 LogicalResult
2460 legalizePatternResult(Operation *op, const Pattern &pattern,
2461 const RewriterState &curState,
2462 const SetVector<Operation *> &newOps,
2463 const SetVector<Operation *> &modifiedOps);
2464
2465 LogicalResult
2466 legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
2467 LogicalResult
2468 legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
2469
2470 //===--------------------------------------------------------------------===//
2471 // Cost Model
2472 //===--------------------------------------------------------------------===//
2473
2474 /// Build an optimistic legalization graph given the provided patterns. This
2475 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2476 /// patterns for operations that are not directly legal, but may be
2477 /// transitively legal for the current target given the provided patterns.
2478 void buildLegalizationGraph(
2479 LegalizationPatterns &anyOpLegalizerPatterns,
2481
2482 /// Compute the benefit of each node within the computed legalization graph.
2483 /// This orders the patterns within 'legalizerPatterns' based upon two
2484 /// criteria:
2485 /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2486 /// represent the more direct mapping to the target.
2487 /// 2) When comparing patterns with the same legalization depth, prefer the
2488 /// pattern with the highest PatternBenefit. This allows for users to
2489 /// prefer specific legalizations over others.
2490 void computeLegalizationGraphBenefit(
2491 LegalizationPatterns &anyOpLegalizerPatterns,
2493
2494 /// Compute the legalization depth when legalizing an operation of the given
2495 /// type.
2496 unsigned computeOpLegalizationDepth(
2497 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2499
2500 /// Apply the conversion cost model to the given set of patterns, and return
2501 /// the smallest legalization depth of any of the patterns. See
2502 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2503 unsigned applyCostModelToPatterns(
2504 LegalizationPatterns &patterns,
2505 DenseMap<OperationName, unsigned> &minOpPatternDepth,
2507
2508 /// The current set of patterns that have been applied.
2509 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2510
2511 /// The rewriter to use when converting operations.
2512 ConversionPatternRewriter &rewriter;
2513
2514 /// The legalization information provided by the target.
2515 const ConversionTarget &target;
2516
2517 /// The pattern applicator to use for conversions.
2518 PatternApplicator applicator;
2519};
2520} // namespace
2521
2522OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2523 const ConversionTarget &targetInfo,
2524 const FrozenRewritePatternSet &patterns)
2525 : rewriter(rewriter), target(targetInfo), applicator(patterns) {
2526 // The set of patterns that can be applied to illegal operations to transform
2527 // them into legal ones.
2529 LegalizationPatterns anyOpLegalizerPatterns;
2530
2531 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2532 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2533}
2534
2535bool OperationLegalizer::isIllegal(Operation *op) const {
2536 return target.isIllegal(op);
2537}
2538
2539LogicalResult OperationLegalizer::legalize(Operation *op) {
2540#ifndef NDEBUG
2541 const char *logLineComment =
2542 "//===-------------------------------------------===//\n";
2543
2544 auto &logger = rewriter.getImpl().logger;
2545#endif
2546
2547 // Check to see if the operation is ignored and doesn't need to be converted.
2548 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2549
2550 LLVM_DEBUG({
2551 logger.getOStream() << "\n";
2552 logger.startLine() << logLineComment;
2553 logger.startLine() << "Legalizing operation : ";
2554 // Do not print the operation name if the operation is ignored. Ignored ops
2555 // may have been erased and should not be accessed. The pointer can be
2556 // printed safely.
2557 if (!isIgnored)
2558 logger.getOStream() << "'" << op->getName() << "' ";
2559 logger.getOStream() << "(" << op << ") {\n";
2560 logger.indent();
2561
2562 // If the operation has no regions, just print it here.
2563 if (!isIgnored && op->getNumRegions() == 0) {
2564 logger.startLine() << OpWithFlags(op,
2565 OpPrintingFlags().printGenericOpForm())
2566 << "\n";
2567 }
2568 });
2569
2570 if (isIgnored) {
2571 LLVM_DEBUG({
2572 logSuccess(logger, "operation marked 'ignored' during conversion");
2573 logger.startLine() << logLineComment;
2574 });
2575 return success();
2576 }
2577
2578 // Check if this operation is legal on the target.
2579 if (auto legalityInfo = target.isLegal(op)) {
2580 LLVM_DEBUG({
2581 logSuccess(
2582 logger, "operation marked legal by the target{0}",
2583 legalityInfo->isRecursivelyLegal
2584 ? "; NOTE: operation is recursively legal; skipping internals"
2585 : "");
2586 logger.startLine() << logLineComment;
2587 });
2588
2589 // If this operation is recursively legal, mark its children as ignored so
2590 // that we don't consider them for legalization.
2591 if (legalityInfo->isRecursivelyLegal) {
2592 op->walk([&](Operation *nested) {
2593 if (op != nested)
2594 rewriter.getImpl().ignoredOps.insert(nested);
2595 });
2596 }
2597
2598 return success();
2599 }
2600
2601 // If the operation is not legal, try to fold it in-place if the folding mode
2602 // is 'BeforePatterns'. 'Never' will skip this.
2603 const ConversionConfig &config = rewriter.getConfig();
2604 if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2605 if (succeeded(legalizeWithFold(op))) {
2606 LLVM_DEBUG({
2607 logSuccess(logger, "operation was folded");
2608 logger.startLine() << logLineComment;
2609 });
2610 return success();
2611 }
2612 }
2613
2614 // Otherwise, we need to apply a legalization pattern to this operation.
2615 if (succeeded(legalizeWithPattern(op))) {
2616 LLVM_DEBUG({
2617 logSuccess(logger, "");
2618 logger.startLine() << logLineComment;
2619 });
2620 return success();
2621 }
2622
2623 // If the operation can't be legalized via patterns, try to fold it in-place
2624 // if the folding mode is 'AfterPatterns'.
2625 if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2626 if (succeeded(legalizeWithFold(op))) {
2627 LLVM_DEBUG({
2628 logSuccess(logger, "operation was folded");
2629 logger.startLine() << logLineComment;
2630 });
2631 return success();
2632 }
2633 }
2634
2635 LLVM_DEBUG({
2636 logFailure(logger, "no matched legalization pattern");
2637 logger.startLine() << logLineComment;
2638 });
2639 return failure();
2640}
2641
2642/// Helper function that moves and returns the given object. Also resets the
2643/// original object, so that it is in a valid, empty state again.
2644template <typename T>
2645static T moveAndReset(T &obj) {
2646 T result = std::move(obj);
2647 obj = T();
2648 return result;
2649}
2650
2651LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2652 auto &rewriterImpl = rewriter.getImpl();
2653 LLVM_DEBUG({
2654 rewriterImpl.logger.startLine() << "* Fold {\n";
2655 rewriterImpl.logger.indent();
2656 });
2657
2658 // Clear pattern state, so that the next pattern application starts with a
2659 // clean slate. (The op/block sets are populated by listener notifications.)
2660 llvm::scope_exit cleanup([&]() {
2661 rewriterImpl.patternNewOps.clear();
2662 rewriterImpl.patternModifiedOps.clear();
2663 });
2664
2665 // Upon failure, undo all changes made by the folder.
2666 RewriterState curState = rewriterImpl.getCurrentState();
2667
2668 // Try to fold the operation.
2669 StringRef opName = op->getName().getStringRef();
2670 SmallVector<Value, 2> replacementValues;
2671 SmallVector<Operation *, 2> newOps;
2672 rewriter.setInsertionPoint(op);
2673 rewriter.startOpModification(op);
2674 if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2675 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2676 rewriter.cancelOpModification(op);
2677 return failure();
2678 }
2679 rewriter.finalizeOpModification(op);
2680
2681 // An empty list of replacement values indicates that the fold was in-place.
2682 // As the operation changed, a new legalization needs to be attempted.
2683 if (replacementValues.empty())
2684 return legalize(op);
2685
2686 // Insert a replacement for 'op' with the folded replacement values.
2687 rewriter.replaceOp(op, replacementValues);
2688
2689 // Recursively legalize any new constant operations.
2690 for (Operation *newOp : newOps) {
2691 if (failed(legalize(newOp))) {
2692 LLVM_DEBUG(logFailure(rewriterImpl.logger,
2693 "failed to legalize generated constant '{0}'",
2694 newOp->getName()));
2695 if (!rewriter.getConfig().allowPatternRollback) {
2696 // Rolling back a folder is like rolling back a pattern.
2697 llvm::reportFatalInternalError(
2698 "op '" + opName +
2699 "' folder rollback of IR modifications requested");
2700 }
2701 rewriterImpl.resetState(
2702 curState, std::string(op->getName().getStringRef()) + " folder");
2703 return failure();
2704 }
2705 }
2706
2707 LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2708 return success();
2709}
2710
2711/// Report a fatal error indicating that newly produced or modified IR could
2712/// not be legalized.
2713static void
2715 const SetVector<Operation *> &newOps,
2716 const SetVector<Operation *> &modifiedOps) {
2717 auto newOpNames = llvm::map_range(
2718 newOps, [](Operation *op) { return op->getName().getStringRef(); });
2719 auto modifiedOpNames = llvm::map_range(
2720 modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
2721 llvm::reportFatalInternalError("pattern '" + pattern.getDebugName() +
2722 "' produced IR that could not be legalized. " +
2723 "new ops: {" + llvm::join(newOpNames, ", ") +
2724 "}, " + "modified ops: {" +
2725 llvm::join(modifiedOpNames, ", ") + "}");
2726}
2727
2728LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2729 auto &rewriterImpl = rewriter.getImpl();
2730 const ConversionConfig &config = rewriter.getConfig();
2731
2732#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2733 Operation *checkOp;
2734 std::optional<OperationFingerPrint> topLevelFingerPrint;
2735 if (!rewriterImpl.config.allowPatternRollback) {
2736 // The op may be getting erased, so we have to check the parent op.
2737 // (In rare cases, a pattern may even erase the parent op, which will cause
2738 // a crash here. Expensive checks are "best effort".) Skip the check if the
2739 // op does not have a parent op.
2740 if ((checkOp = op->getParentOp())) {
2741 if (!op->getContext()->isMultithreadingEnabled()) {
2742 topLevelFingerPrint = OperationFingerPrint(checkOp);
2743 } else {
2744 // Another thread may be modifying a sibling operation. Therefore, the
2745 // fingerprinting mechanism of the parent op works only in
2746 // single-threaded mode.
2747 LLVM_DEBUG({
2748 rewriterImpl.logger.startLine()
2749 << "WARNING: Multi-threadeding is enabled. Some dialect "
2750 "conversion expensive checks are skipped in multithreading "
2751 "mode!\n";
2752 });
2753 }
2754 }
2755 }
2756#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2757
2758 // Functor that returns if the given pattern may be applied.
2759 auto canApply = [&](const Pattern &pattern) {
2760 bool canApply = canApplyPattern(op, pattern);
2761 if (canApply && config.listener)
2762 config.listener->notifyPatternBegin(pattern, op);
2763 return canApply;
2764 };
2765
2766 // Functor that cleans up the rewriter state after a pattern failed to match.
2767 RewriterState curState = rewriterImpl.getCurrentState();
2768 auto onFailure = [&](const Pattern &pattern) {
2769 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2770 if (!rewriterImpl.config.allowPatternRollback) {
2771 // Erase all unresolved materializations.
2772 for (auto op : rewriterImpl.patternMaterializations) {
2773 rewriterImpl.unresolvedMaterializations.erase(op);
2774 op.erase();
2775 }
2776 rewriterImpl.patternMaterializations.clear();
2777#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2778 // Expensive pattern check that can detect API violations.
2779 if (checkOp && topLevelFingerPrint) {
2780 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2781 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2782 llvm::reportFatalInternalError(
2783 "pattern '" + pattern.getDebugName() +
2784 "' returned failure but IR did change");
2785 }
2786#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2787 }
2788 rewriterImpl.patternNewOps.clear();
2789 rewriterImpl.patternModifiedOps.clear();
2790 LLVM_DEBUG({
2791 logFailure(rewriterImpl.logger, "pattern failed to match");
2792 if (rewriterImpl.config.notifyCallback) {
2793 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
2794 diag << "Failed to apply pattern \"" << pattern.getDebugName()
2795 << "\" on op:\n"
2796 << *op;
2797 rewriterImpl.config.notifyCallback(diag);
2798 }
2799 });
2800 if (config.listener)
2801 config.listener->notifyPatternEnd(pattern, failure());
2802 rewriterImpl.resetState(curState, pattern.getDebugName());
2803 appliedPatterns.erase(&pattern);
2804 };
2805
2806 // Functor that performs additional legalization when a pattern is
2807 // successfully applied.
2808 auto onSuccess = [&](const Pattern &pattern) {
2809 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2810 if (!rewriterImpl.config.allowPatternRollback) {
2811 // Eagerly erase unused materializations.
2812 for (auto op : rewriterImpl.patternMaterializations) {
2813 if (op->use_empty()) {
2814 rewriterImpl.unresolvedMaterializations.erase(op);
2815 op.erase();
2816 }
2817 }
2818 rewriterImpl.patternMaterializations.clear();
2819 }
2820 SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2821 SetVector<Operation *> modifiedOps =
2822 moveAndReset(rewriterImpl.patternModifiedOps);
2823 auto result =
2824 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2825 appliedPatterns.erase(&pattern);
2826 if (failed(result)) {
2827 if (!rewriterImpl.config.allowPatternRollback)
2828 reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
2829 rewriterImpl.resetState(curState, pattern.getDebugName());
2830 }
2831 if (config.listener)
2832 config.listener->notifyPatternEnd(pattern, result);
2833 return result;
2834 };
2835
2836 // Try to match and rewrite a pattern on this operation.
2837 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2838 onSuccess);
2839}
2840
2841bool OperationLegalizer::canApplyPattern(Operation *op,
2842 const Pattern &pattern) {
2843 LLVM_DEBUG({
2844 auto &os = rewriter.getImpl().logger;
2845 os.getOStream() << "\n";
2846 os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2847 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2848 os.getOStream() << ")' {\n";
2849 os.indent();
2850 });
2851
2852 // Ensure that we don't cycle by not allowing the same pattern to be
2853 // applied twice in the same recursion stack if it is not known to be safe.
2854 if (!pattern.hasBoundedRewriteRecursion() &&
2855 !appliedPatterns.insert(&pattern).second) {
2856 LLVM_DEBUG(
2857 logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2858 return false;
2859 }
2860 return true;
2861}
2862
2863LogicalResult OperationLegalizer::legalizePatternResult(
2864 Operation *op, const Pattern &pattern, const RewriterState &curState,
2865 const SetVector<Operation *> &newOps,
2866 const SetVector<Operation *> &modifiedOps) {
2867 [[maybe_unused]] auto &impl = rewriter.getImpl();
2868 assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2869
2870#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2871 if (impl.config.allowPatternRollback) {
2872 // Check that the root was either replaced or updated in place.
2873 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2874 auto replacedRoot = [&] {
2875 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2876 };
2877 auto updatedRootInPlace = [&] {
2878 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2879 };
2880 if (!replacedRoot() && !updatedRootInPlace())
2881 llvm::reportFatalInternalError(
2882 "expected pattern to replace the root operation "
2883 "or modify it in place");
2884 }
2885#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2886
2887 // Legalize each of the actions registered during application.
2888 if (failed(legalizePatternRootUpdates(modifiedOps)) ||
2889 failed(legalizePatternCreatedOperations(newOps))) {
2890 return failure();
2891 }
2892
2893 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2894 return success();
2895}
2896
2897LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2898 const SetVector<Operation *> &newOps) {
2899 for (Operation *op : newOps) {
2900 if (failed(legalize(op))) {
2901 LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
2902 "failed to legalize generated operation '{0}'({1})",
2903 op->getName(), op));
2904 return failure();
2905 }
2906 }
2907 return success();
2908}
2909
2910LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2911 const SetVector<Operation *> &modifiedOps) {
2912 for (Operation *op : modifiedOps) {
2913 if (failed(legalize(op))) {
2914 LLVM_DEBUG(
2915 logFailure(rewriter.getImpl().logger,
2916 "failed to legalize operation updated in-place '{0}'",
2917 op->getName()));
2918 return failure();
2919 }
2920 }
2921 return success();
2922}
2923
2924//===----------------------------------------------------------------------===//
2925// Cost Model
2926//===----------------------------------------------------------------------===//
2927
2928void OperationLegalizer::buildLegalizationGraph(
2929 LegalizationPatterns &anyOpLegalizerPatterns,
2931 // A mapping between an operation and a set of operations that can be used to
2932 // generate it.
2934 // A mapping between an operation and any currently invalid patterns it has.
2936 // A worklist of patterns to consider for legality.
2937 SetVector<const Pattern *> patternWorklist;
2938
2939 // Build the mapping from operations to the parent ops that may generate them.
2940 applicator.walkAllPatterns([&](const Pattern &pattern) {
2941 std::optional<OperationName> root = pattern.getRootKind();
2942
2943 // If the pattern has no specific root, we can't analyze the relationship
2944 // between the root op and generated operations. Given that, add all such
2945 // patterns to the legalization set.
2946 if (!root) {
2947 anyOpLegalizerPatterns.push_back(&pattern);
2948 return;
2949 }
2950
2951 // Skip operations that are always known to be legal.
2952 if (target.getOpAction(*root) == LegalizationAction::Legal)
2953 return;
2954
2955 // Add this pattern to the invalid set for the root op and record this root
2956 // as a parent for any generated operations.
2957 invalidPatterns[*root].insert(&pattern);
2958 for (auto op : pattern.getGeneratedOps())
2959 parentOps[op].insert(*root);
2960
2961 // Add this pattern to the worklist.
2962 patternWorklist.insert(&pattern);
2963 });
2964
2965 // If there are any patterns that don't have a specific root kind, we can't
2966 // make direct assumptions about what operations will never be legalized.
2967 // Note: Technically we could, but it would require an analysis that may
2968 // recurse into itself. It would be better to perform this kind of filtering
2969 // at a higher level than here anyways.
2970 if (!anyOpLegalizerPatterns.empty()) {
2971 for (const Pattern *pattern : patternWorklist)
2972 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2973 return;
2974 }
2975
2976 while (!patternWorklist.empty()) {
2977 auto *pattern = patternWorklist.pop_back_val();
2978
2979 // Check to see if any of the generated operations are invalid.
2980 if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2981 std::optional<LegalizationAction> action = target.getOpAction(op);
2982 return !legalizerPatterns.count(op) &&
2983 (!action || action == LegalizationAction::Illegal);
2984 }))
2985 continue;
2986
2987 // Otherwise, if all of the generated operation are valid, this op is now
2988 // legal so add all of the child patterns to the worklist.
2989 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2990 invalidPatterns[*pattern->getRootKind()].erase(pattern);
2991
2992 // Add any invalid patterns of the parent operations to see if they have now
2993 // become legal.
2994 for (auto op : parentOps[*pattern->getRootKind()])
2995 patternWorklist.set_union(invalidPatterns[op]);
2996 }
2997}
2998
2999void OperationLegalizer::computeLegalizationGraphBenefit(
3000 LegalizationPatterns &anyOpLegalizerPatterns,
3002 // The smallest pattern depth, when legalizing an operation.
3003 DenseMap<OperationName, unsigned> minOpPatternDepth;
3004
3005 // For each operation that is transitively legal, compute a cost for it.
3006 for (auto &opIt : legalizerPatterns)
3007 if (!minOpPatternDepth.count(opIt.first))
3008 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3009 legalizerPatterns);
3010
3011 // Apply the cost model to the patterns that can match any operation. Those
3012 // with a specific operation type are already resolved when computing the op
3013 // legalization depth.
3014 if (!anyOpLegalizerPatterns.empty())
3015 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3016 legalizerPatterns);
3017
3018 // Apply a cost model to the pattern applicator. We order patterns first by
3019 // depth then benefit. `legalizerPatterns` contains per-op patterns by
3020 // decreasing benefit.
3021 applicator.applyCostModel([&](const Pattern &pattern) {
3022 ArrayRef<const Pattern *> orderedPatternList;
3023 if (std::optional<OperationName> rootName = pattern.getRootKind())
3024 orderedPatternList = legalizerPatterns[*rootName];
3025 else
3026 orderedPatternList = anyOpLegalizerPatterns;
3027
3028 // If the pattern is not found, then it was removed and cannot be matched.
3029 auto *it = llvm::find(orderedPatternList, &pattern);
3030 if (it == orderedPatternList.end())
3032
3033 // Patterns found earlier in the list have higher benefit.
3034 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3035 });
3036}
3037
3038unsigned OperationLegalizer::computeOpLegalizationDepth(
3039 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
3041 // Check for existing depth.
3042 auto depthIt = minOpPatternDepth.find(op);
3043 if (depthIt != minOpPatternDepth.end())
3044 return depthIt->second;
3045
3046 // If a mapping for this operation does not exist, then this operation
3047 // is always legal. Return 0 as the depth for a directly legal operation.
3048 auto opPatternsIt = legalizerPatterns.find(op);
3049 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3050 return 0u;
3051
3052 // Record this initial depth in case we encounter this op again when
3053 // recursively computing the depth.
3054 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3055
3056 // Apply the cost model to the operation patterns, and update the minimum
3057 // depth.
3058 unsigned minDepth = applyCostModelToPatterns(
3059 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3060 minOpPatternDepth[op] = minDepth;
3061 return minDepth;
3062}
3063
3064unsigned OperationLegalizer::applyCostModelToPatterns(
3065 LegalizationPatterns &patterns,
3066 DenseMap<OperationName, unsigned> &minOpPatternDepth,
3068 unsigned minDepth = std::numeric_limits<unsigned>::max();
3069
3070 // Compute the depth for each pattern within the set.
3071 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3072 patternsByDepth.reserve(patterns.size());
3073 for (const Pattern *pattern : patterns) {
3074 unsigned depth = 1;
3075 for (auto generatedOp : pattern->getGeneratedOps()) {
3076 unsigned generatedOpDepth = computeOpLegalizationDepth(
3077 generatedOp, minOpPatternDepth, legalizerPatterns);
3078 depth = std::max(depth, generatedOpDepth + 1);
3079 }
3080 patternsByDepth.emplace_back(pattern, depth);
3081
3082 // Update the minimum depth of the pattern list.
3083 minDepth = std::min(minDepth, depth);
3084 }
3085
3086 // If the operation only has one legalization pattern, there is no need to
3087 // sort them.
3088 if (patternsByDepth.size() == 1)
3089 return minDepth;
3090
3091 // Sort the patterns by those likely to be the most beneficial.
3092 llvm::stable_sort(patternsByDepth,
3093 [](const std::pair<const Pattern *, unsigned> &lhs,
3094 const std::pair<const Pattern *, unsigned> &rhs) {
3095 // First sort by the smaller pattern legalization
3096 // depth.
3097 if (lhs.second != rhs.second)
3098 return lhs.second < rhs.second;
3099
3100 // Then sort by the larger pattern benefit.
3101 auto lhsBenefit = lhs.first->getBenefit();
3102 auto rhsBenefit = rhs.first->getBenefit();
3103 return lhsBenefit > rhsBenefit;
3104 });
3105
3106 // Update the legalization pattern to use the new sorted list.
3107 patterns.clear();
3108 for (auto &patternIt : patternsByDepth)
3109 patterns.push_back(patternIt.first);
3110 return minDepth;
3111}
3112
3113//===----------------------------------------------------------------------===//
3114// Reconcile Unrealized Casts
3115//===----------------------------------------------------------------------===//
3116
3117/// Try to reconcile all given UnrealizedConversionCastOps and store the
3118/// left-over ops in `remainingCastOps` (if provided). See documentation in
3119/// DialectConversion.h for more details.
3120/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3121/// algorithm may visit an operand (or user) which is a cast op, but will not
3122/// try to reconcile it if not in the filtered set.
3123template <typename RangeT>
3125 RangeT castOps,
3126 function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3128 // A worklist of cast ops to process.
3129 SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
3130
3131 // Helper function that return the unrealized_conversion_cast op that
3132 // defines all inputs of the given op (in the same order). Return "nullptr"
3133 // if there is no such op.
3134 auto getInputCast =
3135 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3136 if (castOp.getInputs().empty())
3137 return {};
3138 auto inputCastOp =
3139 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3140 if (!inputCastOp)
3141 return {};
3142 if (inputCastOp.getOutputs() != castOp.getInputs())
3143 return {};
3144 return inputCastOp;
3145 };
3146
3147 // Process ops in the worklist bottom-to-top.
3148 while (!worklist.empty()) {
3149 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3150
3151 // Traverse the chain of input cast ops to see if an op with the same
3152 // input types can be found.
3153 UnrealizedConversionCastOp nextCast = castOp;
3154 while (nextCast) {
3155 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3156 if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
3157 return v.getDefiningOp() == castOp;
3158 })) {
3159 // Ran into a cycle.
3160 break;
3161 }
3162
3163 // Found a cast where the input types match the output types of the
3164 // matched op. We can directly use those inputs.
3165 castOp.replaceAllUsesWith(nextCast.getInputs());
3166 break;
3167 }
3168 nextCast = getInputCast(nextCast);
3169 }
3170 }
3171
3172 // A set of all alive cast ops. I.e., ops whose results are (transitively)
3173 // used by an op that is not a cast op.
3174 DenseSet<Operation *> liveOps;
3175
3176 // Helper function that marks the given op and transitively reachable input
3177 // cast ops as alive.
3178 auto markOpLive = [&](Operation *rootOp) {
3179 SmallVector<Operation *> worklist;
3180 worklist.push_back(rootOp);
3181 while (!worklist.empty()) {
3182 Operation *op = worklist.pop_back_val();
3183 if (liveOps.insert(op).second) {
3184 // Successfully inserted: process reachable input cast ops.
3185 for (Value v : op->getOperands())
3186 if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3187 if (isCastOpOfInterestFn(castOp))
3188 worklist.push_back(castOp);
3189 }
3190 }
3191 };
3192
3193 // Find all alive cast ops.
3194 for (UnrealizedConversionCastOp op : castOps) {
3195 // The op may have been marked live already as being an operand of another
3196 // live cast op.
3197 if (liveOps.contains(op.getOperation()))
3198 continue;
3199 // If any of the users is not a cast op, mark the current op (and its
3200 // input ops) as live.
3201 if (llvm::any_of(op->getUsers(), [&](Operation *user) {
3202 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3203 return !castOp || !isCastOpOfInterestFn(castOp);
3204 }))
3205 markOpLive(op);
3206 }
3207
3208 // Erase all dead cast ops.
3209 for (UnrealizedConversionCastOp op : castOps) {
3210 if (liveOps.contains(op)) {
3211 // Op is alive and was not erased. Add it to the remaining cast ops.
3212 if (remainingCastOps)
3213 remainingCastOps->push_back(op);
3214 continue;
3215 }
3216
3217 // Op is dead. Erase it.
3218 op->dropAllUses();
3219 op->erase();
3220 }
3221}
3222
3224 ArrayRef<UnrealizedConversionCastOp> castOps,
3225 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3226 // Set of all cast ops for faster lookups.
3227 DenseSet<UnrealizedConversionCastOp> castOpSet;
3228 for (UnrealizedConversionCastOp op : castOps)
3229 castOpSet.insert(op);
3230 reconcileUnrealizedCasts(castOpSet, remainingCastOps);
3231}
3232
3234 const DenseSet<UnrealizedConversionCastOp> &castOps,
3235 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3237 llvm::make_range(castOps.begin(), castOps.end()),
3238 [&](UnrealizedConversionCastOp castOp) {
3239 return castOps.contains(castOp);
3240 },
3241 remainingCastOps);
3242}
3243
3244namespace mlir {
3247 &castOps,
3250 castOps.keys(),
3251 [&](UnrealizedConversionCastOp castOp) {
3252 return castOps.contains(castOp);
3253 },
3254 remainingCastOps);
3255}
3256} // namespace mlir
3257
3258//===----------------------------------------------------------------------===//
3259// OperationConverter
3260//===----------------------------------------------------------------------===//
3261
3262namespace mlir {
3263// This class converts operations to a given conversion target via a set of
3264// rewrite patterns. The conversion behaves differently depending on the
3265// conversion mode.
3268 const FrozenRewritePatternSet &patterns,
3269 const ConversionConfig &config,
3270 OpConversionMode mode)
3271 : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
3272 mode(mode) {}
3273
3274 /// Applies the conversion to the given operations (and their nested
3275 /// operations).
3276 LogicalResult applyConversion(ArrayRef<Operation *> ops);
3277
3278 /// Legalizes the given operations (and their nested operations) to the
3279 /// conversion target.
3280 template <typename Fn>
3281 LogicalResult legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
3282 bool isRecursiveLegalization = false);
3284 bool isRecursiveLegalization = false) {
3285 return legalizeOperations(
3286 ops, /*onFailure=*/[&]() {}, isRecursiveLegalization);
3287 }
3288
3289 /// Converts a single operation. If `isRecursiveLegalization` is "true", the
3290 /// conversion is a recursive legalization request, triggered from within a
3291 /// pattern. In that case, do not emit errors because there will be another
3292 /// attempt at legalizing the operation later (via the regular pre-order
3293 /// legalization mechanism).
3294 LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
3295
3296 const ConversionTarget &getTarget() { return opLegalizer.getTarget(); }
3297
3298private:
3299 /// The rewriter to use when converting operations.
3300 ConversionPatternRewriter rewriter;
3301
3302 /// The legalizer to use when converting operations.
3303 OperationLegalizer opLegalizer;
3304
3305 /// The conversion mode to use when legalizing operations.
3306 OpConversionMode mode;
3307};
3308} // namespace mlir
3309
3311 bool isRecursiveLegalization) {
3312 const ConversionConfig &config = rewriter.getConfig();
3313 auto emitFailedToLegalizeDiag = [&](bool wasExplicitlyIllegal) {
3315 << "failed to legalize operation '"
3316 << op->getName() << "'";
3317 if (wasExplicitlyIllegal)
3318 diag << " that was explicitly marked illegal";
3319 diag << ": " << OpWithFlags(op, OpPrintingFlags().skipRegions());
3320 };
3321
3322 // Legalize the given operation.
3323 if (failed(opLegalizer.legalize(op))) {
3324 // Handle the case of a failed conversion for each of the different modes.
3325 // Full conversions expect all operations to be converted.
3326 if (mode == OpConversionMode::Full) {
3327 if (!isRecursiveLegalization)
3328 emitFailedToLegalizeDiag(/*wasExplicitlyIllegal=*/false);
3329 return failure();
3330 }
3331 // Partial conversions allow conversions to fail iff the operation was not
3332 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
3333 // set, non-legalizable ops are added to that set.
3334 if (mode == OpConversionMode::Partial) {
3335 if (opLegalizer.isIllegal(op)) {
3336 if (!isRecursiveLegalization)
3337 emitFailedToLegalizeDiag(/*wasExplicitlyIllegal=*/true);
3338 return failure();
3339 }
3340 if (config.unlegalizedOps && !isRecursiveLegalization)
3341 config.unlegalizedOps->insert(op);
3342 }
3343 } else if (mode == OpConversionMode::Analysis) {
3344 // Analysis conversions don't fail if any operations fail to legalize,
3345 // they are only interested in the operations that were successfully
3346 // legalized.
3347 if (config.legalizableOps && !isRecursiveLegalization)
3348 config.legalizableOps->insert(op);
3349 }
3350 return success();
3351}
3352
3353static LogicalResult
3355 UnrealizedConversionCastOp op,
3356 const UnresolvedMaterializationInfo &info) {
3357 assert(!op.use_empty() &&
3358 "expected that dead materializations have already been DCE'd");
3359 Operation::operand_range inputOperands = op.getOperands();
3360
3361 // Try to materialize the conversion.
3362 if (const TypeConverter *converter = info.getConverter()) {
3363 rewriter.setInsertionPoint(op);
3364 SmallVector<Value> newMaterialization;
3365 switch (info.getMaterializationKind()) {
3366 case MaterializationKind::Target:
3367 newMaterialization = converter->materializeTargetConversion(
3368 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3369 info.getOriginalType());
3370 break;
3371 case MaterializationKind::Source:
3372 assert(op->getNumResults() == 1 && "expected single result");
3373 Value sourceMat = converter->materializeSourceConversion(
3374 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3375 if (sourceMat)
3376 newMaterialization.push_back(sourceMat);
3377 break;
3378 }
3379 if (!newMaterialization.empty()) {
3380#ifndef NDEBUG
3381 ValueRange newMaterializationRange(newMaterialization);
3382 assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
3383 "materialization callback produced value of incorrect type");
3384#endif // NDEBUG
3385 rewriter.replaceOp(op, newMaterialization);
3386 return success();
3387 }
3388 }
3389
3390 InFlightDiagnostic diag = op->emitError()
3391 << "failed to legalize unresolved materialization "
3392 "from ("
3393 << inputOperands.getTypes() << ") to ("
3394 << op.getResultTypes()
3395 << ") that remained live after conversion";
3396 diag.attachNote(op->getUsers().begin()->getLoc())
3397 << "see existing live user here: " << *op->getUsers().begin();
3398 return failure();
3399}
3400
3401template <typename Fn>
3402LogicalResult
3404 bool isRecursiveLegalization) {
3405 const ConversionTarget &target = opLegalizer.getTarget();
3406
3407 // Compute the set of operations and blocks to convert.
3408 SmallVector<Operation *> toConvert;
3409 for (Operation *op : ops) {
3411 [&](Operation *op) {
3412 toConvert.push_back(op);
3413 // Don't check this operation's children for conversion if the
3414 // operation is recursively legal.
3415 auto legalityInfo = target.isLegal(op);
3416 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3417 return WalkResult::skip();
3418 return WalkResult::advance();
3419 });
3420 }
3421 for (Operation *op : toConvert) {
3422 if (failed(convert(op, isRecursiveLegalization))) {
3423 // Failed to convert an operation.
3424 onFailure();
3425 return failure();
3426 }
3427 }
3428 return success();
3429}
3430
3431LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
3432 return impl->opConverter.legalizeOperations(op,
3433 /*isRecursiveLegalization=*/true);
3434}
3435
3436LogicalResult ConversionPatternRewriter::legalize(Region *r) {
3437 // Fast path: If the region is empty, there is nothing to legalize.
3438 if (r->empty())
3439 return success();
3440
3441 // Gather a list of all operations to legalize. This is done before
3442 // converting the entry block signature because unrealized_conversion_cast
3443 // ops should not be included.
3445 for (Block &b : *r)
3446 for (Operation &op : b)
3447 ops.push_back(&op);
3448
3449 // If the current pattern runs with a type converter, convert the entry block
3450 // signature.
3451 if (const TypeConverter *converter = impl->currentTypeConverter) {
3452 std::optional<TypeConverter::SignatureConversion> conversion =
3453 converter->convertBlockSignature(&r->front());
3454 if (!conversion)
3455 return failure();
3456 applySignatureConversion(&r->front(), *conversion, converter);
3457 }
3458
3459 // Legalize all operations in the region. This includes all nested
3460 // operations.
3461 return impl->opConverter.legalizeOperations(ops,
3462 /*isRecursiveLegalization=*/true);
3463}
3464
3466 // Convert each operation and discard rewrites on failure.
3467 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
3468
3469 LogicalResult status = legalizeOperations(ops, /*onFailure=*/[&]() {
3470 // Dialect conversion failed.
3471 if (rewriterImpl.config.allowPatternRollback) {
3472 // Rollback is allowed: restore the original IR.
3473 rewriterImpl.undoRewrites();
3474 } else {
3475 // Rollback is not allowed: apply all modifications that have been
3476 // performed so far.
3477 rewriterImpl.applyRewrites();
3478 }
3479 });
3480 if (failed(status))
3481 return failure();
3482
3483 // After a successful conversion, apply rewrites.
3484 rewriterImpl.applyRewrites();
3485
3486 // Reconcile all UnrealizedConversionCastOps that were inserted by the
3487 // dialect conversion frameworks. (Not the ones that were inserted by
3488 // patterns.)
3490 &materializations = rewriterImpl.unresolvedMaterializations;
3492 reconcileUnrealizedCasts(materializations, &remainingCastOps);
3493
3494 // Drop markers.
3495 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3496 castOp->removeAttr(kPureTypeConversionMarker);
3497
3498 // Try to legalize all unresolved materializations.
3499 if (rewriter.getConfig().buildMaterializations) {
3500 // Use a new rewriter, so the modifications are not tracked for rollback
3501 // purposes etc.
3502 IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
3503 rewriter.getConfig().listener);
3504 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3505 auto it = materializations.find(castOp);
3506 assert(it != materializations.end() && "inconsistent state");
3507 if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
3508 it->second)))
3509 return failure();
3510 }
3511 }
3512
3513 return success();
3514}
3515
3516//===----------------------------------------------------------------------===//
3517// Type Conversion
3518//===----------------------------------------------------------------------===//
3519
3520void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
3521 ArrayRef<Type> types) {
3522 assert(!types.empty() && "expected valid types");
3523 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
3524 addInputs(types);
3525}
3526
3527void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
3528 assert(!types.empty() &&
3529 "1->0 type remappings don't need to be added explicitly");
3530 argTypes.append(types.begin(), types.end());
3531}
3532
3533void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3534 unsigned newInputNo,
3535 unsigned newInputCount) {
3536 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3537 assert(newInputCount != 0 && "expected valid input count");
3538 remappedInputs[origInputNo] =
3539 InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
3540}
3541
3542void TypeConverter::SignatureConversion::remapInput(
3543 unsigned origInputNo, ArrayRef<Value> replacements) {
3544 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3545 remappedInputs[origInputNo] = InputMapping{
3546 origInputNo, /*size=*/0,
3547 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3548}
3549
3550/// Internal implementation of the type conversion.
3551/// This is used with either a Type or a Value as the first argument.
3552/// - we can cache the context-free conversions until the last registered
3553/// context-aware conversion.
3554/// - we can't cache the result of type conversion happening after context-aware
3555/// conversions, because the type converter may return different results for the
3556/// same input type.
3557LogicalResult
3558TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3559 SmallVectorImpl<Type> &results) const {
3560 assert(typeOrValue && "expected non-null type");
3561 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3562 : cast<Type>(typeOrValue);
3563 {
3564 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3565 std::defer_lock);
3567 cacheReadLock.lock();
3568 auto existingIt = cachedDirectConversions.find(t);
3569 if (existingIt != cachedDirectConversions.end()) {
3570 if (existingIt->second)
3571 results.push_back(existingIt->second);
3572 return success(existingIt->second != nullptr);
3573 }
3574 auto multiIt = cachedMultiConversions.find(t);
3575 if (multiIt != cachedMultiConversions.end()) {
3576 results.append(multiIt->second.begin(), multiIt->second.end());
3577 return success();
3578 }
3579 }
3580 // Walk the added converters in reverse order to apply the most recently
3581 // registered first.
3582 size_t currentCount = results.size();
3583
3584 // We can cache the context-free conversions until the last registered
3585 // context-aware conversion. But only if we're processing a Value right now.
3586 auto isCacheable = [&](int index) {
3587 int numberOfConversionsUntilContextAware =
3588 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3589 return index < numberOfConversionsUntilContextAware;
3590 };
3591
3592 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3593 std::defer_lock);
3594
3595 for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3596 const ConversionCallbackFn &converter = indexedConverter.value();
3597 std::optional<LogicalResult> result = converter(typeOrValue, results);
3598 if (!result) {
3599 assert(results.size() == currentCount &&
3600 "failed type conversion should not change results");
3601 continue;
3602 }
3603 if (!isCacheable(indexedConverter.index()))
3604 return success();
3606 cacheWriteLock.lock();
3607 if (!succeeded(*result)) {
3608 assert(results.size() == currentCount &&
3609 "failed type conversion should not change results");
3610 cachedDirectConversions.try_emplace(t, nullptr);
3611 return failure();
3612 }
3613 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3614 if (newTypes.size() == 1)
3615 cachedDirectConversions.try_emplace(t, newTypes.front());
3616 else
3617 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3618 return success();
3619 }
3620 return failure();
3621}
3622
3623LogicalResult TypeConverter::convertType(Type t,
3624 SmallVectorImpl<Type> &results) const {
3625 return convertTypeImpl(t, results);
3626}
3627
3628LogicalResult TypeConverter::convertType(Value v,
3629 SmallVectorImpl<Type> &results) const {
3630 return convertTypeImpl(v, results);
3631}
3632
3633Type TypeConverter::convertType(Type t) const {
3634 // Use the multi-type result version to convert the type.
3635 SmallVector<Type, 1> results;
3636 if (failed(convertType(t, results)))
3637 return nullptr;
3638
3639 // Check to ensure that only one type was produced.
3640 return results.size() == 1 ? results.front() : nullptr;
3641}
3642
3643Type TypeConverter::convertType(Value v) const {
3644 // Use the multi-type result version to convert the type.
3645 SmallVector<Type, 1> results;
3646 if (failed(convertType(v, results)))
3647 return nullptr;
3648
3649 // Check to ensure that only one type was produced.
3650 return results.size() == 1 ? results.front() : nullptr;
3651}
3652
3653LogicalResult
3654TypeConverter::convertTypes(TypeRange types,
3655 SmallVectorImpl<Type> &results) const {
3656 for (Type type : types)
3657 if (failed(convertType(type, results)))
3658 return failure();
3659 return success();
3660}
3661
3662LogicalResult
3663TypeConverter::convertTypes(ValueRange values,
3664 SmallVectorImpl<Type> &results) const {
3665 for (Value value : values)
3666 if (failed(convertType(value, results)))
3667 return failure();
3668 return success();
3669}
3670
3671bool TypeConverter::isLegal(Type type) const {
3672 return convertType(type) == type;
3673}
3674
3675bool TypeConverter::isLegal(Value value) const {
3676 return convertType(value) == value.getType();
3677}
3678
3679bool TypeConverter::isLegal(Operation *op) const {
3680 return isLegal(op->getOperands()) && isLegal(op->getResults());
3681}
3682
3683bool TypeConverter::isLegal(Region *region) const {
3684 return llvm::all_of(
3685 *region, [this](Block &block) { return isLegal(block.getArguments()); });
3686}
3687
3688bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3689 if (!isLegal(ty.getInputs()))
3690 return false;
3691 if (!isLegal(ty.getResults()))
3692 return false;
3693 return true;
3694}
3695
3696LogicalResult
3697TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
3698 SignatureConversion &result) const {
3699 // Try to convert the given input type.
3700 SmallVector<Type, 1> convertedTypes;
3701 if (failed(convertType(type, convertedTypes)))
3702 return failure();
3703
3704 // If this argument is being dropped, there is nothing left to do.
3705 if (convertedTypes.empty())
3706 return success();
3707
3708 // Otherwise, add the new inputs.
3709 result.addInputs(inputNo, convertedTypes);
3710 return success();
3711}
3712LogicalResult
3713TypeConverter::convertSignatureArgs(TypeRange types,
3714 SignatureConversion &result,
3715 unsigned origInputOffset) const {
3716 for (unsigned i = 0, e = types.size(); i != e; ++i)
3717 if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3718 return failure();
3719 return success();
3720}
3721LogicalResult
3722TypeConverter::convertSignatureArg(unsigned inputNo, Value value,
3723 SignatureConversion &result) const {
3724 // Try to convert the given input type.
3725 SmallVector<Type, 1> convertedTypes;
3726 if (failed(convertType(value, convertedTypes)))
3727 return failure();
3728
3729 // If this argument is being dropped, there is nothing left to do.
3730 if (convertedTypes.empty())
3731 return success();
3732
3733 // Otherwise, add the new inputs.
3734 result.addInputs(inputNo, convertedTypes);
3735 return success();
3736}
3737LogicalResult
3738TypeConverter::convertSignatureArgs(ValueRange values,
3739 SignatureConversion &result,
3740 unsigned origInputOffset) const {
3741 for (unsigned i = 0, e = values.size(); i != e; ++i)
3742 if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
3743 return failure();
3744 return success();
3745}
3746
3747Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3748 Location loc, Type resultType,
3749 ValueRange inputs) const {
3750 for (const SourceMaterializationCallbackFn &fn :
3751 llvm::reverse(sourceMaterializations))
3752 if (Value result = fn(builder, resultType, inputs, loc))
3753 return result;
3754 return nullptr;
3755}
3756
3757Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3758 Location loc, Type resultType,
3759 ValueRange inputs,
3760 Type originalType) const {
3761 SmallVector<Value> result = materializeTargetConversion(
3762 builder, loc, TypeRange(resultType), inputs, originalType);
3763 if (result.empty())
3764 return nullptr;
3765 assert(result.size() == 1 && "expected single result");
3766 return result.front();
3767}
3768
3769SmallVector<Value> TypeConverter::materializeTargetConversion(
3770 OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3771 Type originalType) const {
3772 for (const TargetMaterializationCallbackFn &fn :
3773 llvm::reverse(targetMaterializations)) {
3774 SmallVector<Value> result =
3775 fn(builder, resultTypes, inputs, loc, originalType);
3776 if (result.empty())
3777 continue;
3778 assert(TypeRange(ValueRange(result)) == resultTypes &&
3779 "callback produced incorrect number of values or values with "
3780 "incorrect types");
3781 return result;
3782 }
3783 return {};
3784}
3785
3786std::optional<TypeConverter::SignatureConversion>
3787TypeConverter::convertBlockSignature(Block *block) const {
3788 SignatureConversion conversion(block->getNumArguments());
3789 if (failed(convertSignatureArgs(block->getArguments(), conversion)))
3790 return std::nullopt;
3791 return conversion;
3792}
3793
3794//===----------------------------------------------------------------------===//
3795// Type attribute conversion
3796//===----------------------------------------------------------------------===//
3797TypeConverter::AttributeConversionResult
3798TypeConverter::AttributeConversionResult::result(Attribute attr) {
3799 return AttributeConversionResult(attr, resultTag);
3800}
3801
3802TypeConverter::AttributeConversionResult
3803TypeConverter::AttributeConversionResult::na() {
3804 return AttributeConversionResult(nullptr, naTag);
3805}
3806
3807TypeConverter::AttributeConversionResult
3808TypeConverter::AttributeConversionResult::abort() {
3809 return AttributeConversionResult(nullptr, abortTag);
3810}
3811
3812bool TypeConverter::AttributeConversionResult::hasResult() const {
3813 return impl.getInt() == resultTag;
3814}
3815
3816bool TypeConverter::AttributeConversionResult::isNa() const {
3817 return impl.getInt() == naTag;
3818}
3819
3820bool TypeConverter::AttributeConversionResult::isAbort() const {
3821 return impl.getInt() == abortTag;
3822}
3823
3824Attribute TypeConverter::AttributeConversionResult::getResult() const {
3825 assert(hasResult() && "Cannot get result from N/A or abort");
3826 return impl.getPointer();
3827}
3828
3829std::optional<Attribute>
3830TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
3831 for (const TypeAttributeConversionCallbackFn &fn :
3832 llvm::reverse(typeAttributeConversions)) {
3833 AttributeConversionResult res = fn(type, attr);
3834 if (res.hasResult())
3835 return res.getResult();
3836 if (res.isAbort())
3837 return std::nullopt;
3838 }
3839 return std::nullopt;
3840}
3841
3842//===----------------------------------------------------------------------===//
3843// FunctionOpInterfaceSignatureConversion
3844//===----------------------------------------------------------------------===//
3845
3846static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3847 const TypeConverter &typeConverter,
3848 ConversionPatternRewriter &rewriter) {
3849 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3850 if (!type)
3851 return failure();
3852
3853 // Convert the function signature (inputs and results).
3854 TypeConverter::SignatureConversion funcConversion(type.getNumInputs());
3855 SmallVector<Type, 1> newResults;
3856 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
3857 funcConversion)) ||
3858 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3859 return failure();
3860
3861 // If the function has a body, apply a separate signature conversion to the
3862 // entry block. Some function ops (e.g., gpu.func) have extra block arguments
3863 // beyond the function type inputs (e.g., workgroup memory arguments that are
3864 // not part of the public signature). Use a distinct conversion sized for all
3865 // entry block arguments so that applySignatureConversion does not access
3866 // out-of-bounds mappings.
3867 if (!funcOp.getFunctionBody().empty()) {
3868 Block *entryBlock = &funcOp.getFunctionBody().front();
3869 unsigned numEntryBlockArgs = entryBlock->getNumArguments();
3870 unsigned numFuncTypeInputs = type.getNumInputs();
3871 TypeConverter::SignatureConversion blockConversion(numEntryBlockArgs);
3872 // Convert the function-type inputs the same way as for the function type.
3873 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
3874 blockConversion)))
3875 return failure();
3876 // Add identity mappings for extra block args beyond the function type
3877 // inputs. These arguments are preserved as-is.
3878 for (unsigned i = numFuncTypeInputs; i < numEntryBlockArgs; ++i)
3879 blockConversion.addInputs(i, entryBlock->getArgument(i).getType());
3880 rewriter.applySignatureConversion(entryBlock, blockConversion,
3881 &typeConverter);
3882 }
3883
3884 auto newType = FunctionType::get(
3885 rewriter.getContext(), funcConversion.getConvertedTypes(), newResults);
3886
3887 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3888
3889 return success();
3890}
3891
3892/// Create a default conversion pattern that rewrites the type signature of a
3893/// FunctionOpInterface op. This only supports ops which use FunctionType to
3894/// represent their type.
3895namespace {
3896struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3897 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3898 MLIRContext *ctx,
3899 const TypeConverter &converter,
3900 PatternBenefit benefit)
3901 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3902
3903 LogicalResult
3904 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3905 ConversionPatternRewriter &rewriter) const override {
3906 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3907 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3908 }
3909};
3910
3911struct AnyFunctionOpInterfaceSignatureConversion
3912 : public OpInterfaceConversionPattern<FunctionOpInterface> {
3913 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3914
3915 LogicalResult
3916 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3917 ConversionPatternRewriter &rewriter) const override {
3918 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3919 }
3920};
3921} // namespace
3922
3923FailureOr<Operation *>
3924mlir::convertOpResultTypes(Operation *op, ValueRange operands,
3925 const TypeConverter &converter,
3926 ConversionPatternRewriter &rewriter) {
3927 assert(op && "Invalid op");
3928 Location loc = op->getLoc();
3929 if (converter.isLegal(op))
3930 return rewriter.notifyMatchFailure(loc, "op already legal");
3931
3932 OperationState newOp(loc, op->getName());
3933 newOp.addOperands(operands);
3934
3935 SmallVector<Type> newResultTypes;
3936 if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
3937 return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3938
3939 newOp.addTypes(newResultTypes);
3940 newOp.addAttributes(op->getAttrs());
3941 return rewriter.create(newOp);
3942}
3943
3944void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3945 StringRef functionLikeOpName, RewritePatternSet &patterns,
3946 const TypeConverter &converter, PatternBenefit benefit) {
3947 patterns.add<FunctionOpInterfaceSignatureConversion>(
3948 functionLikeOpName, patterns.getContext(), converter, benefit);
3949}
3950
3951void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3952 RewritePatternSet &patterns, const TypeConverter &converter,
3953 PatternBenefit benefit) {
3954 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3955 converter, patterns.getContext(), benefit);
3956}
3957
3958//===----------------------------------------------------------------------===//
3959// ConversionTarget
3960//===----------------------------------------------------------------------===//
3961
3962void ConversionTarget::setOpAction(OperationName op,
3963 LegalizationAction action) {
3964 legalOperations[op].action = action;
3965}
3966
3967void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3968 LegalizationAction action) {
3969 for (StringRef dialect : dialectNames)
3970 legalDialects[dialect] = action;
3971}
3972
3973auto ConversionTarget::getOpAction(OperationName op) const
3974 -> std::optional<LegalizationAction> {
3975 std::optional<LegalizationInfo> info = getOpInfo(op);
3976 return info ? info->action : std::optional<LegalizationAction>();
3977}
3978
3979auto ConversionTarget::isLegal(Operation *op) const
3980 -> std::optional<LegalOpDetails> {
3981 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3982 if (!info)
3983 return std::nullopt;
3984
3985 // Returns true if this operation instance is known to be legal.
3986 auto isOpLegal = [&] {
3987 // Handle dynamic legality either with the provided legality function.
3988 if (info->action == LegalizationAction::Dynamic) {
3989 std::optional<bool> result = info->legalityFn(op);
3990 if (result)
3991 return *result;
3992 }
3993
3994 // Otherwise, the operation is only legal if it was marked 'Legal'.
3995 return info->action == LegalizationAction::Legal;
3996 };
3997 if (!isOpLegal())
3998 return std::nullopt;
3999
4000 // This operation is legal, compute any additional legality information.
4001 LegalOpDetails legalityDetails;
4002 if (info->isRecursivelyLegal) {
4003 auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
4004 if (legalityFnIt != opRecursiveLegalityFns.end()) {
4005 legalityDetails.isRecursivelyLegal =
4006 legalityFnIt->second(op).value_or(true);
4007 } else {
4008 legalityDetails.isRecursivelyLegal = true;
4009 }
4010 }
4011 return legalityDetails;
4012}
4013
4014bool ConversionTarget::isIllegal(Operation *op) const {
4015 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
4016 if (!info)
4017 return false;
4018
4019 if (info->action == LegalizationAction::Dynamic) {
4020 std::optional<bool> result = info->legalityFn(op);
4021 if (!result)
4022 return false;
4023
4024 return !(*result);
4025 }
4026
4027 return info->action == LegalizationAction::Illegal;
4028}
4029
4030static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
4031 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
4032 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
4033 if (!oldCallback)
4034 return newCallback;
4035
4036 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
4037 Operation *op) -> std::optional<bool> {
4038 if (std::optional<bool> result = newCl(op))
4039 return *result;
4040
4041 return oldCl(op);
4042 };
4043 return chain;
4044}
4045
4046void ConversionTarget::setLegalityCallback(
4047 OperationName name, const DynamicLegalityCallbackFn &callback) {
4048 assert(callback && "expected valid legality callback");
4049 auto *infoIt = legalOperations.find(name);
4050 assert(infoIt != legalOperations.end() &&
4051 infoIt->second.action == LegalizationAction::Dynamic &&
4052 "expected operation to already be marked as dynamically legal");
4053 infoIt->second.legalityFn =
4054 composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
4055}
4056
4057void ConversionTarget::markOpRecursivelyLegal(
4058 OperationName name, const DynamicLegalityCallbackFn &callback) {
4059 auto *infoIt = legalOperations.find(name);
4060 assert(infoIt != legalOperations.end() &&
4061 infoIt->second.action != LegalizationAction::Illegal &&
4062 "expected operation to already be marked as legal");
4063 infoIt->second.isRecursivelyLegal = true;
4064 if (callback)
4065 opRecursiveLegalityFns[name] = composeLegalityCallbacks(
4066 std::move(opRecursiveLegalityFns[name]), callback);
4067 else
4068 opRecursiveLegalityFns.erase(name);
4069}
4070
4071void ConversionTarget::setLegalityCallback(
4072 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
4073 assert(callback && "expected valid legality callback");
4074 for (StringRef dialect : dialects)
4075 dialectLegalityFns[dialect] = composeLegalityCallbacks(
4076 std::move(dialectLegalityFns[dialect]), callback);
4077}
4078
4079void ConversionTarget::setLegalityCallback(
4080 const DynamicLegalityCallbackFn &callback) {
4081 assert(callback && "expected valid legality callback");
4082 unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
4083}
4084
4085auto ConversionTarget::getOpInfo(OperationName op) const
4086 -> std::optional<LegalizationInfo> {
4087 // Check for info for this specific operation.
4088 const auto *it = legalOperations.find(op);
4089 if (it != legalOperations.end())
4090 return it->second;
4091 // Check for info for the parent dialect.
4092 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4093 if (dialectIt != legalDialects.end()) {
4094 DynamicLegalityCallbackFn callback;
4095 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4096 if (dialectFn != dialectLegalityFns.end())
4097 callback = dialectFn->second;
4098 return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
4099 callback};
4100 }
4101 // Otherwise, check if we mark unknown operations as dynamic.
4102 if (unknownLegalityFn)
4103 return LegalizationInfo{LegalizationAction::Dynamic,
4104 /*isRecursivelyLegal=*/false, unknownLegalityFn};
4105 return std::nullopt;
4106}
4107
4108#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4109//===----------------------------------------------------------------------===//
4110// PDL Configuration
4111//===----------------------------------------------------------------------===//
4112
4113void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4114 auto &rewriterImpl =
4115 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4116 rewriterImpl.currentTypeConverter = getTypeConverter();
4117}
4118
4119void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4120 auto &rewriterImpl =
4121 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4122 rewriterImpl.currentTypeConverter = nullptr;
4123}
4124
4125/// Remap the given value using the rewriter and the type converter in the
4126/// provided config.
4127static FailureOr<SmallVector<Value>>
4128pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
4129 SmallVector<Value> mappedValues;
4130 if (failed(rewriter.getRemappedValues(values, mappedValues)))
4131 return failure();
4132 return std::move(mappedValues);
4133}
4134
4135void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
4136 patterns.getPDLPatterns().registerRewriteFunction(
4137 "convertValue",
4138 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4139 auto results = pdllConvertValues(
4140 static_cast<ConversionPatternRewriter &>(rewriter), value);
4141 if (failed(results))
4142 return failure();
4143 return results->front();
4144 });
4145 patterns.getPDLPatterns().registerRewriteFunction(
4146 "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
4147 return pdllConvertValues(
4148 static_cast<ConversionPatternRewriter &>(rewriter), values);
4149 });
4150 patterns.getPDLPatterns().registerRewriteFunction(
4151 "convertType",
4152 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4153 auto &rewriterImpl =
4154 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4155 if (const TypeConverter *converter =
4156 rewriterImpl.currentTypeConverter) {
4157 if (Type newType = converter->convertType(type))
4158 return newType;
4159 return failure();
4160 }
4161 return type;
4162 });
4163 patterns.getPDLPatterns().registerRewriteFunction(
4164 "convertTypes",
4165 [](PatternRewriter &rewriter,
4166 TypeRange types) -> FailureOr<SmallVector<Type>> {
4167 auto &rewriterImpl =
4168 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4169 const TypeConverter *converter = rewriterImpl.currentTypeConverter;
4170 if (!converter)
4171 return SmallVector<Type>(types);
4172
4173 SmallVector<Type> remappedTypes;
4174 if (failed(converter->convertTypes(types, remappedTypes)))
4175 return failure();
4176 return std::move(remappedTypes);
4177 });
4178}
4179#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
4180
4181//===----------------------------------------------------------------------===//
4182// Op Conversion Entry Points
4183//===----------------------------------------------------------------------===//
4184
4185/// This is the type of Action that is dispatched when a conversion is applied.
4187 : public tracing::ActionImpl<ApplyConversionAction> {
4188public:
4191 static constexpr StringLiteral tag = "apply-conversion";
4192 static constexpr StringLiteral desc =
4193 "Encapsulate the application of a dialect conversion";
4194
4195 void print(raw_ostream &os) const override { os << tag; }
4196};
4197
4199 const ConversionTarget &target,
4200 const FrozenRewritePatternSet &patterns,
4201 ConversionConfig config,
4202 OpConversionMode mode) {
4203 if (ops.empty())
4204 return success();
4205 MLIRContext *ctx = ops.front()->getContext();
4206 LogicalResult status = success();
4207 SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
4209 [&] {
4210 OperationConverter opConverter(ops.front()->getContext(), target,
4211 patterns, config, mode);
4212 status = opConverter.applyConversion(ops);
4213 },
4214 irUnits);
4215 return status;
4216}
4217
4218//===----------------------------------------------------------------------===//
4219// Partial Conversion
4220//===----------------------------------------------------------------------===//
4221
4222LogicalResult mlir::applyPartialConversion(
4223 ArrayRef<Operation *> ops, const ConversionTarget &target,
4224 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4225 return applyConversion(ops, target, patterns, config,
4226 OpConversionMode::Partial);
4227}
4228LogicalResult
4229mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
4230 const FrozenRewritePatternSet &patterns,
4231 ConversionConfig config) {
4232 return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
4233}
4234
4235//===----------------------------------------------------------------------===//
4236// Full Conversion
4237//===----------------------------------------------------------------------===//
4238
4239LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4240 const ConversionTarget &target,
4241 const FrozenRewritePatternSet &patterns,
4242 ConversionConfig config) {
4243 return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
4244}
4245LogicalResult mlir::applyFullConversion(Operation *op,
4246 const ConversionTarget &target,
4247 const FrozenRewritePatternSet &patterns,
4248 ConversionConfig config) {
4249 return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
4250}
4251
4252//===----------------------------------------------------------------------===//
4253// Analysis Conversion
4254//===----------------------------------------------------------------------===//
4255
4256/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
4257/// op is a top-level module op (which is expected to be isolated from above),
4258/// return that op.
4260 // Check if there is a top-level operation within `ops`. If so, return that
4261 // op.
4262 for (Operation *op : ops) {
4263 if (!op->getParentOp()) {
4264#ifndef NDEBUG
4265 assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
4266 "expected top-level op to be isolated from above");
4267 for (Operation *other : ops)
4268 assert(op->isAncestor(other) &&
4269 "expected ops to have a common ancestor");
4270#endif // NDEBUG
4271 return op;
4272 }
4273 }
4274
4275 // No top-level op. Find a common ancestor.
4276 Operation *commonAncestor =
4277 ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
4278 for (Operation *op : ops.drop_front()) {
4279 while (!commonAncestor->isProperAncestor(op)) {
4280 commonAncestor =
4282 assert(commonAncestor &&
4283 "expected to find a common isolated from above ancestor");
4284 }
4285 }
4286
4287 return commonAncestor;
4288}
4289
4290LogicalResult mlir::applyAnalysisConversion(
4291 ArrayRef<Operation *> ops, ConversionTarget &target,
4292 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4293#ifndef NDEBUG
4294 if (config.legalizableOps)
4295 assert(config.legalizableOps->empty() && "expected empty set");
4296#endif // NDEBUG
4297
4298 // Clone closted common ancestor that is isolated from above.
4299 Operation *commonAncestor = findCommonAncestor(ops);
4300 IRMapping mapping;
4301 Operation *clonedAncestor = commonAncestor->clone(mapping);
4302 // Compute inverse IR mapping.
4303 DenseMap<Operation *, Operation *> inverseOperationMap;
4304 for (auto &it : mapping.getOperationMap())
4305 inverseOperationMap[it.second] = it.first;
4306
4307 // Convert the cloned operations. The original IR will remain unchanged.
4308 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4309 ops, [&](Operation *op) { return mapping.lookup(op); });
4310 LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
4311 OpConversionMode::Analysis);
4312
4313 // Remap `legalizableOps`, so that they point to the original ops and not the
4314 // cloned ops.
4315 if (config.legalizableOps) {
4316 DenseSet<Operation *> originalLegalizableOps;
4317 for (Operation *op : *config.legalizableOps)
4318 originalLegalizableOps.insert(inverseOperationMap[op]);
4319 *config.legalizableOps = std::move(originalLegalizableOps);
4320 }
4321
4322 // Erase the cloned IR.
4323 clonedAncestor->erase();
4324 return status;
4325}
4326
4327LogicalResult
4328mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
4329 const FrozenRewritePatternSet &patterns,
4330 ConversionConfig config) {
4331 return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
4332}
return success()
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
SmallVector< Value, 2 > ValueVector
A vector of SSA values, optimized for the most common case of one or two values.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
#define DEBUG_TYPE
b getContext())
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition SCCP.cpp:67
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Definition Value.h:306
Location getLoc() const
Return the location for this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:318
BlockArgListType getArguments()
Definition Block.h:97
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
UnitAttr getUnitAttr()
Definition Builders.cpp:102
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
MLIRContext * context
Definition Builders.h:204
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition Builders.h:329
Block::iterator getPoint() const
Definition Builders.h:342
bool isSet() const
Returns true if this insert point is set.
Definition Builders.h:339
Block * getBlock() const
Definition Builders.h:341
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition Builders.cpp:477
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:425
This class represents an operand of an operation.
Definition Value.h:254
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition Value.h:454
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1143
type_range getTypes() const
void destroyOpProperties(PropertyRef properties) const
This hooks destroy the op properties.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
TypeID getOpPropertiesTypeID() const
Return the TypeID of the op properties.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
PropertyRef getPropertiesStorage()
Return a generic (but typed) reference to the property type storage.
Definition Operation.h:927
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:244
void copyProperties(PropertyRef rhs)
Copy properties from an existing other properties object.
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:878
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:860
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:586
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:274
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
OperandRange operand_range
Definition Operation.h:397
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:289
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
result_range getResults()
Definition Operation.h:441
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
CRTP Implementation of an action.
Definition Action.h:76
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Definition Action.h:66
AttrTypeReplacer.
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
This iterator enumerates elements according to their dominance relationship.
Definition Iterators.h:52
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition Builders.h:310
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition Builders.h:300
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
const ConversionTarget & getTarget()
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, bool isRecursiveLegalization=false)
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, Fn onFailure, bool isRecursiveLegalization=false)
Legalizes the given operations (and their nested operations) to the conversion target.
LogicalResult applyConversion(ArrayRef< Operation * > ops)
Applies the conversion to the given operations (and their nested operations).
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.