MLIR 22.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/Config/mlir-config.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Iterators.h"
17#include "mlir/IR/Operation.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/FormatVariadic.h"
25#include "llvm/Support/SaveAndRestore.h"
26#include "llvm/Support/ScopedPrinter.h"
27#include <optional>
28#include <utility>
29
30using namespace mlir;
31using namespace mlir::detail;
32
33#define DEBUG_TYPE "dialect-conversion"
34
35/// A utility function to log a successful result for the given reason.
36template <typename... Args>
37static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
38 LLVM_DEBUG({
39 os.unindent();
40 os.startLine() << "} -> SUCCESS";
41 if (!fmt.empty())
42 os.getOStream() << " : "
43 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
44 os.getOStream() << "\n";
45 });
46}
47
48/// A utility function to log a failure result for the given reason.
49template <typename... Args>
50static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
51 LLVM_DEBUG({
52 os.unindent();
53 os.startLine() << "} -> FAILURE : "
54 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
55 << "\n";
56 });
57}
58
59/// Helper function that computes an insertion point where the given value is
60/// defined and can be used without a dominance violation.
62 Block *insertBlock = value.getParentBlock();
63 Block::iterator insertPt = insertBlock->begin();
64 if (OpResult inputRes = dyn_cast<OpResult>(value))
65 insertPt = ++inputRes.getOwner()->getIterator();
66 return OpBuilder::InsertPoint(insertBlock, insertPt);
67}
68
69/// Helper function that computes an insertion point where the given values are
70/// defined and can be used without a dominance violation.
72 assert(!vals.empty() && "expected at least one value");
73 DominanceInfo domInfo;
75 for (Value v : vals.drop_front()) {
76 // Choose the "later" insertion point.
78 if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
79 nextPt.getPoint())) {
80 // pt is before nextPt => choose nextPt.
81 pt = nextPt;
82 } else {
83#ifndef NDEBUG
84 // nextPt should be before pt => choose pt.
85 // If pt, nextPt are no dominance relationship, then there is no valid
86 // insertion point at which all given values are defined.
87 bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
88 pt.getBlock(), pt.getPoint());
89 assert(dom && "unable to find valid insertion point");
90#endif // NDEBUG
91 }
92 }
93 return pt;
94}
95
96namespace {
97enum OpConversionMode {
98 /// In this mode, the conversion will ignore failed conversions to allow
99 /// illegal operations to co-exist in the IR.
100 Partial,
101
102 /// In this mode, all operations must be legal for the given target for the
103 /// conversion to succeed.
104 Full,
105
106 /// In this mode, operations are analyzed for legality. No actual rewrites are
107 /// applied to the operations on success.
108 Analysis,
109};
110} // namespace
111
112//===----------------------------------------------------------------------===//
113// ConversionValueMapping
114//===----------------------------------------------------------------------===//
115
116/// A vector of SSA values, optimized for the most common case of a single
117/// value.
119
120namespace {
121
122/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
123struct ValueVectorMapInfo {
124 static ValueVector getEmptyKey() { return ValueVector{Value()}; }
125 static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
126 static ::llvm::hash_code getHashValue(const ValueVector &val) {
127 return ::llvm::hash_combine_range(val);
128 }
129 static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
130 return LHS == RHS;
131 }
132};
133
134/// This class wraps a IRMapping to provide recursive lookup
135/// functionality, i.e. we will traverse if the mapped value also has a mapping.
136struct ConversionValueMapping {
137 /// Return "true" if an SSA value is mapped to the given value. May return
138 /// false positives.
139 bool isMappedTo(Value value) const { return mappedTo.contains(value); }
140
141 /// Lookup a value in the mapping.
142 ValueVector lookup(const ValueVector &from) const;
143
144 template <typename T>
145 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
146
147 /// Map a value vector to the one provided.
148 template <typename OldVal, typename NewVal>
149 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
150 map(OldVal &&oldVal, NewVal &&newVal) {
151 LLVM_DEBUG({
152 ValueVector next(newVal);
153 while (true) {
154 assert(next != oldVal && "inserting cyclic mapping");
155 auto it = mapping.find(next);
156 if (it == mapping.end())
157 break;
158 next = it->second;
159 }
160 });
161 mappedTo.insert_range(newVal);
162
163 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
164 }
165
166 /// Map a value vector or single value to the one provided.
167 template <typename OldVal, typename NewVal>
168 std::enable_if_t<!IsValueVector<OldVal>::value ||
169 !IsValueVector<NewVal>::value>
170 map(OldVal &&oldVal, NewVal &&newVal) {
171 if constexpr (IsValueVector<OldVal>{}) {
172 map(std::forward<OldVal>(oldVal), ValueVector{newVal});
173 } else if constexpr (IsValueVector<NewVal>{}) {
174 map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
175 } else {
176 map(ValueVector{oldVal}, ValueVector{newVal});
177 }
178 }
179
180 void map(Value oldVal, SmallVector<Value> &&newVal) {
181 map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
182 }
183
184 /// Drop the last mapping for the given values.
185 void erase(const ValueVector &value) { mapping.erase(value); }
186
187private:
188 /// Current value mappings.
190
191 /// All SSA values that are mapped to. May contain false positives.
192 DenseSet<Value> mappedTo;
193};
194} // namespace
195
196/// Marker attribute for pure type conversions. I.e., mappings whose only
197/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
198/// the replacement values of a "replaceOp" call, etc., are not pure type
199/// conversions.)
200static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
201
202/// Return the operation that defines all values in the vector. Return nullptr
203/// if the values are not defined by the same operation.
205 assert(!values.empty() && "expected non-empty value vector");
206 Operation *op = values.front().getDefiningOp();
207 for (Value v : llvm::drop_begin(values)) {
208 if (v.getDefiningOp() != op)
209 return nullptr;
210 }
211 return op;
212}
213
214/// A vector of values is a pure type conversion if all values are defined by
215/// the same operation and the operation has the `kPureTypeConversionMarker`
216/// attribute.
217static bool isPureTypeConversion(const ValueVector &values) {
218 assert(!values.empty() && "expected non-empty value vector");
219 Operation *op = getCommonDefiningOp(values);
220 return op && op->hasAttr(kPureTypeConversionMarker);
221}
222
223ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
224 auto it = mapping.find(from);
225 if (it == mapping.end()) {
226 // No mapping found: The lookup stops here.
227 return {};
228 }
229 return it->second;
230}
231
232//===----------------------------------------------------------------------===//
233// Rewriter and Translation State
234//===----------------------------------------------------------------------===//
235namespace {
236/// This class contains a snapshot of the current conversion rewriter state.
237/// This is useful when saving and undoing a set of rewrites.
238struct RewriterState {
239 RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
240 unsigned numReplacedOps)
241 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
242 numReplacedOps(numReplacedOps) {}
243
244 /// The current number of rewrites performed.
245 unsigned numRewrites;
246
247 /// The current number of ignored operations.
248 unsigned numIgnoredOperations;
249
250 /// The current number of replaced ops that are scheduled for erasure.
251 unsigned numReplacedOps;
252};
253
254//===----------------------------------------------------------------------===//
255// IR rewrites
256//===----------------------------------------------------------------------===//
257
258static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
259
260/// Notify the listener that the given block and its contents are being erased.
261static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
262 for (Operation &op : b)
263 notifyIRErased(listener, op);
264 listener->notifyBlockErased(&b);
265}
266
267/// Notify the listener that the given operation and its contents are being
268/// erased.
269static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
270 for (Region &r : op.getRegions()) {
271 for (Block &b : r) {
272 notifyIRErased(listener, b);
273 }
274 }
275 listener->notifyOperationErased(&op);
276}
277
278/// An IR rewrite that can be committed (upon success) or rolled back (upon
279/// failure).
280///
281/// The dialect conversion keeps track of IR modifications (requested by the
282/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
283/// are directly applied to the IR as the rewriter API is used, some are applied
284/// partially, and some are delayed until the `IRRewrite` objects are committed.
285class IRRewrite {
286public:
287 /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
288 /// Enum values are ordered, so that they can be used in `classof`: first all
289 /// block rewrites, then all operation rewrites.
290 enum class Kind {
291 // Block rewrites
292 CreateBlock,
293 EraseBlock,
294 InlineBlock,
295 MoveBlock,
296 BlockTypeConversion,
297 // Operation rewrites
298 MoveOperation,
299 ModifyOperation,
300 ReplaceOperation,
301 CreateOperation,
302 UnresolvedMaterialization,
303 // Value rewrites
304 ReplaceValue
305 };
306
307 virtual ~IRRewrite() = default;
308
309 /// Roll back the rewrite. Operations may be erased during rollback.
310 virtual void rollback() = 0;
311
312 /// Commit the rewrite. At this point, it is certain that the dialect
313 /// conversion will succeed. All IR modifications, except for operation/block
314 /// erasure, must be performed through the given rewriter.
315 ///
316 /// Instead of erasing operations/blocks, they should merely be unlinked
317 /// commit phase and finally be erased during the cleanup phase. This is
318 /// because internal dialect conversion state (such as `mapping`) may still
319 /// be using them.
320 ///
321 /// Any IR modification that was already performed before the commit phase
322 /// (e.g., insertion of an op) must be communicated to the listener that may
323 /// be attached to the given rewriter.
324 virtual void commit(RewriterBase &rewriter) {}
325
326 /// Cleanup operations/blocks. Cleanup is called after commit.
327 virtual void cleanup(RewriterBase &rewriter) {}
328
329 Kind getKind() const { return kind; }
330
331 static bool classof(const IRRewrite *rewrite) { return true; }
332
333protected:
334 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
335 : kind(kind), rewriterImpl(rewriterImpl) {}
336
337 const ConversionConfig &getConfig() const;
338
339 const Kind kind;
340 ConversionPatternRewriterImpl &rewriterImpl;
341};
342
343/// A block rewrite.
344class BlockRewrite : public IRRewrite {
345public:
346 /// Return the block that this rewrite operates on.
347 Block *getBlock() const { return block; }
348
349 static bool classof(const IRRewrite *rewrite) {
350 return rewrite->getKind() >= Kind::CreateBlock &&
351 rewrite->getKind() <= Kind::BlockTypeConversion;
352 }
353
354protected:
355 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
356 Block *block)
357 : IRRewrite(kind, rewriterImpl), block(block) {}
358
359 // The block that this rewrite operates on.
360 Block *block;
361};
362
363/// A value rewrite.
364class ValueRewrite : public IRRewrite {
365public:
366 /// Return the value that this rewrite operates on.
367 Value getValue() const { return value; }
368
369 static bool classof(const IRRewrite *rewrite) {
370 return rewrite->getKind() == Kind::ReplaceValue;
371 }
372
373protected:
374 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
375 Value value)
376 : IRRewrite(kind, rewriterImpl), value(value) {}
377
378 // The value that this rewrite operates on.
379 Value value;
380};
381
382/// Creation of a block. Block creations are immediately reflected in the IR.
383/// There is no extra work to commit the rewrite. During rollback, the newly
384/// created block is erased.
385class CreateBlockRewrite : public BlockRewrite {
386public:
387 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
388 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
389
390 static bool classof(const IRRewrite *rewrite) {
391 return rewrite->getKind() == Kind::CreateBlock;
392 }
393
394 void commit(RewriterBase &rewriter) override {
395 // The block was already created and inserted. Just inform the listener.
396 if (auto *listener = rewriter.getListener())
397 listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
398 }
399
400 void rollback() override {
401 // Unlink all of the operations within this block, they will be deleted
402 // separately.
403 auto &blockOps = block->getOperations();
404 while (!blockOps.empty())
405 blockOps.remove(blockOps.begin());
406 block->dropAllUses();
407 if (block->getParent())
408 block->erase();
409 else
410 delete block;
411 }
412};
413
414/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
415/// blocks are immediately unlinked, but only erased during cleanup. This makes
416/// it easier to rollback a block erasure: the block is simply inserted into its
417/// original location.
418class EraseBlockRewrite : public BlockRewrite {
419public:
420 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
421 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
422 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
423
424 static bool classof(const IRRewrite *rewrite) {
425 return rewrite->getKind() == Kind::EraseBlock;
426 }
427
428 ~EraseBlockRewrite() override {
429 assert(!block &&
430 "rewrite was neither rolled back nor committed/cleaned up");
431 }
432
433 void rollback() override {
434 // The block (owned by this rewrite) was not actually erased yet. It was
435 // just unlinked. Put it back into its original position.
436 assert(block && "expected block");
437 auto &blockList = region->getBlocks();
438 Region::iterator before = insertBeforeBlock
439 ? Region::iterator(insertBeforeBlock)
440 : blockList.end();
441 blockList.insert(before, block);
442 block = nullptr;
443 }
444
445 void commit(RewriterBase &rewriter) override {
446 assert(block && "expected block");
447
448 // Notify the listener that the block and its contents are being erased.
449 if (auto *listener =
450 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
451 notifyIRErased(listener, *block);
452 }
453
454 void cleanup(RewriterBase &rewriter) override {
455 // Erase the contents of the block.
456 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
457 rewriter.eraseOp(&op);
458 assert(block->empty() && "expected empty block");
459
460 // Erase the block.
461 block->dropAllDefinedValueUses();
462 delete block;
463 block = nullptr;
464 }
465
466private:
467 // The region in which this block was previously contained.
468 Region *region;
469
470 // The original successor of this block before it was unlinked. "nullptr" if
471 // this block was the only block in the region.
472 Block *insertBeforeBlock;
473};
474
475/// Inlining of a block. This rewrite is immediately reflected in the IR.
476/// Note: This rewrite represents only the inlining of the operations. The
477/// erasure of the inlined block is a separate rewrite.
478class InlineBlockRewrite : public BlockRewrite {
479public:
480 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
481 Block *sourceBlock, Block::iterator before)
482 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
483 sourceBlock(sourceBlock),
484 firstInlinedInst(sourceBlock->empty() ? nullptr
485 : &sourceBlock->front()),
486 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
487 // If a listener is attached to the dialect conversion, ops must be moved
488 // one-by-one. When they are moved in bulk, notifications cannot be sent
489 // because the ops that used to be in the source block at the time of the
490 // inlining (before the "commit" phase) are unknown at the time when
491 // notifications are sent (which is during the "commit" phase).
492 assert(!getConfig().listener &&
493 "InlineBlockRewrite not supported if listener is attached");
494 }
495
496 static bool classof(const IRRewrite *rewrite) {
497 return rewrite->getKind() == Kind::InlineBlock;
498 }
499
500 void rollback() override {
501 // Put the operations from the destination block (owned by the rewrite)
502 // back into the source block.
503 if (firstInlinedInst) {
504 assert(lastInlinedInst && "expected operation");
505 sourceBlock->getOperations().splice(sourceBlock->begin(),
506 block->getOperations(),
507 Block::iterator(firstInlinedInst),
508 ++Block::iterator(lastInlinedInst));
509 }
510 }
511
512private:
513 // The block that originally contained the operations.
514 Block *sourceBlock;
515
516 // The first inlined operation.
517 Operation *firstInlinedInst;
518
519 // The last inlined operation.
520 Operation *lastInlinedInst;
521};
522
523/// Moving of a block. This rewrite is immediately reflected in the IR.
524class MoveBlockRewrite : public BlockRewrite {
525public:
526 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
527 Region *previousRegion, Region::iterator previousIt)
528 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
529 region(previousRegion),
530 insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
531 : &*previousIt) {}
532
533 static bool classof(const IRRewrite *rewrite) {
534 return rewrite->getKind() == Kind::MoveBlock;
535 }
536
537 void commit(RewriterBase &rewriter) override {
538 // The block was already moved. Just inform the listener.
539 if (auto *listener = rewriter.getListener()) {
540 // Note: `previousIt` cannot be passed because this is a delayed
541 // notification and iterators into past IR state cannot be represented.
542 listener->notifyBlockInserted(block, /*previous=*/region,
543 /*previousIt=*/{});
544 }
545 }
546
547 void rollback() override {
548 // Move the block back to its original position.
549 Region::iterator before =
550 insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
551 region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
552 }
553
554private:
555 // The region in which this block was previously contained.
556 Region *region;
557
558 // The original successor of this block before it was moved. "nullptr" if
559 // this block was the only block in the region.
560 Block *insertBeforeBlock;
561};
562
563/// Block type conversion. This rewrite is partially reflected in the IR.
564class BlockTypeConversionRewrite : public BlockRewrite {
565public:
566 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
567 Block *origBlock, Block *newBlock)
568 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
569 newBlock(newBlock) {}
570
571 static bool classof(const IRRewrite *rewrite) {
572 return rewrite->getKind() == Kind::BlockTypeConversion;
573 }
574
575 Block *getOrigBlock() const { return block; }
576
577 Block *getNewBlock() const { return newBlock; }
578
579 void commit(RewriterBase &rewriter) override;
580
581 void rollback() override;
582
583private:
584 /// The new block that was created as part of this signature conversion.
585 Block *newBlock;
586};
587
588/// Replacing a value. This rewrite is not immediately reflected in the
589/// IR. An internal IR mapping is updated, but the actual replacement is delayed
590/// until the rewrite is committed.
591class ReplaceValueRewrite : public ValueRewrite {
592public:
593 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
594 const TypeConverter *converter)
595 : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
596 converter(converter) {}
597
598 static bool classof(const IRRewrite *rewrite) {
599 return rewrite->getKind() == Kind::ReplaceValue;
600 }
601
602 void commit(RewriterBase &rewriter) override;
603
604 void rollback() override;
605
606private:
607 /// The current type converter when the value was replaced.
608 const TypeConverter *converter;
609};
610
611/// An operation rewrite.
612class OperationRewrite : public IRRewrite {
613public:
614 /// Return the operation that this rewrite operates on.
615 Operation *getOperation() const { return op; }
616
617 static bool classof(const IRRewrite *rewrite) {
618 return rewrite->getKind() >= Kind::MoveOperation &&
619 rewrite->getKind() <= Kind::UnresolvedMaterialization;
620 }
621
622protected:
623 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
624 Operation *op)
625 : IRRewrite(kind, rewriterImpl), op(op) {}
626
627 // The operation that this rewrite operates on.
628 Operation *op;
629};
630
631/// Moving of an operation. This rewrite is immediately reflected in the IR.
632class MoveOperationRewrite : public OperationRewrite {
633public:
634 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
635 Operation *op, OpBuilder::InsertPoint previous)
636 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
637 block(previous.getBlock()),
638 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
639 ? nullptr
640 : &*previous.getPoint()) {}
641
642 static bool classof(const IRRewrite *rewrite) {
643 return rewrite->getKind() == Kind::MoveOperation;
644 }
645
646 void commit(RewriterBase &rewriter) override {
647 // The operation was already moved. Just inform the listener.
648 if (auto *listener = rewriter.getListener()) {
649 // Note: `previousIt` cannot be passed because this is a delayed
650 // notification and iterators into past IR state cannot be represented.
651 listener->notifyOperationInserted(
652 op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
653 /*insertPt=*/{}));
654 }
655 }
656
657 void rollback() override {
658 // Move the operation back to its original position.
659 Block::iterator before =
660 insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
661 block->getOperations().splice(before, op->getBlock()->getOperations(), op);
662 }
663
664private:
665 // The block in which this operation was previously contained.
666 Block *block;
667
668 // The original successor of this operation before it was moved. "nullptr"
669 // if this operation was the only operation in the region.
670 Operation *insertBeforeOp;
671};
672
673/// In-place modification of an op. This rewrite is immediately reflected in
674/// the IR. The previous state of the operation is stored in this object.
675class ModifyOperationRewrite : public OperationRewrite {
676public:
677 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
678 Operation *op)
679 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
680 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
681 operands(op->operand_begin(), op->operand_end()),
682 successors(op->successor_begin(), op->successor_end()) {
683 if (OpaqueProperties prop = op->getPropertiesStorage()) {
684 // Make a copy of the properties.
685 propertiesStorage = operator new(op->getPropertiesStorageSize());
686 OpaqueProperties propCopy(propertiesStorage);
687 name.initOpProperties(propCopy, /*init=*/prop);
688 }
689 }
690
691 static bool classof(const IRRewrite *rewrite) {
692 return rewrite->getKind() == Kind::ModifyOperation;
693 }
694
695 ~ModifyOperationRewrite() override {
696 assert(!propertiesStorage &&
697 "rewrite was neither committed nor rolled back");
698 }
699
700 void commit(RewriterBase &rewriter) override {
701 // Notify the listener that the operation was modified in-place.
702 if (auto *listener =
703 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
704 listener->notifyOperationModified(op);
705
706 if (propertiesStorage) {
707 OpaqueProperties propCopy(propertiesStorage);
708 // Note: The operation may have been erased in the mean time, so
709 // OperationName must be stored in this object.
710 name.destroyOpProperties(propCopy);
711 operator delete(propertiesStorage);
712 propertiesStorage = nullptr;
713 }
714 }
715
716 void rollback() override {
717 op->setLoc(loc);
718 op->setAttrs(attrs);
719 op->setOperands(operands);
720 for (const auto &it : llvm::enumerate(successors))
721 op->setSuccessor(it.value(), it.index());
722 if (propertiesStorage) {
723 OpaqueProperties propCopy(propertiesStorage);
724 op->copyProperties(propCopy);
725 name.destroyOpProperties(propCopy);
726 operator delete(propertiesStorage);
727 propertiesStorage = nullptr;
728 }
729 }
730
731private:
732 OperationName name;
733 LocationAttr loc;
734 DictionaryAttr attrs;
735 SmallVector<Value, 8> operands;
736 SmallVector<Block *, 2> successors;
737 void *propertiesStorage = nullptr;
738};
739
740/// Replacing an operation. Erasing an operation is treated as a special case
741/// with "null" replacements. This rewrite is not immediately reflected in the
742/// IR. An internal IR mapping is updated, but values are not replaced and the
743/// original op is not erased until the rewrite is committed.
744class ReplaceOperationRewrite : public OperationRewrite {
745public:
746 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
747 Operation *op, const TypeConverter *converter)
748 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
749 converter(converter) {}
750
751 static bool classof(const IRRewrite *rewrite) {
752 return rewrite->getKind() == Kind::ReplaceOperation;
753 }
754
755 void commit(RewriterBase &rewriter) override;
756
757 void rollback() override;
758
759 void cleanup(RewriterBase &rewriter) override;
760
761private:
762 /// An optional type converter that can be used to materialize conversions
763 /// between the new and old values if necessary.
764 const TypeConverter *converter;
765};
766
767class CreateOperationRewrite : public OperationRewrite {
768public:
769 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
770 Operation *op)
771 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
772
773 static bool classof(const IRRewrite *rewrite) {
774 return rewrite->getKind() == Kind::CreateOperation;
775 }
776
777 void commit(RewriterBase &rewriter) override {
778 // The operation was already created and inserted. Just inform the listener.
779 if (auto *listener = rewriter.getListener())
780 listener->notifyOperationInserted(op, /*previous=*/{});
781 }
782
783 void rollback() override;
784};
785
786/// The type of materialization.
787enum MaterializationKind {
788 /// This materialization materializes a conversion from an illegal type to a
789 /// legal one.
790 Target,
791
792 /// This materialization materializes a conversion from a legal type back to
793 /// an illegal one.
794 Source
795};
796
797/// Helper class that stores metadata about an unresolved materialization.
798class UnresolvedMaterializationInfo {
799public:
800 UnresolvedMaterializationInfo() = default;
801 UnresolvedMaterializationInfo(const TypeConverter *converter,
802 MaterializationKind kind, Type originalType)
803 : converterAndKind(converter, kind), originalType(originalType) {}
804
805 /// Return the type converter of this materialization (which may be null).
806 const TypeConverter *getConverter() const {
807 return converterAndKind.getPointer();
808 }
809
810 /// Return the kind of this materialization.
811 MaterializationKind getMaterializationKind() const {
812 return converterAndKind.getInt();
813 }
814
815 /// Return the original type of the SSA value.
816 Type getOriginalType() const { return originalType; }
817
818private:
819 /// The corresponding type converter to use when resolving this
820 /// materialization, and the kind of this materialization.
821 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
822 converterAndKind;
823
824 /// The original type of the SSA value. Only used for target
825 /// materializations.
826 Type originalType;
827};
828
829/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
830/// op. Unresolved materializations fold away or are replaced with
831/// source/target materializations at the end of the dialect conversion.
832class UnresolvedMaterializationRewrite : public OperationRewrite {
833public:
834 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
835 UnrealizedConversionCastOp op,
836 ValueVector mappedValues)
837 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
838 mappedValues(std::move(mappedValues)) {}
839
840 static bool classof(const IRRewrite *rewrite) {
841 return rewrite->getKind() == Kind::UnresolvedMaterialization;
842 }
843
844 void rollback() override;
845
846 UnrealizedConversionCastOp getOperation() const {
847 return cast<UnrealizedConversionCastOp>(op);
848 }
849
850private:
851 /// The values in the conversion value mapping that are being replaced by the
852 /// results of this unresolved materialization.
853 ValueVector mappedValues;
854};
855} // namespace
856
857#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
858/// Return "true" if there is an operation rewrite that matches the specified
859/// rewrite type and operation among the given rewrites.
860template <typename RewriteTy, typename R>
861static bool hasRewrite(R &&rewrites, Operation *op) {
862 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
863 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
864 return rewriteTy && rewriteTy->getOperation() == op;
865 });
866}
867
868/// Return "true" if there is a block rewrite that matches the specified
869/// rewrite type and block among the given rewrites.
870template <typename RewriteTy, typename R>
871static bool hasRewrite(R &&rewrites, Block *block) {
872 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
873 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
874 return rewriteTy && rewriteTy->getBlock() == block;
875 });
876}
877#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
878
879//===----------------------------------------------------------------------===//
880// ConversionPatternRewriterImpl
881//===----------------------------------------------------------------------===//
882namespace mlir {
883namespace detail {
885 explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
886 const ConversionConfig &config,
890
891 //===--------------------------------------------------------------------===//
892 // State Management
893 //===--------------------------------------------------------------------===//
894
895 /// Return the current state of the rewriter.
896 RewriterState getCurrentState();
897
898 /// Apply all requested operation rewrites. This method is invoked when the
899 /// conversion process succeeds.
900 void applyRewrites();
901
902 /// Reset the state of the rewriter to a previously saved point. Optionally,
903 /// the name of the pattern that triggered the rollback can specified for
904 /// debugging purposes.
905 void resetState(RewriterState state, StringRef patternName = "");
906
907 /// Append a rewrite. Rewrites are committed upon success and rolled back upon
908 /// failure.
909 template <typename RewriteTy, typename... Args>
910 void appendRewrite(Args &&...args) {
911 assert(config.allowPatternRollback && "appending rewrites is not allowed");
912 rewrites.push_back(
913 std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
914 }
915
916 /// Undo the rewrites (motions, splits) one by one in reverse order until
917 /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
918 /// that triggered the rollback can specified for debugging purposes.
919 void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
920
921 /// Remap the given values to those with potentially different types. Returns
922 /// success if the values could be remapped, failure otherwise. `valueDiagTag`
923 /// is the tag used when describing a value within a diagnostic, e.g.
924 /// "operand".
925 LogicalResult remapValues(StringRef valueDiagTag,
926 std::optional<Location> inputLoc, ValueRange values,
927 SmallVector<ValueVector> &remapped);
928
929 /// Return "true" if the given operation is ignored, and does not need to be
930 /// converted.
931 bool isOpIgnored(Operation *op) const;
932
933 /// Return "true" if the given operation was replaced or erased.
934 bool wasOpReplaced(Operation *op) const;
935
936 /// Lookup the most recently mapped values with the desired types in the
937 /// mapping, taking into account only replacements. Perform a best-effort
938 /// search for existing materializations with the desired types.
939 ///
940 /// If `skipPureTypeConversions` is "true", materializations that are pure
941 /// type conversions are not considered.
942 ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
943 bool skipPureTypeConversions = false) const;
944
945 /// Lookup the given value within the map, or return an empty vector if the
946 /// value is not mapped. If it is mapped, this follows the same behavior
947 /// as `lookupOrDefault`.
948 ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
949
950 //===--------------------------------------------------------------------===//
951 // IR Rewrites / Type Conversion
952 //===--------------------------------------------------------------------===//
953
954 /// Convert the types of block arguments within the given region.
955 FailureOr<Block *>
956 convertRegionTypes(Region *region, const TypeConverter &converter,
957 TypeConverter::SignatureConversion *entryConversion);
958
959 /// Apply the given signature conversion on the given block. The new block
960 /// containing the updated signature is returned. If no conversions were
961 /// necessary, e.g. if the block has no arguments, `block` is returned.
962 /// `converter` is used to generate any necessary cast operations that
963 /// translate between the origin argument types and those specified in the
964 /// signature conversion.
965 Block *applySignatureConversion(
966 Block *block, const TypeConverter *converter,
967 TypeConverter::SignatureConversion &signatureConversion);
968
969 /// Replace the results of the given operation with the given values and
970 /// erase the operation.
971 ///
972 /// There can be multiple replacement values for each result (1:N
973 /// replacement). If the replacement values are empty, the respective result
974 /// is dropped and a source materialization is built if the result still has
975 /// uses.
976 void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
977
978 /// Replace the uses of the given value with the given values. The specified
979 /// converter is used to build materializations (if necessary). If `functor`
980 /// is specified, only the uses that the functor returns "true" for are
981 /// replaced.
982 void replaceValueUses(Value from, ValueRange to,
983 const TypeConverter *converter,
984 function_ref<bool(OpOperand &)> functor = nullptr);
985
986 /// Erase the given block and its contents.
987 void eraseBlock(Block *block);
988
989 /// Inline the source block into the destination block before the given
990 /// iterator.
991 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
992
993 //===--------------------------------------------------------------------===//
994 // Materializations
995 //===--------------------------------------------------------------------===//
996
997 /// Build an unresolved materialization operation given a range of output
998 /// types and a list of input operands. Returns the inputs if they their
999 /// types match the output types.
1000 ///
1001 /// If a cast op was built, it can optionally be returned with the `castOp`
1002 /// output argument.
1003 ///
1004 /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
1005 /// the results of the unresolved materialization in the conversion value
1006 /// mapping.
1007 ///
1008 /// If `isPureTypeConversion` is "true", the materialization is created only
1009 /// to resolve a type mismatch. That means it is not a regular value
1010 /// replacement issued by the user. (Replacement values that are created
1011 /// "out of thin air" appear like unresolved materializations because they are
1012 /// unrealized_conversion_cast ops. However, they must be treated like
1013 /// regular value replacements.)
1014 ValueRange buildUnresolvedMaterialization(
1015 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1016 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1017 Type originalType, const TypeConverter *converter,
1018 bool isPureTypeConversion = true);
1019
1020 /// Find a replacement value for the given SSA value in the conversion value
1021 /// mapping. The replacement value must have the same type as the given SSA
1022 /// value. If there is no replacement value with the correct type, find the
1023 /// latest replacement value (regardless of the type) and build a source
1024 /// materialization.
1025 Value findOrBuildReplacementValue(Value value,
1026 const TypeConverter *converter);
1027
1028 //===--------------------------------------------------------------------===//
1029 // Rewriter Notification Hooks
1030 //===--------------------------------------------------------------------===//
1031
1032 //// Notifies that an op was inserted.
1033 void notifyOperationInserted(Operation *op,
1034 OpBuilder::InsertPoint previous) override;
1035
1036 /// Notifies that a block was inserted.
1037 void notifyBlockInserted(Block *block, Region *previous,
1038 Region::iterator previousIt) override;
1039
1040 /// Notifies that a pattern match failed for the given reason.
1041 void
1042 notifyMatchFailure(Location loc,
1043 function_ref<void(Diagnostic &)> reasonCallback) override;
1044
1045 //===--------------------------------------------------------------------===//
1046 // IR Erasure
1047 //===--------------------------------------------------------------------===//
1048
1049 /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1050 /// operation or block is erased multiple times. This rewriter assumes that
1051 /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1053 public:
1056 std::function<void(Operation *)> opErasedCallback = nullptr)
1057 : RewriterBase(context, /*listener=*/this),
1058 opErasedCallback(std::move(opErasedCallback)) {}
1059
1060 /// Erase the given op (unless it was already erased).
1061 void eraseOp(Operation *op) override {
1062 if (wasErased(op))
1063 return;
1064 op->dropAllUses();
1066 }
1067
1068 /// Erase the given block (unless it was already erased).
1069 void eraseBlock(Block *block) override {
1070 if (wasErased(block))
1071 return;
1072 assert(block->empty() && "expected empty block");
1073 block->dropAllDefinedValueUses();
1075 }
1076
1077 bool wasErased(void *ptr) const { return erased.contains(ptr); }
1078
1080 erased.insert(op);
1081 if (opErasedCallback)
1082 opErasedCallback(op);
1083 }
1084
1085 void notifyBlockErased(Block *block) override { erased.insert(block); }
1086
1087 private:
1088 /// Pointers to all erased operations and blocks.
1089 DenseSet<void *> erased;
1090
1091 /// A callback that is invoked when an operation is erased.
1092 std::function<void(Operation *)> opErasedCallback;
1093 };
1094
1095 //===--------------------------------------------------------------------===//
1096 // State
1097 //===--------------------------------------------------------------------===//
1098
1099 /// The rewriter that is used to perform the conversion.
1100 ConversionPatternRewriter &rewriter;
1101
1102 // Mapping between replaced values that differ in type. This happens when
1103 // replacing a value with one of a different type.
1104 ConversionValueMapping mapping;
1105
1106 /// Ordered list of block operations (creations, splits, motions).
1107 /// This vector is maintained only if `allowPatternRollback` is set to
1108 /// "true". Otherwise, all IR rewrites are materialized immediately and no
1109 /// bookkeeping is needed.
1111
1112 /// A set of operations that should no longer be considered for legalization.
1113 /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1114 /// tracked separately.
1116
1117 /// A set of operations that were replaced/erased. Such ops are not erased
1118 /// immediately but only when the dialect conversion succeeds. In the mean
1119 /// time, they should no longer be considered for legalization and any attempt
1120 /// to modify/access them is invalid rewriter API usage.
1122
1123 /// A set of operations that were created by the current pattern.
1125
1126 /// A set of operations that were modified by the current pattern.
1128
1129 /// A list of unresolved materializations that were created by the current
1130 /// pattern.
1132
1133 /// A mapping for looking up metadata of unresolved materializations.
1136
1137 /// The current type converter, or nullptr if no type converter is currently
1138 /// active.
1140
1141 /// A mapping of regions to type converters that should be used when
1142 /// converting the arguments of blocks within that region.
1144
1145 /// Dialect conversion configuration.
1146 const ConversionConfig &config;
1147
1148 /// The operation converter to use for recursive legalization.
1150
1151 /// A set of erased operations. This set is utilized only if
1152 /// `allowPatternRollback` is set to "false". Conceptually, this set is
1153 /// similar to `replacedOps` (which is maintained when the flag is set to
1154 /// "true"). However, erasing from a DenseSet is more efficient than erasing
1155 /// from a SetVector.
1157
1158 /// A set of erased blocks. This set is utilized only if
1159 /// `allowPatternRollback` is set to "false".
1161
1162 /// A rewriter that notifies the listener (if any) about all IR
1163 /// modifications. This rewriter is utilized only if `allowPatternRollback`
1164 /// is set to "false". If the flag is set to "true", the listener is notified
1165 /// with a separate mechanism (e.g., in `IRRewrite::commit`).
1167
1168#ifndef NDEBUG
1169 /// A set of replaced values. This set is for debugging purposes only and it
1170 /// is maintained only if `allowPatternRollback` is set to "true".
1172
1173 /// A set of operations that have pending updates. This tracking isn't
1174 /// strictly necessary, and is thus only active during debug builds for extra
1175 /// verification.
1177
1178 /// A raw output stream used to prefix the debug log.
1179 llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
1180 llvm::dbgs()};
1181
1182 /// A logger used to emit diagnostics during the conversion process.
1183 llvm::ScopedPrinter logger{os};
1184 std::string logPrefix;
1185#endif
1186};
1187} // namespace detail
1188} // namespace mlir
1189
1190const ConversionConfig &IRRewrite::getConfig() const {
1191 return rewriterImpl.config;
1192}
1193
1194void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1195 // Inform the listener about all IR modifications that have already taken
1196 // place: References to the original block have been replaced with the new
1197 // block.
1198 if (auto *listener =
1199 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1200 for (Operation *op : getNewBlock()->getUsers())
1201 listener->notifyOperationModified(op);
1202}
1203
1204void BlockTypeConversionRewrite::rollback() {
1205 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1206}
1207
1208/// Replace all uses of `from` with `repl`.
1209static void
1211 function_ref<bool(OpOperand &)> functor = nullptr) {
1212 if (isa<BlockArgument>(repl)) {
1213 // `repl` is a block argument. Directly replace all uses.
1214 if (functor) {
1215 rewriter.replaceUsesWithIf(from, repl, functor);
1216 } else {
1217 rewriter.replaceAllUsesWith(from, repl);
1218 }
1219 return;
1220 }
1221
1222 // If the replacement value is an operation, only replace those uses that:
1223 // - are in a different block than the replacement operation, or
1224 // - are in the same block but after the replacement operation.
1225 //
1226 // Example:
1227 // ^bb0(%arg0: i32):
1228 // %0 = "consumer"(%arg0) : (i32) -> (i32)
1229 // "another_consumer"(%arg0) : (i32) -> ()
1230 //
1231 // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1232 // use in "another_consumer" but not the use in "consumer". When using the
1233 // normal RewriterBase API, this would typically be done with
1234 // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1235 // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1236 // it cannot be supported efficiently with `allowPatternRollback` set to
1237 // "true". Therefore, the conversion driver is trying to be smart and replaces
1238 // only those uses that do not lead to a dominance violation. E.g., the
1239 // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1240 // behavior.
1241 //
1242 // TODO: As we move more and more towards `allowPatternRollback` set to
1243 // "false", we should remove this special handling, in order to align the
1244 // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1245 Operation *replOp = repl.getDefiningOp();
1246 Block *replBlock = replOp->getBlock();
1247 rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
1248 Operation *user = operand.getOwner();
1249 bool result =
1250 user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1251 if (result && functor)
1252 result &= functor(operand);
1253 return result;
1254 });
1255}
1256
1257void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1258 Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
1259 if (!repl)
1260 return;
1261 performReplaceValue(rewriter, value, repl);
1262}
1263
1264void ReplaceValueRewrite::rollback() {
1265 rewriterImpl.mapping.erase({value});
1266#ifndef NDEBUG
1267 rewriterImpl.replacedValues.erase(value);
1268#endif // NDEBUG
1269}
1270
1271void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1272 auto *listener =
1273 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1274
1275 // Compute replacement values.
1276 SmallVector<Value> replacements =
1277 llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1278 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1279 });
1280
1281 // Notify the listener that the operation is about to be replaced.
1282 if (listener)
1283 listener->notifyOperationReplaced(op, replacements);
1284
1285 // Replace all uses with the new values.
1286 for (auto [result, newValue] :
1287 llvm::zip_equal(op->getResults(), replacements))
1288 if (newValue)
1289 rewriter.replaceAllUsesWith(result, newValue);
1290
1291 // The original op will be erased, so remove it from the set of unlegalized
1292 // ops.
1293 if (getConfig().unlegalizedOps)
1294 getConfig().unlegalizedOps->erase(op);
1295
1296 // Notify the listener that the operation and its contents are being erased.
1297 if (listener)
1298 notifyIRErased(listener, *op);
1299
1300 // Do not erase the operation yet. It may still be referenced in `mapping`.
1301 // Just unlink it for now and erase it during cleanup.
1302 op->getBlock()->getOperations().remove(op);
1303}
1304
1305void ReplaceOperationRewrite::rollback() {
1306 for (auto result : op->getResults())
1307 rewriterImpl.mapping.erase({result});
1308}
1309
1310void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1311 rewriter.eraseOp(op);
1312}
1313
1314void CreateOperationRewrite::rollback() {
1315 for (Region &region : op->getRegions()) {
1316 while (!region.getBlocks().empty())
1317 region.getBlocks().remove(region.getBlocks().begin());
1318 }
1319 op->dropAllUses();
1320 op->erase();
1321}
1322
1323void UnresolvedMaterializationRewrite::rollback() {
1324 if (!mappedValues.empty())
1325 rewriterImpl.mapping.erase(mappedValues);
1326 rewriterImpl.unresolvedMaterializations.erase(getOperation());
1327 op->erase();
1328}
1329
1331 // Commit all rewrites. Use a new rewriter, so the modifications are not
1332 // tracked for rollback purposes etc.
1333 IRRewriter irRewriter(rewriter.getContext(), config.listener);
1334 // Note: New rewrites may be added during the "commit" phase and the
1335 // `rewrites` vector may reallocate.
1336 for (size_t i = 0; i < rewrites.size(); ++i)
1337 rewrites[i]->commit(irRewriter);
1338
1339 // Clean up all rewrites.
1340 SingleEraseRewriter eraseRewriter(
1341 rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
1342 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1343 unresolvedMaterializations.erase(castOp);
1344 });
1345 for (auto &rewrite : rewrites)
1346 rewrite->cleanup(eraseRewriter);
1347}
1348
1349//===----------------------------------------------------------------------===//
1350// State Management
1351//===----------------------------------------------------------------------===//
1352
1354 Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1355 // Helper function that looks up a single value.
1356 auto lookup = [&](const ValueVector &values) -> ValueVector {
1357 assert(!values.empty() && "expected non-empty value vector");
1358
1359 // If the pattern rollback is enabled, use the mapping to look up the
1360 // values.
1361 if (config.allowPatternRollback)
1362 return mapping.lookup(values);
1363
1364 // Otherwise, look up values by examining the IR. All replacements have
1365 // already been materialized in IR.
1366 Operation *op = getCommonDefiningOp(values);
1367 if (!op)
1368 return {};
1369 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1370 if (!castOp)
1371 return {};
1372 if (!this->unresolvedMaterializations.contains(castOp))
1373 return {};
1374 if (castOp.getOutputs() != values)
1375 return {};
1376 return castOp.getInputs();
1377 };
1378
1379 // Helper function that looks up each value in `values` individually and then
1380 // composes the results. If that fails, it tries to look up the entire vector
1381 // at once.
1382 auto composedLookup = [&](const ValueVector &values) -> ValueVector {
1383 // If possible, replace each value with (one or multiple) mapped values.
1384 ValueVector next;
1385 for (Value v : values) {
1386 ValueVector r = lookup({v});
1387 if (!r.empty()) {
1388 llvm::append_range(next, r);
1389 } else {
1390 next.push_back(v);
1391 }
1392 }
1393 if (next != values) {
1394 // At least one value was replaced.
1395 return next;
1396 }
1397
1398 // Otherwise: Check if there is a mapping for the entire vector. Such
1399 // mappings are materializations. (N:M mapping are not supported for value
1400 // replacements.)
1401 //
1402 // Note: From a correctness point of view, materializations do not have to
1403 // be stored (and looked up) in the mapping. But for performance reasons,
1404 // we choose to reuse existing IR (when possible) instead of creating it
1405 // multiple times.
1406 ValueVector r = lookup(values);
1407 if (r.empty()) {
1408 // No mapping found: The lookup stops here.
1409 return {};
1410 }
1411 return r;
1412 };
1413
1414 // Try to find the deepest values that have the desired types. If there is no
1415 // such mapping, simply return the deepest values.
1416 ValueVector desiredValue;
1417 ValueVector current{from};
1418 ValueVector lastNonMaterialization{from};
1419 do {
1420 // Store the current value if the types match.
1421 bool match = TypeRange(ValueRange(current)) == desiredTypes;
1422 if (skipPureTypeConversions) {
1423 // Skip pure type conversions, if requested.
1424 bool pureConversion = isPureTypeConversion(current);
1425 match &= !pureConversion;
1426 // Keep track of the last mapped value that was not a pure type
1427 // conversion.
1428 if (!pureConversion)
1429 lastNonMaterialization = current;
1430 }
1431 if (match)
1432 desiredValue = current;
1433
1434 // Lookup next value in the mapping.
1435 ValueVector next = composedLookup(current);
1436 if (next.empty())
1437 break;
1438 current = std::move(next);
1439 } while (true);
1440
1441 // If the desired values were found use them, otherwise default to the leaf
1442 // values. (Skip pure type conversions, if requested.)
1443 if (!desiredTypes.empty())
1444 return desiredValue;
1445 if (skipPureTypeConversions)
1446 return lastNonMaterialization;
1447 return current;
1448}
1449
1452 TypeRange desiredTypes) const {
1453 ValueVector result = lookupOrDefault(from, desiredTypes);
1454 if (result == ValueVector{from} ||
1455 (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
1456 return {};
1457 return result;
1458}
1459
1461 return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1462}
1463
1465 StringRef patternName) {
1466 // Undo any rewrites.
1467 undoRewrites(state.numRewrites, patternName);
1468
1469 // Pop all of the recorded ignored operations that are no longer valid.
1470 while (ignoredOps.size() != state.numIgnoredOperations)
1471 ignoredOps.pop_back();
1472
1473 while (replacedOps.size() != state.numReplacedOps)
1474 replacedOps.pop_back();
1475}
1476
1477void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1478 StringRef patternName) {
1479 for (auto &rewrite :
1480 llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1481 rewrite->rollback();
1482 rewrites.resize(numRewritesToKeep);
1483}
1484
1486 StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
1487 SmallVector<ValueVector> &remapped) {
1488 remapped.reserve(llvm::size(values));
1489
1490 for (const auto &it : llvm::enumerate(values)) {
1491 Value operand = it.value();
1492 Type origType = operand.getType();
1493 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1494
1495 if (!currentTypeConverter) {
1496 // The current pattern does not have a type converter. Pass the most
1497 // recently mapped values, excluding materializations. Materializations
1498 // are intentionally excluded because their presence may depend on other
1499 // patterns. Including materializations would make the lookup fragile
1500 // and unpredictable.
1501 remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1502 /*skipPureTypeConversions=*/true));
1503 continue;
1504 }
1505
1506 // If there is no legal conversion, fail to match this pattern.
1507 SmallVector<Type, 1> legalTypes;
1508 if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
1509 notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1510 diag << "unable to convert type for " << valueDiagTag << " #"
1511 << it.index() << ", type was " << origType;
1512 });
1513 return failure();
1514 }
1515 // If a type is converted to 0 types, there is nothing to do.
1516 if (legalTypes.empty()) {
1517 remapped.push_back({});
1518 continue;
1519 }
1520
1521 ValueVector repl = lookupOrDefault(operand, legalTypes);
1522 if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1523 // Mapped values have the correct type or there is an existing
1524 // materialization. Or the operand is not mapped at all and has the
1525 // correct type.
1526 remapped.push_back(std::move(repl));
1527 continue;
1528 }
1529
1530 // Create a materialization for the most recently mapped values.
1531 repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1532 /*skipPureTypeConversions=*/true);
1534 MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1535 /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1536 /*originalType=*/origType, currentTypeConverter);
1537 remapped.push_back(castValues);
1538 }
1539 return success();
1540}
1541
1543 // Check to see if this operation is ignored or was replaced.
1544 return wasOpReplaced(op) || ignoredOps.count(op);
1545}
1546
1548 // Check to see if this operation was replaced.
1549 return replacedOps.count(op) || erasedOps.count(op);
1550}
1551
1552//===----------------------------------------------------------------------===//
1553// Type Conversion
1554//===----------------------------------------------------------------------===//
1555
1557 Region *region, const TypeConverter &converter,
1558 TypeConverter::SignatureConversion *entryConversion) {
1559 regionToConverter[region] = &converter;
1560 if (region->empty())
1561 return nullptr;
1562
1563 // Convert the arguments of each non-entry block within the region.
1564 for (Block &block :
1565 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1566 // Compute the signature for the block with the provided converter.
1567 std::optional<TypeConverter::SignatureConversion> conversion =
1568 converter.convertBlockSignature(&block);
1569 if (!conversion)
1570 return failure();
1571 // Convert the block with the computed signature.
1572 applySignatureConversion(&block, &converter, *conversion);
1573 }
1574
1575 // Convert the entry block. If an entry signature conversion was provided,
1576 // use that one. Otherwise, compute the signature with the type converter.
1577 if (entryConversion)
1578 return applySignatureConversion(&region->front(), &converter,
1579 *entryConversion);
1580 std::optional<TypeConverter::SignatureConversion> conversion =
1581 converter.convertBlockSignature(&region->front());
1582 if (!conversion)
1583 return failure();
1584 return applySignatureConversion(&region->front(), &converter, *conversion);
1585}
1586
1588 Block *block, const TypeConverter *converter,
1589 TypeConverter::SignatureConversion &signatureConversion) {
1590#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1591 // A block cannot be converted multiple times.
1592 if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1593 llvm::report_fatal_error("block was already converted");
1594#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1595
1597
1598 // If no arguments are being changed or added, there is nothing to do.
1599 unsigned origArgCount = block->getNumArguments();
1600 auto convertedTypes = signatureConversion.getConvertedTypes();
1601 if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1602 return block;
1603
1604 // Compute the locations of all block arguments in the new block.
1605 SmallVector<Location> newLocs(convertedTypes.size(),
1606 rewriter.getUnknownLoc());
1607 for (unsigned i = 0; i < origArgCount; ++i) {
1608 auto inputMap = signatureConversion.getInputMapping(i);
1609 if (!inputMap || inputMap->replacedWithValues())
1610 continue;
1611 Location origLoc = block->getArgument(i).getLoc();
1612 for (unsigned j = 0; j < inputMap->size; ++j)
1613 newLocs[inputMap->inputNo + j] = origLoc;
1614 }
1615
1616 // Insert a new block with the converted block argument types and move all ops
1617 // from the old block to the new block.
1618 Block *newBlock =
1619 rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1620 convertedTypes, newLocs);
1621
1622 // If a listener is attached to the dialect conversion, ops cannot be moved
1623 // to the destination block in bulk ("fast path"). This is because at the time
1624 // the notifications are sent, it is unknown which ops were moved. Instead,
1625 // ops should be moved one-by-one ("slow path"), so that a separate
1626 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1627 // a bit more efficient, so we try to do that when possible.
1628 bool fastPath = !config.listener;
1629 if (fastPath) {
1630 if (config.allowPatternRollback)
1631 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1632 newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1633 } else {
1634 while (!block->empty())
1635 rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1636 }
1637
1638 // Replace all uses of the old block with the new block.
1639 block->replaceAllUsesWith(newBlock);
1640
1641 for (unsigned i = 0; i != origArgCount; ++i) {
1642 BlockArgument origArg = block->getArgument(i);
1643 Type origArgType = origArg.getType();
1644
1645 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1646 signatureConversion.getInputMapping(i);
1647 if (!inputMap) {
1648 // This block argument was dropped and no replacement value was provided.
1649 // Materialize a replacement value "out of thin air".
1650 // Note: Materialization must be built here because we cannot find a
1651 // valid insertion point in the new block. (Will point to the old block.)
1652 Value mat =
1654 MaterializationKind::Source,
1655 OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1656 origArg.getLoc(),
1657 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1658 /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1659 /*isPureTypeConversion=*/false)
1660 .front();
1661 replaceValueUses(origArg, mat, converter);
1662 continue;
1663 }
1664
1665 if (inputMap->replacedWithValues()) {
1666 // This block argument was dropped and replacement values were provided.
1667 assert(inputMap->size == 0 &&
1668 "invalid to provide a replacement value when the argument isn't "
1669 "dropped");
1670 replaceValueUses(origArg, inputMap->replacementValues, converter);
1671 continue;
1672 }
1673
1674 // This is a 1->1+ mapping.
1675 auto replArgs =
1676 newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1677 replaceValueUses(origArg, replArgs, converter);
1678 }
1679
1680 if (config.allowPatternRollback)
1681 appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1682
1683 // Erase the old block. (It is just unlinked for now and will be erased during
1684 // cleanup.)
1685 rewriter.eraseBlock(block);
1686
1687 return newBlock;
1688}
1689
1690//===----------------------------------------------------------------------===//
1691// Materializations
1692//===----------------------------------------------------------------------===//
1693
1694/// Build an unresolved materialization operation given an output type and set
1695/// of input operands.
1697 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1698 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1699 Type originalType, const TypeConverter *converter,
1700 bool isPureTypeConversion) {
1701 assert((!originalType || kind == MaterializationKind::Target) &&
1702 "original type is valid only for target materializations");
1703 assert(TypeRange(inputs) != outputTypes &&
1704 "materialization is not necessary");
1705
1706 // Create an unresolved materialization. We use a new OpBuilder to avoid
1707 // tracking the materialization like we do for other operations.
1708 OpBuilder builder(outputTypes.front().getContext());
1709 builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1710 UnrealizedConversionCastOp convertOp =
1711 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1712 if (config.attachDebugMaterializationKind) {
1713 StringRef kindStr =
1714 kind == MaterializationKind::Source ? "source" : "target";
1715 convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
1716 }
1718 convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
1719
1720 // Register the materialization.
1721 unresolvedMaterializations[convertOp] =
1722 UnresolvedMaterializationInfo(converter, kind, originalType);
1723 if (config.allowPatternRollback) {
1724 if (!valuesToMap.empty())
1725 mapping.map(valuesToMap, convertOp.getResults());
1727 std::move(valuesToMap));
1728 } else {
1729 patternMaterializations.insert(convertOp);
1730 }
1731 return convertOp.getResults();
1732}
1733
1735 Value value, const TypeConverter *converter) {
1736 assert(config.allowPatternRollback &&
1737 "this code path is valid only in rollback mode");
1738
1739 // Try to find a replacement value with the same type in the conversion value
1740 // mapping. This includes cached materializations. We try to reuse those
1741 // instead of generating duplicate IR.
1742 ValueVector repl = lookupOrNull(value, value.getType());
1743 if (!repl.empty())
1744 return repl.front();
1745
1746 // Check if the value is dead. No replacement value is needed in that case.
1747 // This is an approximate check that may have false negatives but does not
1748 // require computing and traversing an inverse mapping. (We may end up
1749 // building source materializations that are never used and that fold away.)
1750 if (llvm::all_of(value.getUsers(),
1751 [&](Operation *op) { return replacedOps.contains(op); }) &&
1752 !mapping.isMappedTo(value))
1753 return Value();
1754
1755 // No replacement value was found. Get the latest replacement value
1756 // (regardless of the type) and build a source materialization to the
1757 // original type.
1758 repl = lookupOrNull(value);
1759
1760 // Compute the insertion point of the materialization.
1762 if (repl.empty()) {
1763 // The source materialization has no inputs. Insert it right before the
1764 // value that it is replacing.
1765 ip = computeInsertPoint(value);
1766 } else {
1767 // Compute the "earliest" insertion point at which all values in `repl` are
1768 // defined. It is important to emit the materialization at that location
1769 // because the same materialization may be reused in a different context.
1770 // (That's because materializations are cached in the conversion value
1771 // mapping.) The insertion point of the materialization must be valid for
1772 // all future users that may be created later in the conversion process.
1773 ip = computeInsertPoint(repl);
1774 }
1776 MaterializationKind::Source, ip, value.getLoc(),
1777 /*valuesToMap=*/repl, /*inputs=*/repl,
1778 /*outputTypes=*/value.getType(),
1779 /*originalType=*/Type(), converter,
1780 /*isPureTypeConversion=*/!repl.empty())
1781 .front();
1782 return castValue;
1783}
1784
1785//===----------------------------------------------------------------------===//
1786// Rewriter Notification Hooks
1787//===----------------------------------------------------------------------===//
1788
1790 Operation *op, OpBuilder::InsertPoint previous) {
1791 // If no previous insertion point is provided, the op used to be detached.
1792 bool wasDetached = !previous.isSet();
1793 LLVM_DEBUG({
1794 logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
1795 << ")";
1796 if (wasDetached)
1797 logger.getOStream() << " (was detached)";
1798 logger.getOStream() << "\n";
1799 });
1800
1801 // In rollback mode, it is easier to misuse the API, so perform extra error
1802 // checking.
1803 assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
1804 "attempting to insert into a block within a replaced/erased op");
1805
1806 // In "no rollback" mode, the listener is always notified immediately.
1807 if (!config.allowPatternRollback && config.listener)
1808 config.listener->notifyOperationInserted(op, previous);
1809
1810 if (wasDetached) {
1811 // If the op was detached, it is most likely a newly created op. Add it the
1812 // set of newly created ops, so that it will be legalized. If this op is
1813 // not a newly created op, it will be legalized a second time, which is
1814 // inefficient but harmless.
1815 patternNewOps.insert(op);
1816
1817 if (config.allowPatternRollback) {
1818 // TODO: If the same op is inserted multiple times from a detached
1819 // state, the rollback mechanism may erase the same op multiple times.
1820 // This is a bug in the rollback-based dialect conversion driver.
1822 } else {
1823 // In "no rollback" mode, there is an extra data structure for tracking
1824 // erased operations that must be kept up to date.
1825 erasedOps.erase(op);
1826 }
1827 return;
1828 }
1829
1830 // The op was moved from one place to another.
1831 if (config.allowPatternRollback)
1833}
1834
1835/// Given that `fromRange` is about to be replaced with `toRange`, compute
1836/// replacement values with the types of `fromRange`.
1837static SmallVector<Value>
1839 const SmallVector<SmallVector<Value>> &toRange,
1840 const TypeConverter *converter) {
1841 assert(!impl.config.allowPatternRollback &&
1842 "this code path is valid only in 'no rollback' mode");
1843 SmallVector<Value> repls;
1844 for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1845 if (from.use_empty()) {
1846 // The replaced value is dead. No replacement value is needed.
1847 repls.push_back(Value());
1848 continue;
1849 }
1850
1851 if (to.empty()) {
1852 // The replaced value is dropped. Materialize a replacement value "out of
1853 // thin air".
1854 Value srcMat = impl.buildUnresolvedMaterialization(
1855 MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
1856 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1857 /*outputTypes=*/from.getType(), /*originalType=*/Type(),
1858 converter)[0];
1859 repls.push_back(srcMat);
1860 continue;
1861 }
1862
1863 if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
1864 // The replacement value already has the correct type. Use it directly.
1865 repls.push_back(to[0]);
1866 continue;
1867 }
1868
1869 // The replacement value has the wrong type. Build a source materialization
1870 // to the original type.
1871 // TODO: This is a bit inefficient. We should try to reuse existing
1872 // materializations if possible. This would require an extension of the
1873 // `lookupOrDefault` API.
1874 Value srcMat = impl.buildUnresolvedMaterialization(
1875 MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
1876 /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
1877 /*originalType=*/Type(), converter)[0];
1878 repls.push_back(srcMat);
1879 }
1880
1881 return repls;
1882}
1883
1885 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1886 assert(newValues.size() == op->getNumResults() &&
1887 "incorrect number of replacement values");
1888 LLVM_DEBUG({
1889 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
1890 << ")\n";
1892 // If the user-provided replacement types are different from the
1893 // legalized types, as per the current type converter, print a note.
1894 // In most cases, the replacement types are expected to match the types
1895 // produced by the type converter, so this could indicate a bug in the
1896 // user code.
1897 for (auto [result, repls] :
1898 llvm::zip_equal(op->getResults(), newValues)) {
1899 Type resultType = result.getType();
1900 auto logProlog = [&, repls = repls]() {
1901 logger.startLine() << " Note: Replacing op result of type "
1902 << resultType << " with value(s) of type (";
1903 llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) {
1904 logger.getOStream() << v.getType();
1905 });
1906 logger.getOStream() << ")";
1907 };
1908 SmallVector<Type> convertedTypes;
1909 if (failed(currentTypeConverter->convertTypes(resultType,
1910 convertedTypes))) {
1911 logProlog();
1912 logger.getOStream() << ", but the type converter failed to legalize "
1913 "the original type.\n";
1914 continue;
1915 }
1916 if (TypeRange(convertedTypes) != TypeRange(ValueRange(repls))) {
1917 logProlog();
1918 logger.getOStream() << ", but the legalized type(s) is/are (";
1919 llvm::interleaveComma(convertedTypes, logger.getOStream(),
1920 [&](Type t) { logger.getOStream() << t; });
1921 logger.getOStream() << ")\n";
1922 }
1923 }
1924 }
1925 });
1926
1927 if (!config.allowPatternRollback) {
1928 // Pattern rollback is not allowed: materialize all IR changes immediately.
1930 *this, op->getResults(), newValues, currentTypeConverter);
1931 // Update internal data structures, so that there are no dangling pointers
1932 // to erased IR.
1933 op->walk([&](Operation *op) {
1934 erasedOps.insert(op);
1935 ignoredOps.remove(op);
1936 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1937 unresolvedMaterializations.erase(castOp);
1938 patternMaterializations.erase(castOp);
1939 }
1940 // The original op will be erased, so remove it from the set of
1941 // unlegalized ops.
1942 if (config.unlegalizedOps)
1943 config.unlegalizedOps->erase(op);
1944 });
1945 op->walk([&](Block *block) { erasedBlocks.insert(block); });
1946 // Replace the op with the replacement values and notify the listener.
1947 notifyingRewriter.replaceOp(op, repls);
1948 return;
1949 }
1950
1951 assert(!ignoredOps.contains(op) && "operation was already replaced");
1952#ifndef NDEBUG
1953 for (Value v : op->getResults())
1954 assert(!replacedValues.contains(v) &&
1955 "attempting to replace a value that was already replaced");
1956#endif // NDEBUG
1957
1958 // Check if replaced op is an unresolved materialization, i.e., an
1959 // unrealized_conversion_cast op that was created by the conversion driver.
1960 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1961 // Make sure that the user does not mess with unresolved materializations
1962 // that were inserted by the conversion driver. We keep track of these
1963 // ops in internal data structures.
1964 assert(!unresolvedMaterializations.contains(castOp) &&
1965 "attempting to replace/erase an unresolved materialization");
1966 }
1967
1968 // Create mappings for each of the new result values.
1969 for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
1970 mapping.map(static_cast<Value>(result), std::move(repl));
1971
1973 // Mark this operation and all nested ops as replaced.
1974 op->walk([&](Operation *op) { replacedOps.insert(op); });
1975}
1976
1978 Value from, ValueRange to, const TypeConverter *converter,
1979 function_ref<bool(OpOperand &)> functor) {
1980 LLVM_DEBUG({
1981 logger.startLine() << "** Replace Value : '" << from << "'";
1982 if (auto blockArg = dyn_cast<BlockArgument>(from)) {
1983 if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
1984 logger.getOStream() << " (in region of '" << parentOp->getName()
1985 << "' (" << parentOp << ")";
1986 } else {
1987 logger.getOStream() << " (unlinked block)";
1988 }
1989 }
1990 if (functor) {
1991 logger.getOStream() << ", conditional replacement";
1992 }
1993 });
1994
1995 if (!config.allowPatternRollback) {
1996 SmallVector<Value> toConv = llvm::to_vector(to);
1997 SmallVector<Value> repls =
1998 getReplacementValues(*this, from, {toConv}, converter);
1999 IRRewriter r(from.getContext());
2000 Value repl = repls.front();
2001 if (!repl)
2002 return;
2003
2004 performReplaceValue(r, from, repl, functor);
2005 return;
2006 }
2007
2008#ifndef NDEBUG
2009 // Make sure that a value is not replaced multiple times. In rollback mode,
2010 // `replaceAllUsesWith` replaces not only all current uses of the given value,
2011 // but also all future uses that may be introduced by future pattern
2012 // applications. Therefore, it does not make sense to call
2013 // `replaceAllUsesWith` multiple times with the same value. Doing so would
2014 // overwrite the mapping and mess with the internal state of the dialect
2015 // conversion driver.
2016 assert(!replacedValues.contains(from) &&
2017 "attempting to replace a value that was already replaced");
2018 assert(!wasOpReplaced(from.getDefiningOp()) &&
2019 "attempting to replace a op result that was already replaced");
2020 replacedValues.insert(from);
2021#endif // NDEBUG
2022
2023 if (functor)
2024 llvm::report_fatal_error(
2025 "conditional value replacement is not supported in rollback mode");
2026 mapping.map(from, to);
2027 appendRewrite<ReplaceValueRewrite>(from, converter);
2028}
2029
2031 if (!config.allowPatternRollback) {
2032 // Pattern rollback is not allowed: materialize all IR changes immediately.
2033 // Update internal data structures, so that there are no dangling pointers
2034 // to erased IR.
2035 block->walk([&](Operation *op) {
2036 erasedOps.insert(op);
2037 ignoredOps.remove(op);
2038 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2039 unresolvedMaterializations.erase(castOp);
2040 patternMaterializations.erase(castOp);
2041 }
2042 // The original op will be erased, so remove it from the set of
2043 // unlegalized ops.
2044 if (config.unlegalizedOps)
2045 config.unlegalizedOps->erase(op);
2046 });
2047 block->walk([&](Block *block) { erasedBlocks.insert(block); });
2048 // Erase the block and notify the listener.
2049 notifyingRewriter.eraseBlock(block);
2050 return;
2051 }
2052
2053 assert(!wasOpReplaced(block->getParentOp()) &&
2054 "attempting to erase a block within a replaced/erased op");
2056
2057 // Unlink the block from its parent region. The block is kept in the rewrite
2058 // object and will be actually destroyed when rewrites are applied. This
2059 // allows us to keep the operations in the block live and undo the removal by
2060 // re-inserting the block.
2061 block->getParent()->getBlocks().remove(block);
2062
2063 // Mark all nested ops as erased.
2064 block->walk([&](Operation *op) { replacedOps.insert(op); });
2065}
2066
2068 Block *block, Region *previous, Region::iterator previousIt) {
2069 // If no previous insertion point is provided, the block used to be detached.
2070 bool wasDetached = !previous;
2071 Operation *newParentOp = block->getParentOp();
2072 LLVM_DEBUG(
2073 {
2074 Operation *parent = newParentOp;
2075 if (parent) {
2076 logger.startLine() << "** Insert Block into : '" << parent->getName()
2077 << "' (" << parent << ")";
2078 } else {
2079 logger.startLine()
2080 << "** Insert Block into detached Region (nullptr parent op)";
2081 }
2082 if (wasDetached)
2083 logger.getOStream() << " (was detached)";
2084 logger.getOStream() << "\n";
2085 });
2086
2087 // In rollback mode, it is easier to misuse the API, so perform extra error
2088 // checking.
2089 assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
2090 "attempting to insert into a region within a replaced/erased op");
2091 (void)newParentOp;
2092
2093 // In "no rollback" mode, the listener is always notified immediately.
2094 if (!config.allowPatternRollback && config.listener)
2095 config.listener->notifyBlockInserted(block, previous, previousIt);
2096
2097 if (wasDetached) {
2098 // If the block was detached, it is most likely a newly created block.
2099 if (config.allowPatternRollback) {
2100 // TODO: If the same block is inserted multiple times from a detached
2101 // state, the rollback mechanism may erase the same block multiple times.
2102 // This is a bug in the rollback-based dialect conversion driver.
2104 } else {
2105 // In "no rollback" mode, there is an extra data structure for tracking
2106 // erased blocks that must be kept up to date.
2107 erasedBlocks.erase(block);
2108 }
2109 return;
2110 }
2111
2112 // The block was moved from one place to another.
2113 if (config.allowPatternRollback)
2114 appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
2115}
2116
2118 Block *dest,
2119 Block::iterator before) {
2120 appendRewrite<InlineBlockRewrite>(dest, source, before);
2121}
2122
2124 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
2125 LLVM_DEBUG({
2127 reasonCallback(diag);
2128 logger.startLine() << "** Failure : " << diag.str() << "\n";
2129 if (config.notifyCallback)
2130 config.notifyCallback(diag);
2131 });
2132}
2133
2134//===----------------------------------------------------------------------===//
2135// ConversionPatternRewriter
2136//===----------------------------------------------------------------------===//
2137
2138ConversionPatternRewriter::ConversionPatternRewriter(
2139 MLIRContext *ctx, const ConversionConfig &config,
2140 OperationConverter &opConverter)
2142 *this, config, opConverter)) {
2143 setListener(impl.get());
2144}
2145
2146ConversionPatternRewriter::~ConversionPatternRewriter() = default;
2147
2148const ConversionConfig &ConversionPatternRewriter::getConfig() const {
2149 return impl->config;
2150}
2151
2152void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
2153 assert(op && newOp && "expected non-null op");
2154 replaceOp(op, newOp->getResults());
2155}
2156
2157void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
2158 assert(op->getNumResults() == newValues.size() &&
2159 "incorrect # of replacement values");
2160
2161 // If the current insertion point is before the erased operation, we adjust
2162 // the insertion point to be after the operation.
2163 if (getInsertionPoint() == op->getIterator())
2165
2166 SmallVector<SmallVector<Value>> newVals =
2167 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2168 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2169 });
2170 impl->replaceOp(op, std::move(newVals));
2171}
2172
2173void ConversionPatternRewriter::replaceOpWithMultiple(
2174 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2175 assert(op->getNumResults() == newValues.size() &&
2176 "incorrect # of replacement values");
2177
2178 // If the current insertion point is before the erased operation, we adjust
2179 // the insertion point to be after the operation.
2180 if (getInsertionPoint() == op->getIterator())
2182
2183 impl->replaceOp(op, std::move(newValues));
2184}
2185
2186void ConversionPatternRewriter::eraseOp(Operation *op) {
2187 LLVM_DEBUG({
2188 impl->logger.startLine()
2189 << "** Erase : '" << op->getName() << "'(" << op << ")\n";
2190 });
2191
2192 // If the current insertion point is before the erased operation, we adjust
2193 // the insertion point to be after the operation.
2194 if (getInsertionPoint() == op->getIterator())
2196
2197 SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
2198 impl->replaceOp(op, std::move(nullRepls));
2199}
2200
2201void ConversionPatternRewriter::eraseBlock(Block *block) {
2202 impl->eraseBlock(block);
2203}
2204
2205Block *ConversionPatternRewriter::applySignatureConversion(
2206 Block *block, TypeConverter::SignatureConversion &conversion,
2207 const TypeConverter *converter) {
2208 assert(!impl->wasOpReplaced(block->getParentOp()) &&
2209 "attempting to apply a signature conversion to a block within a "
2210 "replaced/erased op");
2211 return impl->applySignatureConversion(block, converter, conversion);
2212}
2213
2214FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2215 Region *region, const TypeConverter &converter,
2216 TypeConverter::SignatureConversion *entryConversion) {
2217 assert(!impl->wasOpReplaced(region->getParentOp()) &&
2218 "attempting to apply a signature conversion to a block within a "
2219 "replaced/erased op");
2220 return impl->convertRegionTypes(region, converter, entryConversion);
2221}
2222
2223void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
2224 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2225}
2226
2227void ConversionPatternRewriter::replaceUsesWithIf(
2228 Value from, ValueRange to, function_ref<bool(OpOperand &)> functor,
2229 bool *allUsesReplaced) {
2230 assert(!allUsesReplaced &&
2231 "allUsesReplaced is not supported in a dialect conversion");
2232 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2233}
2234
2235Value ConversionPatternRewriter::getRemappedValue(Value key) {
2236 SmallVector<ValueVector> remappedValues;
2237 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
2238 remappedValues)))
2239 return nullptr;
2240 assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
2241 return remappedValues.front().front();
2242}
2243
2244LogicalResult
2245ConversionPatternRewriter::getRemappedValues(ValueRange keys,
2246 SmallVectorImpl<Value> &results) {
2247 if (keys.empty())
2248 return success();
2249 SmallVector<ValueVector> remapped;
2250 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
2251 remapped)))
2252 return failure();
2253 for (const auto &values : remapped) {
2254 assert(values.size() == 1 && "1:N conversion not supported");
2255 results.push_back(values.front());
2256 }
2257 return success();
2258}
2259
2260void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
2261 Block::iterator before,
2262 ValueRange argValues) {
2263#ifndef NDEBUG
2264 assert(argValues.size() == source->getNumArguments() &&
2265 "incorrect # of argument replacement values");
2266 assert(!impl->wasOpReplaced(source->getParentOp()) &&
2267 "attempting to inline a block from a replaced/erased op");
2268 assert(!impl->wasOpReplaced(dest->getParentOp()) &&
2269 "attempting to inline a block into a replaced/erased op");
2270 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
2271 // The source block will be deleted, so it should not have any users (i.e.,
2272 // there should be no predecessors).
2273 assert(llvm::all_of(source->getUsers(), opIgnored) &&
2274 "expected 'source' to have no predecessors");
2275#endif // NDEBUG
2276
2277 // If a listener is attached to the dialect conversion, ops cannot be moved
2278 // to the destination block in bulk ("fast path"). This is because at the time
2279 // the notifications are sent, it is unknown which ops were moved. Instead,
2280 // ops should be moved one-by-one ("slow path"), so that a separate
2281 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
2282 // a bit more efficient, so we try to do that when possible.
2283 bool fastPath = !getConfig().listener;
2284
2285 if (fastPath && impl->config.allowPatternRollback)
2286 impl->inlineBlockBefore(source, dest, before);
2287
2288 // Replace all uses of block arguments.
2289 for (auto it : llvm::zip(source->getArguments(), argValues))
2290 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2291
2292 if (fastPath) {
2293 // Move all ops at once.
2294 dest->getOperations().splice(before, source->getOperations());
2295 } else {
2296 // Move op by op.
2297 while (!source->empty())
2298 moveOpBefore(&source->front(), dest, before);
2299 }
2300
2301 // If the current insertion point is within the source block, adjust the
2302 // insertion point to the destination block.
2303 if (getInsertionBlock() == source)
2304 setInsertionPoint(dest, getInsertionPoint());
2305
2306 // Erase the source block.
2307 eraseBlock(source);
2308}
2309
2310void ConversionPatternRewriter::startOpModification(Operation *op) {
2311 if (!impl->config.allowPatternRollback) {
2312 // Pattern rollback is not allowed: no extra bookkeeping is needed.
2314 return;
2315 }
2316 assert(!impl->wasOpReplaced(op) &&
2317 "attempting to modify a replaced/erased op");
2318#ifndef NDEBUG
2319 impl->pendingRootUpdates.insert(op);
2320#endif
2321 impl->appendRewrite<ModifyOperationRewrite>(op);
2322}
2323
2324void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2325 impl->patternModifiedOps.insert(op);
2326 if (!impl->config.allowPatternRollback) {
2328 if (getConfig().listener)
2329 getConfig().listener->notifyOperationModified(op);
2330 return;
2331 }
2332
2333 // There is nothing to do here, we only need to track the operation at the
2334 // start of the update.
2335#ifndef NDEBUG
2336 assert(!impl->wasOpReplaced(op) &&
2337 "attempting to modify a replaced/erased op");
2338 assert(impl->pendingRootUpdates.erase(op) &&
2339 "operation did not have a pending in-place update");
2340#endif
2341}
2342
2343void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2344 if (!impl->config.allowPatternRollback) {
2346 return;
2347 }
2348#ifndef NDEBUG
2349 assert(impl->pendingRootUpdates.erase(op) &&
2350 "operation did not have a pending in-place update");
2351#endif
2352 // Erase the last update for this operation.
2353 auto it = llvm::find_if(
2354 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
2355 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2356 return modifyRewrite && modifyRewrite->getOperation() == op;
2357 });
2358 assert(it != impl->rewrites.rend() && "no root update started on op");
2359 (*it)->rollback();
2360 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2361 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2362}
2363
2364detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2365 return *impl;
2366}
2367
2368//===----------------------------------------------------------------------===//
2369// ConversionPattern
2370//===----------------------------------------------------------------------===//
2371
2372FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2373 ArrayRef<ValueRange> operands) const {
2374 SmallVector<Value> oneToOneOperands;
2375 oneToOneOperands.reserve(operands.size());
2376 for (ValueRange operand : operands) {
2377 if (operand.size() != 1)
2378 return failure();
2379
2380 oneToOneOperands.push_back(operand.front());
2381 }
2382 return std::move(oneToOneOperands);
2383}
2384
2385LogicalResult
2386ConversionPattern::matchAndRewrite(Operation *op,
2387 PatternRewriter &rewriter) const {
2388 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
2389 auto &rewriterImpl = dialectRewriter.getImpl();
2390
2391 // Track the current conversion pattern type converter in the rewriter.
2392 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
2393 getTypeConverter());
2394
2395 // Remap the operands of the operation.
2396 SmallVector<ValueVector> remapped;
2397 if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
2398 op->getOperands(), remapped))) {
2399 return failure();
2400 }
2401 SmallVector<ValueRange> remappedAsRange =
2402 llvm::to_vector_of<ValueRange>(remapped);
2403 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2404}
2405
2406//===----------------------------------------------------------------------===//
2407// OperationLegalizer
2408//===----------------------------------------------------------------------===//
2409
2410namespace {
2411/// A set of rewrite patterns that can be used to legalize a given operation.
2412using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2413
2414/// This class defines a recursive operation legalizer.
2415class OperationLegalizer {
2416public:
2417 using LegalizationAction = ConversionTarget::LegalizationAction;
2418
2419 OperationLegalizer(ConversionPatternRewriter &rewriter,
2420 const ConversionTarget &targetInfo,
2421 const FrozenRewritePatternSet &patterns);
2422
2423 /// Returns true if the given operation is known to be illegal on the target.
2424 bool isIllegal(Operation *op) const;
2425
2426 /// Attempt to legalize the given operation. Returns success if the operation
2427 /// was legalized, failure otherwise.
2428 LogicalResult legalize(Operation *op);
2429
2430 /// Returns the conversion target in use by the legalizer.
2431 const ConversionTarget &getTarget() { return target; }
2432
2433private:
2434 /// Attempt to legalize the given operation by folding it.
2435 LogicalResult legalizeWithFold(Operation *op);
2436
2437 /// Attempt to legalize the given operation by applying a pattern. Returns
2438 /// success if the operation was legalized, failure otherwise.
2439 LogicalResult legalizeWithPattern(Operation *op);
2440
2441 /// Return true if the given pattern may be applied to the given operation,
2442 /// false otherwise.
2443 bool canApplyPattern(Operation *op, const Pattern &pattern);
2444
2445 /// Legalize the resultant IR after successfully applying the given pattern.
2446 LogicalResult
2447 legalizePatternResult(Operation *op, const Pattern &pattern,
2448 const RewriterState &curState,
2449 const SetVector<Operation *> &newOps,
2450 const SetVector<Operation *> &modifiedOps);
2451
2452 LogicalResult
2453 legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
2454 LogicalResult
2455 legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
2456
2457 //===--------------------------------------------------------------------===//
2458 // Cost Model
2459 //===--------------------------------------------------------------------===//
2460
2461 /// Build an optimistic legalization graph given the provided patterns. This
2462 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2463 /// patterns for operations that are not directly legal, but may be
2464 /// transitively legal for the current target given the provided patterns.
2465 void buildLegalizationGraph(
2466 LegalizationPatterns &anyOpLegalizerPatterns,
2468
2469 /// Compute the benefit of each node within the computed legalization graph.
2470 /// This orders the patterns within 'legalizerPatterns' based upon two
2471 /// criteria:
2472 /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2473 /// represent the more direct mapping to the target.
2474 /// 2) When comparing patterns with the same legalization depth, prefer the
2475 /// pattern with the highest PatternBenefit. This allows for users to
2476 /// prefer specific legalizations over others.
2477 void computeLegalizationGraphBenefit(
2478 LegalizationPatterns &anyOpLegalizerPatterns,
2480
2481 /// Compute the legalization depth when legalizing an operation of the given
2482 /// type.
2483 unsigned computeOpLegalizationDepth(
2484 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2486
2487 /// Apply the conversion cost model to the given set of patterns, and return
2488 /// the smallest legalization depth of any of the patterns. See
2489 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2490 unsigned applyCostModelToPatterns(
2491 LegalizationPatterns &patterns,
2492 DenseMap<OperationName, unsigned> &minOpPatternDepth,
2494
2495 /// The current set of patterns that have been applied.
2496 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2497
2498 /// The rewriter to use when converting operations.
2499 ConversionPatternRewriter &rewriter;
2500
2501 /// The legalization information provided by the target.
2502 const ConversionTarget &target;
2503
2504 /// The pattern applicator to use for conversions.
2505 PatternApplicator applicator;
2506};
2507} // namespace
2508
2509OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2510 const ConversionTarget &targetInfo,
2511 const FrozenRewritePatternSet &patterns)
2512 : rewriter(rewriter), target(targetInfo), applicator(patterns) {
2513 // The set of patterns that can be applied to illegal operations to transform
2514 // them into legal ones.
2516 LegalizationPatterns anyOpLegalizerPatterns;
2517
2518 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2519 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2520}
2521
2522bool OperationLegalizer::isIllegal(Operation *op) const {
2523 return target.isIllegal(op);
2524}
2525
2526LogicalResult OperationLegalizer::legalize(Operation *op) {
2527#ifndef NDEBUG
2528 const char *logLineComment =
2529 "//===-------------------------------------------===//\n";
2530
2531 auto &logger = rewriter.getImpl().logger;
2532#endif
2533
2534 // Check to see if the operation is ignored and doesn't need to be converted.
2535 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2536
2537 LLVM_DEBUG({
2538 logger.getOStream() << "\n";
2539 logger.startLine() << logLineComment;
2540 logger.startLine() << "Legalizing operation : ";
2541 // Do not print the operation name if the operation is ignored. Ignored ops
2542 // may have been erased and should not be accessed. The pointer can be
2543 // printed safely.
2544 if (!isIgnored)
2545 logger.getOStream() << "'" << op->getName() << "' ";
2546 logger.getOStream() << "(" << op << ") {\n";
2547 logger.indent();
2548
2549 // If the operation has no regions, just print it here.
2550 if (!isIgnored && op->getNumRegions() == 0) {
2551 logger.startLine() << OpWithFlags(op,
2552 OpPrintingFlags().printGenericOpForm())
2553 << "\n";
2554 }
2555 });
2556
2557 if (isIgnored) {
2558 LLVM_DEBUG({
2559 logSuccess(logger, "operation marked 'ignored' during conversion");
2560 logger.startLine() << logLineComment;
2561 });
2562 return success();
2563 }
2564
2565 // Check if this operation is legal on the target.
2566 if (auto legalityInfo = target.isLegal(op)) {
2567 LLVM_DEBUG({
2568 logSuccess(
2569 logger, "operation marked legal by the target{0}",
2570 legalityInfo->isRecursivelyLegal
2571 ? "; NOTE: operation is recursively legal; skipping internals"
2572 : "");
2573 logger.startLine() << logLineComment;
2574 });
2575
2576 // If this operation is recursively legal, mark its children as ignored so
2577 // that we don't consider them for legalization.
2578 if (legalityInfo->isRecursivelyLegal) {
2579 op->walk([&](Operation *nested) {
2580 if (op != nested)
2581 rewriter.getImpl().ignoredOps.insert(nested);
2582 });
2583 }
2584
2585 return success();
2586 }
2587
2588 // If the operation is not legal, try to fold it in-place if the folding mode
2589 // is 'BeforePatterns'. 'Never' will skip this.
2590 const ConversionConfig &config = rewriter.getConfig();
2591 if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2592 if (succeeded(legalizeWithFold(op))) {
2593 LLVM_DEBUG({
2594 logSuccess(logger, "operation was folded");
2595 logger.startLine() << logLineComment;
2596 });
2597 return success();
2598 }
2599 }
2600
2601 // Otherwise, we need to apply a legalization pattern to this operation.
2602 if (succeeded(legalizeWithPattern(op))) {
2603 LLVM_DEBUG({
2604 logSuccess(logger, "");
2605 logger.startLine() << logLineComment;
2606 });
2607 return success();
2608 }
2609
2610 // If the operation can't be legalized via patterns, try to fold it in-place
2611 // if the folding mode is 'AfterPatterns'.
2612 if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2613 if (succeeded(legalizeWithFold(op))) {
2614 LLVM_DEBUG({
2615 logSuccess(logger, "operation was folded");
2616 logger.startLine() << logLineComment;
2617 });
2618 return success();
2619 }
2620 }
2621
2622 LLVM_DEBUG({
2623 logFailure(logger, "no matched legalization pattern");
2624 logger.startLine() << logLineComment;
2625 });
2626 return failure();
2627}
2628
2629/// Helper function that moves and returns the given object. Also resets the
2630/// original object, so that it is in a valid, empty state again.
2631template <typename T>
2632static T moveAndReset(T &obj) {
2633 T result = std::move(obj);
2634 obj = T();
2635 return result;
2636}
2637
2638LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2639 auto &rewriterImpl = rewriter.getImpl();
2640 LLVM_DEBUG({
2641 rewriterImpl.logger.startLine() << "* Fold {\n";
2642 rewriterImpl.logger.indent();
2643 });
2644
2645 // Clear pattern state, so that the next pattern application starts with a
2646 // clean slate. (The op/block sets are populated by listener notifications.)
2647 auto cleanup = llvm::make_scope_exit([&]() {
2648 rewriterImpl.patternNewOps.clear();
2649 rewriterImpl.patternModifiedOps.clear();
2650 });
2651
2652 // Upon failure, undo all changes made by the folder.
2653 RewriterState curState = rewriterImpl.getCurrentState();
2654
2655 // Try to fold the operation.
2656 StringRef opName = op->getName().getStringRef();
2657 SmallVector<Value, 2> replacementValues;
2658 SmallVector<Operation *, 2> newOps;
2659 rewriter.setInsertionPoint(op);
2660 rewriter.startOpModification(op);
2661 if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2662 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2663 rewriter.cancelOpModification(op);
2664 return failure();
2665 }
2666 rewriter.finalizeOpModification(op);
2667
2668 // An empty list of replacement values indicates that the fold was in-place.
2669 // As the operation changed, a new legalization needs to be attempted.
2670 if (replacementValues.empty())
2671 return legalize(op);
2672
2673 // Insert a replacement for 'op' with the folded replacement values.
2674 rewriter.replaceOp(op, replacementValues);
2675
2676 // Recursively legalize any new constant operations.
2677 for (Operation *newOp : newOps) {
2678 if (failed(legalize(newOp))) {
2679 LLVM_DEBUG(logFailure(rewriterImpl.logger,
2680 "failed to legalize generated constant '{0}'",
2681 newOp->getName()));
2682 if (!rewriter.getConfig().allowPatternRollback) {
2683 // Rolling back a folder is like rolling back a pattern.
2684 llvm::report_fatal_error(
2685 "op '" + opName +
2686 "' folder rollback of IR modifications requested");
2687 }
2688 rewriterImpl.resetState(
2689 curState, std::string(op->getName().getStringRef()) + " folder");
2690 return failure();
2691 }
2692 }
2693
2694 LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2695 return success();
2696}
2697
2698/// Report a fatal error indicating that newly produced or modified IR could
2699/// not be legalized.
2700static void
2702 const SetVector<Operation *> &newOps,
2703 const SetVector<Operation *> &modifiedOps) {
2704 auto newOpNames = llvm::map_range(
2705 newOps, [](Operation *op) { return op->getName().getStringRef(); });
2706 auto modifiedOpNames = llvm::map_range(
2707 modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
2708 llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2709 "' produced IR that could not be legalized. " +
2710 "new ops: {" + llvm::join(newOpNames, ", ") + "}, " +
2711 "modified ops: {" +
2712 llvm::join(modifiedOpNames, ", ") + "}");
2713}
2714
2715LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2716 auto &rewriterImpl = rewriter.getImpl();
2717 const ConversionConfig &config = rewriter.getConfig();
2718
2719#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2720 Operation *checkOp;
2721 std::optional<OperationFingerPrint> topLevelFingerPrint;
2722 if (!rewriterImpl.config.allowPatternRollback) {
2723 // The op may be getting erased, so we have to check the parent op.
2724 // (In rare cases, a pattern may even erase the parent op, which will cause
2725 // a crash here. Expensive checks are "best effort".) Skip the check if the
2726 // op does not have a parent op.
2727 if ((checkOp = op->getParentOp())) {
2728 if (!op->getContext()->isMultithreadingEnabled()) {
2729 topLevelFingerPrint = OperationFingerPrint(checkOp);
2730 } else {
2731 // Another thread may be modifying a sibling operation. Therefore, the
2732 // fingerprinting mechanism of the parent op works only in
2733 // single-threaded mode.
2734 LLVM_DEBUG({
2735 rewriterImpl.logger.startLine()
2736 << "WARNING: Multi-threadeding is enabled. Some dialect "
2737 "conversion expensive checks are skipped in multithreading "
2738 "mode!\n";
2739 });
2740 }
2741 }
2742 }
2743#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2744
2745 // Functor that returns if the given pattern may be applied.
2746 auto canApply = [&](const Pattern &pattern) {
2747 bool canApply = canApplyPattern(op, pattern);
2748 if (canApply && config.listener)
2749 config.listener->notifyPatternBegin(pattern, op);
2750 return canApply;
2751 };
2752
2753 // Functor that cleans up the rewriter state after a pattern failed to match.
2754 RewriterState curState = rewriterImpl.getCurrentState();
2755 auto onFailure = [&](const Pattern &pattern) {
2756 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2757 if (!rewriterImpl.config.allowPatternRollback) {
2758 // Erase all unresolved materializations.
2759 for (auto op : rewriterImpl.patternMaterializations) {
2760 rewriterImpl.unresolvedMaterializations.erase(op);
2761 op.erase();
2762 }
2763 rewriterImpl.patternMaterializations.clear();
2764#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2765 // Expensive pattern check that can detect API violations.
2766 if (checkOp && topLevelFingerPrint) {
2767 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2768 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2769 llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2770 "' returned failure but IR did change");
2771 }
2772#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2773 }
2774 rewriterImpl.patternNewOps.clear();
2775 rewriterImpl.patternModifiedOps.clear();
2776 LLVM_DEBUG({
2777 logFailure(rewriterImpl.logger, "pattern failed to match");
2778 if (rewriterImpl.config.notifyCallback) {
2779 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
2780 diag << "Failed to apply pattern \"" << pattern.getDebugName()
2781 << "\" on op:\n"
2782 << *op;
2783 rewriterImpl.config.notifyCallback(diag);
2784 }
2785 });
2786 if (config.listener)
2787 config.listener->notifyPatternEnd(pattern, failure());
2788 rewriterImpl.resetState(curState, pattern.getDebugName());
2789 appliedPatterns.erase(&pattern);
2790 };
2791
2792 // Functor that performs additional legalization when a pattern is
2793 // successfully applied.
2794 auto onSuccess = [&](const Pattern &pattern) {
2795 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2796 if (!rewriterImpl.config.allowPatternRollback) {
2797 // Eagerly erase unused materializations.
2798 for (auto op : rewriterImpl.patternMaterializations) {
2799 if (op->use_empty()) {
2800 rewriterImpl.unresolvedMaterializations.erase(op);
2801 op.erase();
2802 }
2803 }
2804 rewriterImpl.patternMaterializations.clear();
2805 }
2806 SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2807 SetVector<Operation *> modifiedOps =
2808 moveAndReset(rewriterImpl.patternModifiedOps);
2809 auto result =
2810 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2811 appliedPatterns.erase(&pattern);
2812 if (failed(result)) {
2813 if (!rewriterImpl.config.allowPatternRollback)
2814 reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
2815 rewriterImpl.resetState(curState, pattern.getDebugName());
2816 }
2817 if (config.listener)
2818 config.listener->notifyPatternEnd(pattern, result);
2819 return result;
2820 };
2821
2822 // Try to match and rewrite a pattern on this operation.
2823 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2824 onSuccess);
2825}
2826
2827bool OperationLegalizer::canApplyPattern(Operation *op,
2828 const Pattern &pattern) {
2829 LLVM_DEBUG({
2830 auto &os = rewriter.getImpl().logger;
2831 os.getOStream() << "\n";
2832 os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2833 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2834 os.getOStream() << ")' {\n";
2835 os.indent();
2836 });
2837
2838 // Ensure that we don't cycle by not allowing the same pattern to be
2839 // applied twice in the same recursion stack if it is not known to be safe.
2840 if (!pattern.hasBoundedRewriteRecursion() &&
2841 !appliedPatterns.insert(&pattern).second) {
2842 LLVM_DEBUG(
2843 logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2844 return false;
2845 }
2846 return true;
2847}
2848
2849LogicalResult OperationLegalizer::legalizePatternResult(
2850 Operation *op, const Pattern &pattern, const RewriterState &curState,
2851 const SetVector<Operation *> &newOps,
2852 const SetVector<Operation *> &modifiedOps) {
2853 [[maybe_unused]] auto &impl = rewriter.getImpl();
2854 assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2855
2856#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2857 if (impl.config.allowPatternRollback) {
2858 // Check that the root was either replaced or updated in place.
2859 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2860 auto replacedRoot = [&] {
2861 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2862 };
2863 auto updatedRootInPlace = [&] {
2864 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2865 };
2866 if (!replacedRoot() && !updatedRootInPlace())
2867 llvm::report_fatal_error("expected pattern to replace the root operation "
2868 "or modify it in place");
2869 }
2870#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2871
2872 // Legalize each of the actions registered during application.
2873 if (failed(legalizePatternRootUpdates(modifiedOps)) ||
2874 failed(legalizePatternCreatedOperations(newOps))) {
2875 return failure();
2876 }
2877
2878 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2879 return success();
2880}
2881
2882LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2883 const SetVector<Operation *> &newOps) {
2884 for (Operation *op : newOps) {
2885 if (failed(legalize(op))) {
2886 LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
2887 "failed to legalize generated operation '{0}'({1})",
2888 op->getName(), op));
2889 return failure();
2890 }
2891 }
2892 return success();
2893}
2894
2895LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2896 const SetVector<Operation *> &modifiedOps) {
2897 for (Operation *op : modifiedOps) {
2898 if (failed(legalize(op))) {
2899 LLVM_DEBUG(
2900 logFailure(rewriter.getImpl().logger,
2901 "failed to legalize operation updated in-place '{0}'",
2902 op->getName()));
2903 return failure();
2904 }
2905 }
2906 return success();
2907}
2908
2909//===----------------------------------------------------------------------===//
2910// Cost Model
2911//===----------------------------------------------------------------------===//
2912
2913void OperationLegalizer::buildLegalizationGraph(
2914 LegalizationPatterns &anyOpLegalizerPatterns,
2916 // A mapping between an operation and a set of operations that can be used to
2917 // generate it.
2919 // A mapping between an operation and any currently invalid patterns it has.
2921 // A worklist of patterns to consider for legality.
2922 SetVector<const Pattern *> patternWorklist;
2923
2924 // Build the mapping from operations to the parent ops that may generate them.
2925 applicator.walkAllPatterns([&](const Pattern &pattern) {
2926 std::optional<OperationName> root = pattern.getRootKind();
2927
2928 // If the pattern has no specific root, we can't analyze the relationship
2929 // between the root op and generated operations. Given that, add all such
2930 // patterns to the legalization set.
2931 if (!root) {
2932 anyOpLegalizerPatterns.push_back(&pattern);
2933 return;
2934 }
2935
2936 // Skip operations that are always known to be legal.
2937 if (target.getOpAction(*root) == LegalizationAction::Legal)
2938 return;
2939
2940 // Add this pattern to the invalid set for the root op and record this root
2941 // as a parent for any generated operations.
2942 invalidPatterns[*root].insert(&pattern);
2943 for (auto op : pattern.getGeneratedOps())
2944 parentOps[op].insert(*root);
2945
2946 // Add this pattern to the worklist.
2947 patternWorklist.insert(&pattern);
2948 });
2949
2950 // If there are any patterns that don't have a specific root kind, we can't
2951 // make direct assumptions about what operations will never be legalized.
2952 // Note: Technically we could, but it would require an analysis that may
2953 // recurse into itself. It would be better to perform this kind of filtering
2954 // at a higher level than here anyways.
2955 if (!anyOpLegalizerPatterns.empty()) {
2956 for (const Pattern *pattern : patternWorklist)
2957 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2958 return;
2959 }
2960
2961 while (!patternWorklist.empty()) {
2962 auto *pattern = patternWorklist.pop_back_val();
2963
2964 // Check to see if any of the generated operations are invalid.
2965 if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2966 std::optional<LegalizationAction> action = target.getOpAction(op);
2967 return !legalizerPatterns.count(op) &&
2968 (!action || action == LegalizationAction::Illegal);
2969 }))
2970 continue;
2971
2972 // Otherwise, if all of the generated operation are valid, this op is now
2973 // legal so add all of the child patterns to the worklist.
2974 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2975 invalidPatterns[*pattern->getRootKind()].erase(pattern);
2976
2977 // Add any invalid patterns of the parent operations to see if they have now
2978 // become legal.
2979 for (auto op : parentOps[*pattern->getRootKind()])
2980 patternWorklist.set_union(invalidPatterns[op]);
2981 }
2982}
2983
2984void OperationLegalizer::computeLegalizationGraphBenefit(
2985 LegalizationPatterns &anyOpLegalizerPatterns,
2987 // The smallest pattern depth, when legalizing an operation.
2988 DenseMap<OperationName, unsigned> minOpPatternDepth;
2989
2990 // For each operation that is transitively legal, compute a cost for it.
2991 for (auto &opIt : legalizerPatterns)
2992 if (!minOpPatternDepth.count(opIt.first))
2993 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2994 legalizerPatterns);
2995
2996 // Apply the cost model to the patterns that can match any operation. Those
2997 // with a specific operation type are already resolved when computing the op
2998 // legalization depth.
2999 if (!anyOpLegalizerPatterns.empty())
3000 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3001 legalizerPatterns);
3002
3003 // Apply a cost model to the pattern applicator. We order patterns first by
3004 // depth then benefit. `legalizerPatterns` contains per-op patterns by
3005 // decreasing benefit.
3006 applicator.applyCostModel([&](const Pattern &pattern) {
3007 ArrayRef<const Pattern *> orderedPatternList;
3008 if (std::optional<OperationName> rootName = pattern.getRootKind())
3009 orderedPatternList = legalizerPatterns[*rootName];
3010 else
3011 orderedPatternList = anyOpLegalizerPatterns;
3012
3013 // If the pattern is not found, then it was removed and cannot be matched.
3014 auto *it = llvm::find(orderedPatternList, &pattern);
3015 if (it == orderedPatternList.end())
3017
3018 // Patterns found earlier in the list have higher benefit.
3019 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3020 });
3021}
3022
3023unsigned OperationLegalizer::computeOpLegalizationDepth(
3024 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
3026 // Check for existing depth.
3027 auto depthIt = minOpPatternDepth.find(op);
3028 if (depthIt != minOpPatternDepth.end())
3029 return depthIt->second;
3030
3031 // If a mapping for this operation does not exist, then this operation
3032 // is always legal. Return 0 as the depth for a directly legal operation.
3033 auto opPatternsIt = legalizerPatterns.find(op);
3034 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3035 return 0u;
3036
3037 // Record this initial depth in case we encounter this op again when
3038 // recursively computing the depth.
3039 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3040
3041 // Apply the cost model to the operation patterns, and update the minimum
3042 // depth.
3043 unsigned minDepth = applyCostModelToPatterns(
3044 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3045 minOpPatternDepth[op] = minDepth;
3046 return minDepth;
3047}
3048
3049unsigned OperationLegalizer::applyCostModelToPatterns(
3050 LegalizationPatterns &patterns,
3051 DenseMap<OperationName, unsigned> &minOpPatternDepth,
3053 unsigned minDepth = std::numeric_limits<unsigned>::max();
3054
3055 // Compute the depth for each pattern within the set.
3056 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3057 patternsByDepth.reserve(patterns.size());
3058 for (const Pattern *pattern : patterns) {
3059 unsigned depth = 1;
3060 for (auto generatedOp : pattern->getGeneratedOps()) {
3061 unsigned generatedOpDepth = computeOpLegalizationDepth(
3062 generatedOp, minOpPatternDepth, legalizerPatterns);
3063 depth = std::max(depth, generatedOpDepth + 1);
3064 }
3065 patternsByDepth.emplace_back(pattern, depth);
3066
3067 // Update the minimum depth of the pattern list.
3068 minDepth = std::min(minDepth, depth);
3069 }
3070
3071 // If the operation only has one legalization pattern, there is no need to
3072 // sort them.
3073 if (patternsByDepth.size() == 1)
3074 return minDepth;
3075
3076 // Sort the patterns by those likely to be the most beneficial.
3077 llvm::stable_sort(patternsByDepth,
3078 [](const std::pair<const Pattern *, unsigned> &lhs,
3079 const std::pair<const Pattern *, unsigned> &rhs) {
3080 // First sort by the smaller pattern legalization
3081 // depth.
3082 if (lhs.second != rhs.second)
3083 return lhs.second < rhs.second;
3084
3085 // Then sort by the larger pattern benefit.
3086 auto lhsBenefit = lhs.first->getBenefit();
3087 auto rhsBenefit = rhs.first->getBenefit();
3088 return lhsBenefit > rhsBenefit;
3089 });
3090
3091 // Update the legalization pattern to use the new sorted list.
3092 patterns.clear();
3093 for (auto &patternIt : patternsByDepth)
3094 patterns.push_back(patternIt.first);
3095 return minDepth;
3096}
3097
3098//===----------------------------------------------------------------------===//
3099// Reconcile Unrealized Casts
3100//===----------------------------------------------------------------------===//
3101
3102/// Try to reconcile all given UnrealizedConversionCastOps and store the
3103/// left-over ops in `remainingCastOps` (if provided). See documentation in
3104/// DialectConversion.h for more details.
3105/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3106/// algorithm may visit an operand (or user) which is a cast op, but will not
3107/// try to reconcile it if not in the filtered set.
3108template <typename RangeT>
3110 RangeT castOps,
3111 function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3113 // A worklist of cast ops to process.
3114 SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
3115
3116 // Helper function that return the unrealized_conversion_cast op that
3117 // defines all inputs of the given op (in the same order). Return "nullptr"
3118 // if there is no such op.
3119 auto getInputCast =
3120 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3121 if (castOp.getInputs().empty())
3122 return {};
3123 auto inputCastOp =
3124 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3125 if (!inputCastOp)
3126 return {};
3127 if (inputCastOp.getOutputs() != castOp.getInputs())
3128 return {};
3129 return inputCastOp;
3130 };
3131
3132 // Process ops in the worklist bottom-to-top.
3133 while (!worklist.empty()) {
3134 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3135
3136 // Traverse the chain of input cast ops to see if an op with the same
3137 // input types can be found.
3138 UnrealizedConversionCastOp nextCast = castOp;
3139 while (nextCast) {
3140 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3141 if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
3142 return v.getDefiningOp() == castOp;
3143 })) {
3144 // Ran into a cycle.
3145 break;
3146 }
3147
3148 // Found a cast where the input types match the output types of the
3149 // matched op. We can directly use those inputs.
3150 castOp.replaceAllUsesWith(nextCast.getInputs());
3151 break;
3152 }
3153 nextCast = getInputCast(nextCast);
3154 }
3155 }
3156
3157 // A set of all alive cast ops. I.e., ops whose results are (transitively)
3158 // used by an op that is not a cast op.
3159 DenseSet<Operation *> liveOps;
3160
3161 // Helper function that marks the given op and transitively reachable input
3162 // cast ops as alive.
3163 auto markOpLive = [&](Operation *rootOp) {
3164 SmallVector<Operation *> worklist;
3165 worklist.push_back(rootOp);
3166 while (!worklist.empty()) {
3167 Operation *op = worklist.pop_back_val();
3168 if (liveOps.insert(op).second) {
3169 // Successfully inserted: process reachable input cast ops.
3170 for (Value v : op->getOperands())
3171 if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3172 if (isCastOpOfInterestFn(castOp))
3173 worklist.push_back(castOp);
3174 }
3175 }
3176 };
3177
3178 // Find all alive cast ops.
3179 for (UnrealizedConversionCastOp op : castOps) {
3180 // The op may have been marked live already as being an operand of another
3181 // live cast op.
3182 if (liveOps.contains(op.getOperation()))
3183 continue;
3184 // If any of the users is not a cast op, mark the current op (and its
3185 // input ops) as live.
3186 if (llvm::any_of(op->getUsers(), [&](Operation *user) {
3187 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3188 return !castOp || !isCastOpOfInterestFn(castOp);
3189 }))
3190 markOpLive(op);
3191 }
3192
3193 // Erase all dead cast ops.
3194 for (UnrealizedConversionCastOp op : castOps) {
3195 if (liveOps.contains(op)) {
3196 // Op is alive and was not erased. Add it to the remaining cast ops.
3197 if (remainingCastOps)
3198 remainingCastOps->push_back(op);
3199 continue;
3200 }
3201
3202 // Op is dead. Erase it.
3203 op->dropAllUses();
3204 op->erase();
3205 }
3206}
3207
3209 ArrayRef<UnrealizedConversionCastOp> castOps,
3210 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3211 // Set of all cast ops for faster lookups.
3212 DenseSet<UnrealizedConversionCastOp> castOpSet;
3213 for (UnrealizedConversionCastOp op : castOps)
3214 castOpSet.insert(op);
3215 reconcileUnrealizedCasts(castOpSet, remainingCastOps);
3216}
3217
3219 const DenseSet<UnrealizedConversionCastOp> &castOps,
3220 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3222 llvm::make_range(castOps.begin(), castOps.end()),
3223 [&](UnrealizedConversionCastOp castOp) {
3224 return castOps.contains(castOp);
3225 },
3226 remainingCastOps);
3227}
3228
3229namespace mlir {
3232 &castOps,
3235 castOps.keys(),
3236 [&](UnrealizedConversionCastOp castOp) {
3237 return castOps.contains(castOp);
3238 },
3239 remainingCastOps);
3240}
3241} // namespace mlir
3242
3243//===----------------------------------------------------------------------===//
3244// OperationConverter
3245//===----------------------------------------------------------------------===//
3246
3247namespace mlir {
3248// This class converts operations to a given conversion target via a set of
3249// rewrite patterns. The conversion behaves differently depending on the
3250// conversion mode.
3254 const ConversionConfig &config,
3255 OpConversionMode mode)
3256 : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
3257 mode(mode) {}
3258
3259 /// Applies the conversion to the given operations (and their nested
3260 /// operations).
3261 LogicalResult applyConversion(ArrayRef<Operation *> ops);
3262
3263 /// Legalizes the given operations (and their nested operations) to the
3264 /// conversion target.
3265 template <typename Fn>
3266 LogicalResult legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
3267 bool isRecursiveLegalization = false);
3269 bool isRecursiveLegalization = false) {
3270 return legalizeOperations(
3271 ops, /*onFailure=*/[&]() {}, isRecursiveLegalization);
3272 }
3273
3274 /// Converts a single operation. If `isRecursiveLegalization` is "true", the
3275 /// conversion is a recursive legalization request, triggered from within a
3276 /// pattern. In that case, do not emit errors because there will be another
3277 /// attempt at legalizing the operation later (via the regular pre-order
3278 /// legalization mechanism).
3279 LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
3280
3281 const ConversionTarget &getTarget() { return opLegalizer.getTarget(); }
3282
3283private:
3284 /// The rewriter to use when converting operations.
3285 ConversionPatternRewriter rewriter;
3286
3287 /// The legalizer to use when converting operations.
3288 OperationLegalizer opLegalizer;
3289
3290 /// The conversion mode to use when legalizing operations.
3291 OpConversionMode mode;
3292};
3293} // namespace mlir
3294
3296 bool isRecursiveLegalization) {
3297 const ConversionConfig &config = rewriter.getConfig();
3298
3299 // Legalize the given operation.
3300 if (failed(opLegalizer.legalize(op))) {
3301 // Handle the case of a failed conversion for each of the different modes.
3302 // Full conversions expect all operations to be converted.
3303 if (mode == OpConversionMode::Full) {
3304 if (!isRecursiveLegalization)
3305 op->emitError() << "failed to legalize operation '" << op->getName()
3306 << "'";
3307 return failure();
3308 }
3309 // Partial conversions allow conversions to fail iff the operation was not
3310 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
3311 // set, non-legalizable ops are added to that set.
3312 if (mode == OpConversionMode::Partial) {
3313 if (opLegalizer.isIllegal(op)) {
3314 if (!isRecursiveLegalization)
3315 op->emitError() << "failed to legalize operation '" << op->getName()
3316 << "' that was explicitly marked illegal";
3317 return failure();
3318 }
3319 if (config.unlegalizedOps && !isRecursiveLegalization)
3320 config.unlegalizedOps->insert(op);
3321 }
3322 } else if (mode == OpConversionMode::Analysis) {
3323 // Analysis conversions don't fail if any operations fail to legalize,
3324 // they are only interested in the operations that were successfully
3325 // legalized.
3326 if (config.legalizableOps && !isRecursiveLegalization)
3327 config.legalizableOps->insert(op);
3328 }
3329 return success();
3330}
3331
3332static LogicalResult
3334 UnrealizedConversionCastOp op,
3335 const UnresolvedMaterializationInfo &info) {
3336 assert(!op.use_empty() &&
3337 "expected that dead materializations have already been DCE'd");
3338 Operation::operand_range inputOperands = op.getOperands();
3339
3340 // Try to materialize the conversion.
3341 if (const TypeConverter *converter = info.getConverter()) {
3342 rewriter.setInsertionPoint(op);
3343 SmallVector<Value> newMaterialization;
3344 switch (info.getMaterializationKind()) {
3345 case MaterializationKind::Target:
3346 newMaterialization = converter->materializeTargetConversion(
3347 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3348 info.getOriginalType());
3349 break;
3350 case MaterializationKind::Source:
3351 assert(op->getNumResults() == 1 && "expected single result");
3352 Value sourceMat = converter->materializeSourceConversion(
3353 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3354 if (sourceMat)
3355 newMaterialization.push_back(sourceMat);
3356 break;
3357 }
3358 if (!newMaterialization.empty()) {
3359#ifndef NDEBUG
3360 ValueRange newMaterializationRange(newMaterialization);
3361 assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
3362 "materialization callback produced value of incorrect type");
3363#endif // NDEBUG
3364 rewriter.replaceOp(op, newMaterialization);
3365 return success();
3366 }
3367 }
3368
3369 InFlightDiagnostic diag = op->emitError()
3370 << "failed to legalize unresolved materialization "
3371 "from ("
3372 << inputOperands.getTypes() << ") to ("
3373 << op.getResultTypes()
3374 << ") that remained live after conversion";
3375 diag.attachNote(op->getUsers().begin()->getLoc())
3376 << "see existing live user here: " << *op->getUsers().begin();
3377 return failure();
3378}
3379
3380template <typename Fn>
3381LogicalResult
3383 bool isRecursiveLegalization) {
3384 const ConversionTarget &target = opLegalizer.getTarget();
3385
3386 // Compute the set of operations and blocks to convert.
3387 SmallVector<Operation *> toConvert;
3388 for (Operation *op : ops) {
3390 [&](Operation *op) {
3391 toConvert.push_back(op);
3392 // Don't check this operation's children for conversion if the
3393 // operation is recursively legal.
3394 auto legalityInfo = target.isLegal(op);
3395 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3396 return WalkResult::skip();
3397 return WalkResult::advance();
3398 });
3399 }
3400 for (Operation *op : toConvert) {
3401 if (failed(convert(op, isRecursiveLegalization))) {
3402 // Failed to convert an operation.
3403 onFailure();
3404 return failure();
3405 }
3406 }
3407 return success();
3408}
3409
3410LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
3411 return impl->opConverter.legalizeOperations(op,
3412 /*isRecursiveLegalization=*/true);
3413}
3414
3415LogicalResult ConversionPatternRewriter::legalize(Region *r) {
3416 // Fast path: If the region is empty, there is nothing to legalize.
3417 if (r->empty())
3418 return success();
3419
3420 // Gather a list of all operations to legalize. This is done before
3421 // converting the entry block signature because unrealized_conversion_cast
3422 // ops should not be included.
3424 for (Block &b : *r)
3425 for (Operation &op : b)
3426 ops.push_back(&op);
3427
3428 // If the current pattern runs with a type converter, convert the entry block
3429 // signature.
3430 if (const TypeConverter *converter = impl->currentTypeConverter) {
3431 std::optional<TypeConverter::SignatureConversion> conversion =
3432 converter->convertBlockSignature(&r->front());
3433 if (!conversion)
3434 return failure();
3435 applySignatureConversion(&r->front(), *conversion, converter);
3436 }
3437
3438 // Legalize all operations in the region. This includes all nested
3439 // operations.
3440 return impl->opConverter.legalizeOperations(ops,
3441 /*isRecursiveLegalization=*/true);
3442}
3443
3445 // Convert each operation and discard rewrites on failure.
3446 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
3447 LogicalResult status = legalizeOperations(ops, /*onFailure=*/[&]() {
3448 // Dialect conversion failed.
3449 if (rewriterImpl.config.allowPatternRollback) {
3450 // Rollback is allowed: restore the original IR.
3451 rewriterImpl.undoRewrites();
3452 } else {
3453 // Rollback is not allowed: apply all modifications that have been
3454 // performed so far.
3455 rewriterImpl.applyRewrites();
3456 }
3457 });
3458 if (failed(status))
3459 return failure();
3460
3461 // After a successful conversion, apply rewrites.
3462 rewriterImpl.applyRewrites();
3463
3464 // Reconcile all UnrealizedConversionCastOps that were inserted by the
3465 // dialect conversion frameworks. (Not the ones that were inserted by
3466 // patterns.)
3468 &materializations = rewriterImpl.unresolvedMaterializations;
3470 reconcileUnrealizedCasts(materializations, &remainingCastOps);
3471
3472 // Drop markers.
3473 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3474 castOp->removeAttr(kPureTypeConversionMarker);
3475
3476 // Try to legalize all unresolved materializations.
3477 if (rewriter.getConfig().buildMaterializations) {
3478 // Use a new rewriter, so the modifications are not tracked for rollback
3479 // purposes etc.
3480 IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
3481 rewriter.getConfig().listener);
3482 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3483 auto it = materializations.find(castOp);
3484 assert(it != materializations.end() && "inconsistent state");
3485 if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
3486 it->second)))
3487 return failure();
3488 }
3489 }
3490
3491 return success();
3492}
3493
3494//===----------------------------------------------------------------------===//
3495// Type Conversion
3496//===----------------------------------------------------------------------===//
3497
3498void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
3499 ArrayRef<Type> types) {
3500 assert(!types.empty() && "expected valid types");
3501 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
3502 addInputs(types);
3503}
3504
3505void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
3506 assert(!types.empty() &&
3507 "1->0 type remappings don't need to be added explicitly");
3508 argTypes.append(types.begin(), types.end());
3509}
3510
3511void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3512 unsigned newInputNo,
3513 unsigned newInputCount) {
3514 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3515 assert(newInputCount != 0 && "expected valid input count");
3516 remappedInputs[origInputNo] =
3517 InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
3518}
3519
3520void TypeConverter::SignatureConversion::remapInput(
3521 unsigned origInputNo, ArrayRef<Value> replacements) {
3522 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3523 remappedInputs[origInputNo] = InputMapping{
3524 origInputNo, /*size=*/0,
3525 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3526}
3527
3528/// Internal implementation of the type conversion.
3529/// This is used with either a Type or a Value as the first argument.
3530/// - we can cache the context-free conversions until the last registered
3531/// context-aware conversion.
3532/// - we can't cache the result of type conversion happening after context-aware
3533/// conversions, because the type converter may return different results for the
3534/// same input type.
3535LogicalResult
3536TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3537 SmallVectorImpl<Type> &results) const {
3538 assert(typeOrValue && "expected non-null type");
3539 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3540 : cast<Type>(typeOrValue);
3541 {
3542 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3543 std::defer_lock);
3545 cacheReadLock.lock();
3546 auto existingIt = cachedDirectConversions.find(t);
3547 if (existingIt != cachedDirectConversions.end()) {
3548 if (existingIt->second)
3549 results.push_back(existingIt->second);
3550 return success(existingIt->second != nullptr);
3551 }
3552 auto multiIt = cachedMultiConversions.find(t);
3553 if (multiIt != cachedMultiConversions.end()) {
3554 results.append(multiIt->second.begin(), multiIt->second.end());
3555 return success();
3556 }
3557 }
3558 // Walk the added converters in reverse order to apply the most recently
3559 // registered first.
3560 size_t currentCount = results.size();
3561
3562 // We can cache the context-free conversions until the last registered
3563 // context-aware conversion. But only if we're processing a Value right now.
3564 auto isCacheable = [&](int index) {
3565 int numberOfConversionsUntilContextAware =
3566 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3567 return index < numberOfConversionsUntilContextAware;
3568 };
3569
3570 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3571 std::defer_lock);
3572
3573 for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3574 const ConversionCallbackFn &converter = indexedConverter.value();
3575 std::optional<LogicalResult> result = converter(typeOrValue, results);
3576 if (!result) {
3577 assert(results.size() == currentCount &&
3578 "failed type conversion should not change results");
3579 continue;
3580 }
3581 if (!isCacheable(indexedConverter.index()))
3582 return success();
3584 cacheWriteLock.lock();
3585 if (!succeeded(*result)) {
3586 assert(results.size() == currentCount &&
3587 "failed type conversion should not change results");
3588 cachedDirectConversions.try_emplace(t, nullptr);
3589 return failure();
3590 }
3591 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3592 if (newTypes.size() == 1)
3593 cachedDirectConversions.try_emplace(t, newTypes.front());
3594 else
3595 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3596 return success();
3597 }
3598 return failure();
3599}
3600
3601LogicalResult TypeConverter::convertType(Type t,
3602 SmallVectorImpl<Type> &results) const {
3603 return convertTypeImpl(t, results);
3604}
3605
3606LogicalResult TypeConverter::convertType(Value v,
3607 SmallVectorImpl<Type> &results) const {
3608 return convertTypeImpl(v, results);
3609}
3610
3611Type TypeConverter::convertType(Type t) const {
3612 // Use the multi-type result version to convert the type.
3613 SmallVector<Type, 1> results;
3614 if (failed(convertType(t, results)))
3615 return nullptr;
3616
3617 // Check to ensure that only one type was produced.
3618 return results.size() == 1 ? results.front() : nullptr;
3619}
3620
3621Type TypeConverter::convertType(Value v) const {
3622 // Use the multi-type result version to convert the type.
3623 SmallVector<Type, 1> results;
3624 if (failed(convertType(v, results)))
3625 return nullptr;
3626
3627 // Check to ensure that only one type was produced.
3628 return results.size() == 1 ? results.front() : nullptr;
3629}
3630
3631LogicalResult
3632TypeConverter::convertTypes(TypeRange types,
3633 SmallVectorImpl<Type> &results) const {
3634 for (Type type : types)
3635 if (failed(convertType(type, results)))
3636 return failure();
3637 return success();
3638}
3639
3640LogicalResult
3641TypeConverter::convertTypes(ValueRange values,
3642 SmallVectorImpl<Type> &results) const {
3643 for (Value value : values)
3644 if (failed(convertType(value, results)))
3645 return failure();
3646 return success();
3647}
3648
3649bool TypeConverter::isLegal(Type type) const {
3650 return convertType(type) == type;
3651}
3652
3653bool TypeConverter::isLegal(Value value) const {
3654 return convertType(value) == value.getType();
3655}
3656
3657bool TypeConverter::isLegal(Operation *op) const {
3658 return isLegal(op->getOperands()) && isLegal(op->getResults());
3659}
3660
3661bool TypeConverter::isLegal(Region *region) const {
3662 return llvm::all_of(
3663 *region, [this](Block &block) { return isLegal(block.getArguments()); });
3664}
3665
3666bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3667 if (!isLegal(ty.getInputs()))
3668 return false;
3669 if (!isLegal(ty.getResults()))
3670 return false;
3671 return true;
3672}
3673
3674LogicalResult
3675TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
3676 SignatureConversion &result) const {
3677 // Try to convert the given input type.
3678 SmallVector<Type, 1> convertedTypes;
3679 if (failed(convertType(type, convertedTypes)))
3680 return failure();
3681
3682 // If this argument is being dropped, there is nothing left to do.
3683 if (convertedTypes.empty())
3684 return success();
3685
3686 // Otherwise, add the new inputs.
3687 result.addInputs(inputNo, convertedTypes);
3688 return success();
3689}
3690LogicalResult
3691TypeConverter::convertSignatureArgs(TypeRange types,
3692 SignatureConversion &result,
3693 unsigned origInputOffset) const {
3694 for (unsigned i = 0, e = types.size(); i != e; ++i)
3695 if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3696 return failure();
3697 return success();
3698}
3699LogicalResult
3700TypeConverter::convertSignatureArg(unsigned inputNo, Value value,
3701 SignatureConversion &result) const {
3702 // Try to convert the given input type.
3703 SmallVector<Type, 1> convertedTypes;
3704 if (failed(convertType(value, convertedTypes)))
3705 return failure();
3706
3707 // If this argument is being dropped, there is nothing left to do.
3708 if (convertedTypes.empty())
3709 return success();
3710
3711 // Otherwise, add the new inputs.
3712 result.addInputs(inputNo, convertedTypes);
3713 return success();
3714}
3715LogicalResult
3716TypeConverter::convertSignatureArgs(ValueRange values,
3717 SignatureConversion &result,
3718 unsigned origInputOffset) const {
3719 for (unsigned i = 0, e = values.size(); i != e; ++i)
3720 if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
3721 return failure();
3722 return success();
3723}
3724
3725Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3726 Location loc, Type resultType,
3727 ValueRange inputs) const {
3728 for (const SourceMaterializationCallbackFn &fn :
3729 llvm::reverse(sourceMaterializations))
3730 if (Value result = fn(builder, resultType, inputs, loc))
3731 return result;
3732 return nullptr;
3733}
3734
3735Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3736 Location loc, Type resultType,
3737 ValueRange inputs,
3738 Type originalType) const {
3739 SmallVector<Value> result = materializeTargetConversion(
3740 builder, loc, TypeRange(resultType), inputs, originalType);
3741 if (result.empty())
3742 return nullptr;
3743 assert(result.size() == 1 && "expected single result");
3744 return result.front();
3745}
3746
3747SmallVector<Value> TypeConverter::materializeTargetConversion(
3748 OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3749 Type originalType) const {
3750 for (const TargetMaterializationCallbackFn &fn :
3751 llvm::reverse(targetMaterializations)) {
3752 SmallVector<Value> result =
3753 fn(builder, resultTypes, inputs, loc, originalType);
3754 if (result.empty())
3755 continue;
3756 assert(TypeRange(ValueRange(result)) == resultTypes &&
3757 "callback produced incorrect number of values or values with "
3758 "incorrect types");
3759 return result;
3760 }
3761 return {};
3762}
3763
3764std::optional<TypeConverter::SignatureConversion>
3765TypeConverter::convertBlockSignature(Block *block) const {
3766 SignatureConversion conversion(block->getNumArguments());
3767 if (failed(convertSignatureArgs(block->getArguments(), conversion)))
3768 return std::nullopt;
3769 return conversion;
3770}
3771
3772//===----------------------------------------------------------------------===//
3773// Type attribute conversion
3774//===----------------------------------------------------------------------===//
3775TypeConverter::AttributeConversionResult
3776TypeConverter::AttributeConversionResult::result(Attribute attr) {
3777 return AttributeConversionResult(attr, resultTag);
3778}
3779
3780TypeConverter::AttributeConversionResult
3781TypeConverter::AttributeConversionResult::na() {
3782 return AttributeConversionResult(nullptr, naTag);
3783}
3784
3785TypeConverter::AttributeConversionResult
3786TypeConverter::AttributeConversionResult::abort() {
3787 return AttributeConversionResult(nullptr, abortTag);
3788}
3789
3790bool TypeConverter::AttributeConversionResult::hasResult() const {
3791 return impl.getInt() == resultTag;
3792}
3793
3794bool TypeConverter::AttributeConversionResult::isNa() const {
3795 return impl.getInt() == naTag;
3796}
3797
3798bool TypeConverter::AttributeConversionResult::isAbort() const {
3799 return impl.getInt() == abortTag;
3800}
3801
3802Attribute TypeConverter::AttributeConversionResult::getResult() const {
3803 assert(hasResult() && "Cannot get result from N/A or abort");
3804 return impl.getPointer();
3805}
3806
3807std::optional<Attribute>
3808TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
3809 for (const TypeAttributeConversionCallbackFn &fn :
3810 llvm::reverse(typeAttributeConversions)) {
3811 AttributeConversionResult res = fn(type, attr);
3812 if (res.hasResult())
3813 return res.getResult();
3814 if (res.isAbort())
3815 return std::nullopt;
3816 }
3817 return std::nullopt;
3818}
3819
3820//===----------------------------------------------------------------------===//
3821// FunctionOpInterfaceSignatureConversion
3822//===----------------------------------------------------------------------===//
3823
3824static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3825 const TypeConverter &typeConverter,
3826 ConversionPatternRewriter &rewriter) {
3827 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3828 if (!type)
3829 return failure();
3830
3831 // Convert the original function types.
3832 TypeConverter::SignatureConversion result(type.getNumInputs());
3833 SmallVector<Type, 1> newResults;
3834 if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3835 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3836 return failure();
3837 if (!funcOp.getFunctionBody().empty())
3838 rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
3839 &typeConverter);
3840
3841 // Update the function signature in-place.
3842 auto newType = FunctionType::get(rewriter.getContext(),
3843 result.getConvertedTypes(), newResults);
3844
3845 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3846
3847 return success();
3848}
3849
3850/// Create a default conversion pattern that rewrites the type signature of a
3851/// FunctionOpInterface op. This only supports ops which use FunctionType to
3852/// represent their type.
3853namespace {
3854struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3855 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3856 MLIRContext *ctx,
3857 const TypeConverter &converter,
3858 PatternBenefit benefit)
3859 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3860
3861 LogicalResult
3862 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3863 ConversionPatternRewriter &rewriter) const override {
3864 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3865 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3866 }
3867};
3868
3869struct AnyFunctionOpInterfaceSignatureConversion
3870 : public OpInterfaceConversionPattern<FunctionOpInterface> {
3871 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3872
3873 LogicalResult
3874 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3875 ConversionPatternRewriter &rewriter) const override {
3876 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3877 }
3878};
3879} // namespace
3880
3881FailureOr<Operation *>
3882mlir::convertOpResultTypes(Operation *op, ValueRange operands,
3883 const TypeConverter &converter,
3884 ConversionPatternRewriter &rewriter) {
3885 assert(op && "Invalid op");
3886 Location loc = op->getLoc();
3887 if (converter.isLegal(op))
3888 return rewriter.notifyMatchFailure(loc, "op already legal");
3889
3890 OperationState newOp(loc, op->getName());
3891 newOp.addOperands(operands);
3892
3893 SmallVector<Type> newResultTypes;
3894 if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
3895 return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3896
3897 newOp.addTypes(newResultTypes);
3898 newOp.addAttributes(op->getAttrs());
3899 return rewriter.create(newOp);
3900}
3901
3902void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3903 StringRef functionLikeOpName, RewritePatternSet &patterns,
3904 const TypeConverter &converter, PatternBenefit benefit) {
3905 patterns.add<FunctionOpInterfaceSignatureConversion>(
3906 functionLikeOpName, patterns.getContext(), converter, benefit);
3907}
3908
3909void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3910 RewritePatternSet &patterns, const TypeConverter &converter,
3911 PatternBenefit benefit) {
3912 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3913 converter, patterns.getContext(), benefit);
3914}
3915
3916//===----------------------------------------------------------------------===//
3917// ConversionTarget
3918//===----------------------------------------------------------------------===//
3919
3920void ConversionTarget::setOpAction(OperationName op,
3921 LegalizationAction action) {
3922 legalOperations[op].action = action;
3923}
3924
3925void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3926 LegalizationAction action) {
3927 for (StringRef dialect : dialectNames)
3928 legalDialects[dialect] = action;
3929}
3930
3931auto ConversionTarget::getOpAction(OperationName op) const
3932 -> std::optional<LegalizationAction> {
3933 std::optional<LegalizationInfo> info = getOpInfo(op);
3934 return info ? info->action : std::optional<LegalizationAction>();
3935}
3936
3937auto ConversionTarget::isLegal(Operation *op) const
3938 -> std::optional<LegalOpDetails> {
3939 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3940 if (!info)
3941 return std::nullopt;
3942
3943 // Returns true if this operation instance is known to be legal.
3944 auto isOpLegal = [&] {
3945 // Handle dynamic legality either with the provided legality function.
3946 if (info->action == LegalizationAction::Dynamic) {
3947 std::optional<bool> result = info->legalityFn(op);
3948 if (result)
3949 return *result;
3950 }
3951
3952 // Otherwise, the operation is only legal if it was marked 'Legal'.
3953 return info->action == LegalizationAction::Legal;
3954 };
3955 if (!isOpLegal())
3956 return std::nullopt;
3957
3958 // This operation is legal, compute any additional legality information.
3959 LegalOpDetails legalityDetails;
3960 if (info->isRecursivelyLegal) {
3961 auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3962 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3963 legalityDetails.isRecursivelyLegal =
3964 legalityFnIt->second(op).value_or(true);
3965 } else {
3966 legalityDetails.isRecursivelyLegal = true;
3967 }
3968 }
3969 return legalityDetails;
3970}
3971
3972bool ConversionTarget::isIllegal(Operation *op) const {
3973 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3974 if (!info)
3975 return false;
3976
3977 if (info->action == LegalizationAction::Dynamic) {
3978 std::optional<bool> result = info->legalityFn(op);
3979 if (!result)
3980 return false;
3981
3982 return !(*result);
3983 }
3984
3985 return info->action == LegalizationAction::Illegal;
3986}
3987
3988static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
3989 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
3990 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
3991 if (!oldCallback)
3992 return newCallback;
3993
3994 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3995 Operation *op) -> std::optional<bool> {
3996 if (std::optional<bool> result = newCl(op))
3997 return *result;
3998
3999 return oldCl(op);
4000 };
4001 return chain;
4002}
4003
4004void ConversionTarget::setLegalityCallback(
4005 OperationName name, const DynamicLegalityCallbackFn &callback) {
4006 assert(callback && "expected valid legality callback");
4007 auto *infoIt = legalOperations.find(name);
4008 assert(infoIt != legalOperations.end() &&
4009 infoIt->second.action == LegalizationAction::Dynamic &&
4010 "expected operation to already be marked as dynamically legal");
4011 infoIt->second.legalityFn =
4012 composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
4013}
4014
4015void ConversionTarget::markOpRecursivelyLegal(
4016 OperationName name, const DynamicLegalityCallbackFn &callback) {
4017 auto *infoIt = legalOperations.find(name);
4018 assert(infoIt != legalOperations.end() &&
4019 infoIt->second.action != LegalizationAction::Illegal &&
4020 "expected operation to already be marked as legal");
4021 infoIt->second.isRecursivelyLegal = true;
4022 if (callback)
4023 opRecursiveLegalityFns[name] = composeLegalityCallbacks(
4024 std::move(opRecursiveLegalityFns[name]), callback);
4025 else
4026 opRecursiveLegalityFns.erase(name);
4027}
4028
4029void ConversionTarget::setLegalityCallback(
4030 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
4031 assert(callback && "expected valid legality callback");
4032 for (StringRef dialect : dialects)
4033 dialectLegalityFns[dialect] = composeLegalityCallbacks(
4034 std::move(dialectLegalityFns[dialect]), callback);
4035}
4036
4037void ConversionTarget::setLegalityCallback(
4038 const DynamicLegalityCallbackFn &callback) {
4039 assert(callback && "expected valid legality callback");
4040 unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
4041}
4042
4043auto ConversionTarget::getOpInfo(OperationName op) const
4044 -> std::optional<LegalizationInfo> {
4045 // Check for info for this specific operation.
4046 const auto *it = legalOperations.find(op);
4047 if (it != legalOperations.end())
4048 return it->second;
4049 // Check for info for the parent dialect.
4050 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4051 if (dialectIt != legalDialects.end()) {
4052 DynamicLegalityCallbackFn callback;
4053 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4054 if (dialectFn != dialectLegalityFns.end())
4055 callback = dialectFn->second;
4056 return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
4057 callback};
4058 }
4059 // Otherwise, check if we mark unknown operations as dynamic.
4060 if (unknownLegalityFn)
4061 return LegalizationInfo{LegalizationAction::Dynamic,
4062 /*isRecursivelyLegal=*/false, unknownLegalityFn};
4063 return std::nullopt;
4064}
4065
4066#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4067//===----------------------------------------------------------------------===//
4068// PDL Configuration
4069//===----------------------------------------------------------------------===//
4070
4071void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4072 auto &rewriterImpl =
4073 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4074 rewriterImpl.currentTypeConverter = getTypeConverter();
4075}
4076
4077void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4078 auto &rewriterImpl =
4079 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4080 rewriterImpl.currentTypeConverter = nullptr;
4081}
4082
4083/// Remap the given value using the rewriter and the type converter in the
4084/// provided config.
4085static FailureOr<SmallVector<Value>>
4086pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
4087 SmallVector<Value> mappedValues;
4088 if (failed(rewriter.getRemappedValues(values, mappedValues)))
4089 return failure();
4090 return std::move(mappedValues);
4091}
4092
4093void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
4094 patterns.getPDLPatterns().registerRewriteFunction(
4095 "convertValue",
4096 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4097 auto results = pdllConvertValues(
4098 static_cast<ConversionPatternRewriter &>(rewriter), value);
4099 if (failed(results))
4100 return failure();
4101 return results->front();
4102 });
4103 patterns.getPDLPatterns().registerRewriteFunction(
4104 "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
4105 return pdllConvertValues(
4106 static_cast<ConversionPatternRewriter &>(rewriter), values);
4107 });
4108 patterns.getPDLPatterns().registerRewriteFunction(
4109 "convertType",
4110 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4111 auto &rewriterImpl =
4112 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4113 if (const TypeConverter *converter =
4114 rewriterImpl.currentTypeConverter) {
4115 if (Type newType = converter->convertType(type))
4116 return newType;
4117 return failure();
4118 }
4119 return type;
4120 });
4121 patterns.getPDLPatterns().registerRewriteFunction(
4122 "convertTypes",
4123 [](PatternRewriter &rewriter,
4124 TypeRange types) -> FailureOr<SmallVector<Type>> {
4125 auto &rewriterImpl =
4126 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4127 const TypeConverter *converter = rewriterImpl.currentTypeConverter;
4128 if (!converter)
4129 return SmallVector<Type>(types);
4130
4131 SmallVector<Type> remappedTypes;
4132 if (failed(converter->convertTypes(types, remappedTypes)))
4133 return failure();
4134 return std::move(remappedTypes);
4135 });
4136}
4137#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
4138
4139//===----------------------------------------------------------------------===//
4140// Op Conversion Entry Points
4141//===----------------------------------------------------------------------===//
4142
4143/// This is the type of Action that is dispatched when a conversion is applied.
4145 : public tracing::ActionImpl<ApplyConversionAction> {
4146public:
4149 static constexpr StringLiteral tag = "apply-conversion";
4150 static constexpr StringLiteral desc =
4151 "Encapsulate the application of a dialect conversion";
4152
4153 void print(raw_ostream &os) const override { os << tag; }
4154};
4155
4157 const ConversionTarget &target,
4159 ConversionConfig config,
4160 OpConversionMode mode) {
4161 if (ops.empty())
4162 return success();
4163 MLIRContext *ctx = ops.front()->getContext();
4164 LogicalResult status = success();
4165 SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
4167 [&] {
4168 OperationConverter opConverter(ops.front()->getContext(), target,
4169 patterns, config, mode);
4170 status = opConverter.applyConversion(ops);
4171 },
4172 irUnits);
4173 return status;
4174}
4175
4176//===----------------------------------------------------------------------===//
4177// Partial Conversion
4178//===----------------------------------------------------------------------===//
4179
4180LogicalResult mlir::applyPartialConversion(
4181 ArrayRef<Operation *> ops, const ConversionTarget &target,
4182 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4183 return applyConversion(ops, target, patterns, config,
4184 OpConversionMode::Partial);
4185}
4186LogicalResult
4187mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
4188 const FrozenRewritePatternSet &patterns,
4189 ConversionConfig config) {
4190 return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
4191}
4192
4193//===----------------------------------------------------------------------===//
4194// Full Conversion
4195//===----------------------------------------------------------------------===//
4196
4197LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4198 const ConversionTarget &target,
4199 const FrozenRewritePatternSet &patterns,
4200 ConversionConfig config) {
4201 return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
4202}
4203LogicalResult mlir::applyFullConversion(Operation *op,
4204 const ConversionTarget &target,
4205 const FrozenRewritePatternSet &patterns,
4206 ConversionConfig config) {
4207 return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
4208}
4209
4210//===----------------------------------------------------------------------===//
4211// Analysis Conversion
4212//===----------------------------------------------------------------------===//
4213
4214/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
4215/// op is a top-level module op (which is expected to be isolated from above),
4216/// return that op.
4218 // Check if there is a top-level operation within `ops`. If so, return that
4219 // op.
4220 for (Operation *op : ops) {
4221 if (!op->getParentOp()) {
4222#ifndef NDEBUG
4223 assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
4224 "expected top-level op to be isolated from above");
4225 for (Operation *other : ops)
4226 assert(op->isAncestor(other) &&
4227 "expected ops to have a common ancestor");
4228#endif // NDEBUG
4229 return op;
4230 }
4231 }
4232
4233 // No top-level op. Find a common ancestor.
4234 Operation *commonAncestor =
4235 ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
4236 for (Operation *op : ops.drop_front()) {
4237 while (!commonAncestor->isProperAncestor(op)) {
4238 commonAncestor =
4240 assert(commonAncestor &&
4241 "expected to find a common isolated from above ancestor");
4242 }
4243 }
4244
4245 return commonAncestor;
4246}
4247
4248LogicalResult mlir::applyAnalysisConversion(
4249 ArrayRef<Operation *> ops, ConversionTarget &target,
4250 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4251#ifndef NDEBUG
4252 if (config.legalizableOps)
4253 assert(config.legalizableOps->empty() && "expected empty set");
4254#endif // NDEBUG
4255
4256 // Clone closted common ancestor that is isolated from above.
4257 Operation *commonAncestor = findCommonAncestor(ops);
4258 IRMapping mapping;
4259 Operation *clonedAncestor = commonAncestor->clone(mapping);
4260 // Compute inverse IR mapping.
4261 DenseMap<Operation *, Operation *> inverseOperationMap;
4262 for (auto &it : mapping.getOperationMap())
4263 inverseOperationMap[it.second] = it.first;
4264
4265 // Convert the cloned operations. The original IR will remain unchanged.
4266 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4267 ops, [&](Operation *op) { return mapping.lookup(op); });
4268 LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
4269 OpConversionMode::Analysis);
4270
4271 // Remap `legalizableOps`, so that they point to the original ops and not the
4272 // cloned ops.
4273 if (config.legalizableOps) {
4274 DenseSet<Operation *> originalLegalizableOps;
4275 for (Operation *op : *config.legalizableOps)
4276 originalLegalizableOps.insert(inverseOperationMap[op]);
4277 *config.legalizableOps = std::move(originalLegalizableOps);
4278 }
4279
4280 // Erase the cloned IR.
4281 clonedAncestor->erase();
4282 return status;
4283}
4284
4285LogicalResult
4286mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
4287 const FrozenRewritePatternSet &patterns,
4288 ConversionConfig config) {
4289 return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
4290}
return success()
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
#define DEBUG_TYPE
b getContext())
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition SCCP.cpp:67
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Definition Value.h:309
Location getLoc() const
Return the location for this argument.
Definition Value.h:324
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:149
bool empty()
Definition Block.h:148
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:308
BlockArgListType getArguments()
Definition Block.h:87
iterator end()
Definition Block.h:144
iterator begin()
Definition Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
UnitAttr getUnitAttr()
Definition Builders.cpp:98
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
MLIRContext * context
Definition Builders.h:202
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition Builders.h:327
Block::iterator getPoint() const
Definition Builders.h:340
bool isSet() const
Returns true if this insert point is set.
Definition Builders.h:337
Block * getBlock() const
Definition Builders.h:339
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition Builders.cpp:473
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:421
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for ops that are known to be isolated from above.
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
void destroyOpProperties(OpaqueProperties properties) const
This hooks destroy the op properties.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:226
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:852
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:834
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition Operation.h:900
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
CRTP Implementation of an action.
Definition Action.h:76
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Definition Action.h:66
AttrTypeReplacer.
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
This iterator enumerates elements according to their dominance relationship.
Definition Iterators.h:48
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition Builders.h:308
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition Builders.h:298
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
const ConversionTarget & getTarget()
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, bool isRecursiveLegalization=false)
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, Fn onFailure, bool isRecursiveLegalization=false)
Legalizes the given operations (and their nested operations) to the conversion target.
LogicalResult applyConversion(ArrayRef< Operation * > ops)
Applies the conversion to the given operations (and their nested operations).
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.