MLIR 22.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/Config/mlir-config.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Iterators.h"
17#include "mlir/IR/Operation.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/FormatVariadic.h"
25#include "llvm/Support/SaveAndRestore.h"
26#include "llvm/Support/ScopedPrinter.h"
27#include <optional>
28#include <utility>
29
30using namespace mlir;
31using namespace mlir::detail;
32
33#define DEBUG_TYPE "dialect-conversion"
34
35/// A utility function to log a successful result for the given reason.
36template <typename... Args>
37static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
38 LLVM_DEBUG({
39 os.unindent();
40 os.startLine() << "} -> SUCCESS";
41 if (!fmt.empty())
42 os.getOStream() << " : "
43 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
44 os.getOStream() << "\n";
45 });
46}
47
48/// A utility function to log a failure result for the given reason.
49template <typename... Args>
50static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
51 LLVM_DEBUG({
52 os.unindent();
53 os.startLine() << "} -> FAILURE : "
54 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
55 << "\n";
56 });
57}
58
59/// Helper function that computes an insertion point where the given value is
60/// defined and can be used without a dominance violation.
62 Block *insertBlock = value.getParentBlock();
63 Block::iterator insertPt = insertBlock->begin();
64 if (OpResult inputRes = dyn_cast<OpResult>(value))
65 insertPt = ++inputRes.getOwner()->getIterator();
66 return OpBuilder::InsertPoint(insertBlock, insertPt);
67}
68
69/// Helper function that computes an insertion point where the given values are
70/// defined and can be used without a dominance violation.
72 assert(!vals.empty() && "expected at least one value");
73 DominanceInfo domInfo;
75 for (Value v : vals.drop_front()) {
76 // Choose the "later" insertion point.
78 if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
79 nextPt.getPoint())) {
80 // pt is before nextPt => choose nextPt.
81 pt = nextPt;
82 } else {
83#ifndef NDEBUG
84 // nextPt should be before pt => choose pt.
85 // If pt, nextPt are no dominance relationship, then there is no valid
86 // insertion point at which all given values are defined.
87 bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
88 pt.getBlock(), pt.getPoint());
89 assert(dom && "unable to find valid insertion point");
90#endif // NDEBUG
91 }
92 }
93 return pt;
94}
95
96namespace {
97enum OpConversionMode {
98 /// In this mode, the conversion will ignore failed conversions to allow
99 /// illegal operations to co-exist in the IR.
100 Partial,
101
102 /// In this mode, all operations must be legal for the given target for the
103 /// conversion to succeed.
104 Full,
105
106 /// In this mode, operations are analyzed for legality. No actual rewrites are
107 /// applied to the operations on success.
108 Analysis,
109};
110} // namespace
111
112//===----------------------------------------------------------------------===//
113// ConversionValueMapping
114//===----------------------------------------------------------------------===//
115
116/// A vector of SSA values, optimized for the most common case of a single
117/// value.
119
120namespace {
121
122/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
123struct ValueVectorMapInfo {
124 static ValueVector getEmptyKey() { return ValueVector{Value()}; }
125 static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
126 static ::llvm::hash_code getHashValue(const ValueVector &val) {
127 return ::llvm::hash_combine_range(val);
128 }
129 static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
130 return LHS == RHS;
131 }
132};
133
134/// This class wraps a IRMapping to provide recursive lookup
135/// functionality, i.e. we will traverse if the mapped value also has a mapping.
136struct ConversionValueMapping {
137 /// Return "true" if an SSA value is mapped to the given value. May return
138 /// false positives.
139 bool isMappedTo(Value value) const { return mappedTo.contains(value); }
140
141 /// Lookup a value in the mapping.
142 ValueVector lookup(const ValueVector &from) const;
143
144 template <typename T>
145 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
146
147 /// Map a value vector to the one provided.
148 template <typename OldVal, typename NewVal>
149 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
150 map(OldVal &&oldVal, NewVal &&newVal) {
151 LLVM_DEBUG({
152 ValueVector next(newVal);
153 while (true) {
154 assert(next != oldVal && "inserting cyclic mapping");
155 auto it = mapping.find(next);
156 if (it == mapping.end())
157 break;
158 next = it->second;
159 }
160 });
161 mappedTo.insert_range(newVal);
162
163 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
164 }
165
166 /// Map a value vector or single value to the one provided.
167 template <typename OldVal, typename NewVal>
168 std::enable_if_t<!IsValueVector<OldVal>::value ||
169 !IsValueVector<NewVal>::value>
170 map(OldVal &&oldVal, NewVal &&newVal) {
171 if constexpr (IsValueVector<OldVal>{}) {
172 map(std::forward<OldVal>(oldVal), ValueVector{newVal});
173 } else if constexpr (IsValueVector<NewVal>{}) {
174 map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
175 } else {
176 map(ValueVector{oldVal}, ValueVector{newVal});
177 }
178 }
179
180 void map(Value oldVal, SmallVector<Value> &&newVal) {
181 map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
182 }
183
184 /// Drop the last mapping for the given values.
185 void erase(const ValueVector &value) { mapping.erase(value); }
186
187private:
188 /// Current value mappings.
190
191 /// All SSA values that are mapped to. May contain false positives.
192 DenseSet<Value> mappedTo;
193};
194} // namespace
195
196/// Marker attribute for pure type conversions. I.e., mappings whose only
197/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
198/// the replacement values of a "replaceOp" call, etc., are not pure type
199/// conversions.)
200static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
201
202/// Return the operation that defines all values in the vector. Return nullptr
203/// if the values are not defined by the same operation.
205 assert(!values.empty() && "expected non-empty value vector");
206 Operation *op = values.front().getDefiningOp();
207 for (Value v : llvm::drop_begin(values)) {
208 if (v.getDefiningOp() != op)
209 return nullptr;
210 }
211 return op;
212}
213
214/// A vector of values is a pure type conversion if all values are defined by
215/// the same operation and the operation has the `kPureTypeConversionMarker`
216/// attribute.
217static bool isPureTypeConversion(const ValueVector &values) {
218 assert(!values.empty() && "expected non-empty value vector");
219 Operation *op = getCommonDefiningOp(values);
220 return op && op->hasAttr(kPureTypeConversionMarker);
221}
222
223ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
224 auto it = mapping.find(from);
225 if (it == mapping.end()) {
226 // No mapping found: The lookup stops here.
227 return {};
228 }
229 return it->second;
230}
231
232//===----------------------------------------------------------------------===//
233// Rewriter and Translation State
234//===----------------------------------------------------------------------===//
235namespace {
236/// This class contains a snapshot of the current conversion rewriter state.
237/// This is useful when saving and undoing a set of rewrites.
238struct RewriterState {
239 RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
240 unsigned numReplacedOps)
241 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
242 numReplacedOps(numReplacedOps) {}
243
244 /// The current number of rewrites performed.
245 unsigned numRewrites;
246
247 /// The current number of ignored operations.
248 unsigned numIgnoredOperations;
249
250 /// The current number of replaced ops that are scheduled for erasure.
251 unsigned numReplacedOps;
252};
253
254//===----------------------------------------------------------------------===//
255// IR rewrites
256//===----------------------------------------------------------------------===//
257
258static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
259
260/// Notify the listener that the given block and its contents are being erased.
261static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
262 for (Operation &op : b)
263 notifyIRErased(listener, op);
264 listener->notifyBlockErased(&b);
265}
266
267/// Notify the listener that the given operation and its contents are being
268/// erased.
269static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
270 for (Region &r : op.getRegions()) {
271 for (Block &b : r) {
272 notifyIRErased(listener, b);
273 }
274 }
275 listener->notifyOperationErased(&op);
276}
277
278/// An IR rewrite that can be committed (upon success) or rolled back (upon
279/// failure).
280///
281/// The dialect conversion keeps track of IR modifications (requested by the
282/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
283/// are directly applied to the IR as the rewriter API is used, some are applied
284/// partially, and some are delayed until the `IRRewrite` objects are committed.
285class IRRewrite {
286public:
287 /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
288 /// Enum values are ordered, so that they can be used in `classof`: first all
289 /// block rewrites, then all operation rewrites.
290 enum class Kind {
291 // Block rewrites
292 CreateBlock,
293 EraseBlock,
294 InlineBlock,
295 MoveBlock,
296 BlockTypeConversion,
297 // Operation rewrites
298 MoveOperation,
299 ModifyOperation,
300 ReplaceOperation,
301 CreateOperation,
302 UnresolvedMaterialization,
303 // Value rewrites
304 ReplaceValue
305 };
306
307 virtual ~IRRewrite() = default;
308
309 /// Roll back the rewrite. Operations may be erased during rollback.
310 virtual void rollback() = 0;
311
312 /// Commit the rewrite. At this point, it is certain that the dialect
313 /// conversion will succeed. All IR modifications, except for operation/block
314 /// erasure, must be performed through the given rewriter.
315 ///
316 /// Instead of erasing operations/blocks, they should merely be unlinked
317 /// commit phase and finally be erased during the cleanup phase. This is
318 /// because internal dialect conversion state (such as `mapping`) may still
319 /// be using them.
320 ///
321 /// Any IR modification that was already performed before the commit phase
322 /// (e.g., insertion of an op) must be communicated to the listener that may
323 /// be attached to the given rewriter.
324 virtual void commit(RewriterBase &rewriter) {}
325
326 /// Cleanup operations/blocks. Cleanup is called after commit.
327 virtual void cleanup(RewriterBase &rewriter) {}
328
329 Kind getKind() const { return kind; }
330
331 static bool classof(const IRRewrite *rewrite) { return true; }
332
333protected:
334 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
335 : kind(kind), rewriterImpl(rewriterImpl) {}
336
337 const ConversionConfig &getConfig() const;
338
339 const Kind kind;
340 ConversionPatternRewriterImpl &rewriterImpl;
341};
342
343/// A block rewrite.
344class BlockRewrite : public IRRewrite {
345public:
346 /// Return the block that this rewrite operates on.
347 Block *getBlock() const { return block; }
348
349 static bool classof(const IRRewrite *rewrite) {
350 return rewrite->getKind() >= Kind::CreateBlock &&
351 rewrite->getKind() <= Kind::BlockTypeConversion;
352 }
353
354protected:
355 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
356 Block *block)
357 : IRRewrite(kind, rewriterImpl), block(block) {}
358
359 // The block that this rewrite operates on.
360 Block *block;
361};
362
363/// A value rewrite.
364class ValueRewrite : public IRRewrite {
365public:
366 /// Return the value that this rewrite operates on.
367 Value getValue() const { return value; }
368
369 static bool classof(const IRRewrite *rewrite) {
370 return rewrite->getKind() == Kind::ReplaceValue;
371 }
372
373protected:
374 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
375 Value value)
376 : IRRewrite(kind, rewriterImpl), value(value) {}
377
378 // The value that this rewrite operates on.
379 Value value;
380};
381
382/// Creation of a block. Block creations are immediately reflected in the IR.
383/// There is no extra work to commit the rewrite. During rollback, the newly
384/// created block is erased.
385class CreateBlockRewrite : public BlockRewrite {
386public:
387 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
388 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
389
390 static bool classof(const IRRewrite *rewrite) {
391 return rewrite->getKind() == Kind::CreateBlock;
392 }
393
394 void commit(RewriterBase &rewriter) override {
395 // The block was already created and inserted. Just inform the listener.
396 if (auto *listener = rewriter.getListener())
397 listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
398 }
399
400 void rollback() override {
401 // Unlink all of the operations within this block, they will be deleted
402 // separately.
403 auto &blockOps = block->getOperations();
404 while (!blockOps.empty())
405 blockOps.remove(blockOps.begin());
406 block->dropAllUses();
407 if (block->getParent())
408 block->erase();
409 else
410 delete block;
411 }
412};
413
414/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
415/// blocks are immediately unlinked, but only erased during cleanup. This makes
416/// it easier to rollback a block erasure: the block is simply inserted into its
417/// original location.
418class EraseBlockRewrite : public BlockRewrite {
419public:
420 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
421 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
422 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
423
424 static bool classof(const IRRewrite *rewrite) {
425 return rewrite->getKind() == Kind::EraseBlock;
426 }
427
428 ~EraseBlockRewrite() override {
429 assert(!block &&
430 "rewrite was neither rolled back nor committed/cleaned up");
431 }
432
433 void rollback() override {
434 // The block (owned by this rewrite) was not actually erased yet. It was
435 // just unlinked. Put it back into its original position.
436 assert(block && "expected block");
437 auto &blockList = region->getBlocks();
438 Region::iterator before = insertBeforeBlock
439 ? Region::iterator(insertBeforeBlock)
440 : blockList.end();
441 blockList.insert(before, block);
442 block = nullptr;
443 }
444
445 void commit(RewriterBase &rewriter) override {
446 assert(block && "expected block");
447
448 // Notify the listener that the block and its contents are being erased.
449 if (auto *listener =
450 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
451 notifyIRErased(listener, *block);
452 }
453
454 void cleanup(RewriterBase &rewriter) override {
455 // Erase the contents of the block.
456 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
457 rewriter.eraseOp(&op);
458 assert(block->empty() && "expected empty block");
459
460 // Erase the block.
461 block->dropAllDefinedValueUses();
462 delete block;
463 block = nullptr;
464 }
465
466private:
467 // The region in which this block was previously contained.
468 Region *region;
469
470 // The original successor of this block before it was unlinked. "nullptr" if
471 // this block was the only block in the region.
472 Block *insertBeforeBlock;
473};
474
475/// Inlining of a block. This rewrite is immediately reflected in the IR.
476/// Note: This rewrite represents only the inlining of the operations. The
477/// erasure of the inlined block is a separate rewrite.
478class InlineBlockRewrite : public BlockRewrite {
479public:
480 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
481 Block *sourceBlock, Block::iterator before)
482 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
483 sourceBlock(sourceBlock),
484 firstInlinedInst(sourceBlock->empty() ? nullptr
485 : &sourceBlock->front()),
486 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
487 // If a listener is attached to the dialect conversion, ops must be moved
488 // one-by-one. When they are moved in bulk, notifications cannot be sent
489 // because the ops that used to be in the source block at the time of the
490 // inlining (before the "commit" phase) are unknown at the time when
491 // notifications are sent (which is during the "commit" phase).
492 assert(!getConfig().listener &&
493 "InlineBlockRewrite not supported if listener is attached");
494 }
495
496 static bool classof(const IRRewrite *rewrite) {
497 return rewrite->getKind() == Kind::InlineBlock;
498 }
499
500 void rollback() override {
501 // Put the operations from the destination block (owned by the rewrite)
502 // back into the source block.
503 if (firstInlinedInst) {
504 assert(lastInlinedInst && "expected operation");
505 sourceBlock->getOperations().splice(sourceBlock->begin(),
506 block->getOperations(),
507 Block::iterator(firstInlinedInst),
508 ++Block::iterator(lastInlinedInst));
509 }
510 }
511
512private:
513 // The block that originally contained the operations.
514 Block *sourceBlock;
515
516 // The first inlined operation.
517 Operation *firstInlinedInst;
518
519 // The last inlined operation.
520 Operation *lastInlinedInst;
521};
522
523/// Moving of a block. This rewrite is immediately reflected in the IR.
524class MoveBlockRewrite : public BlockRewrite {
525public:
526 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
527 Region *previousRegion, Region::iterator previousIt)
528 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
529 region(previousRegion),
530 insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
531 : &*previousIt) {}
532
533 static bool classof(const IRRewrite *rewrite) {
534 return rewrite->getKind() == Kind::MoveBlock;
535 }
536
537 void commit(RewriterBase &rewriter) override {
538 // The block was already moved. Just inform the listener.
539 if (auto *listener = rewriter.getListener()) {
540 // Note: `previousIt` cannot be passed because this is a delayed
541 // notification and iterators into past IR state cannot be represented.
542 listener->notifyBlockInserted(block, /*previous=*/region,
543 /*previousIt=*/{});
544 }
545 }
546
547 void rollback() override {
548 // Move the block back to its original position.
549 Region::iterator before =
550 insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
551 region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
552 }
553
554private:
555 // The region in which this block was previously contained.
556 Region *region;
557
558 // The original successor of this block before it was moved. "nullptr" if
559 // this block was the only block in the region.
560 Block *insertBeforeBlock;
561};
562
563/// Block type conversion. This rewrite is partially reflected in the IR.
564class BlockTypeConversionRewrite : public BlockRewrite {
565public:
566 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
567 Block *origBlock, Block *newBlock)
568 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
569 newBlock(newBlock) {}
570
571 static bool classof(const IRRewrite *rewrite) {
572 return rewrite->getKind() == Kind::BlockTypeConversion;
573 }
574
575 Block *getOrigBlock() const { return block; }
576
577 Block *getNewBlock() const { return newBlock; }
578
579 void commit(RewriterBase &rewriter) override;
580
581 void rollback() override;
582
583private:
584 /// The new block that was created as part of this signature conversion.
585 Block *newBlock;
586};
587
588/// Replacing a value. This rewrite is not immediately reflected in the
589/// IR. An internal IR mapping is updated, but the actual replacement is delayed
590/// until the rewrite is committed.
591class ReplaceValueRewrite : public ValueRewrite {
592public:
593 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
594 const TypeConverter *converter)
595 : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
596 converter(converter) {}
597
598 static bool classof(const IRRewrite *rewrite) {
599 return rewrite->getKind() == Kind::ReplaceValue;
600 }
601
602 void commit(RewriterBase &rewriter) override;
603
604 void rollback() override;
605
606private:
607 /// The current type converter when the value was replaced.
608 const TypeConverter *converter;
609};
610
611/// An operation rewrite.
612class OperationRewrite : public IRRewrite {
613public:
614 /// Return the operation that this rewrite operates on.
615 Operation *getOperation() const { return op; }
616
617 static bool classof(const IRRewrite *rewrite) {
618 return rewrite->getKind() >= Kind::MoveOperation &&
619 rewrite->getKind() <= Kind::UnresolvedMaterialization;
620 }
621
622protected:
623 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
624 Operation *op)
625 : IRRewrite(kind, rewriterImpl), op(op) {}
626
627 // The operation that this rewrite operates on.
628 Operation *op;
629};
630
631/// Moving of an operation. This rewrite is immediately reflected in the IR.
632class MoveOperationRewrite : public OperationRewrite {
633public:
634 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
635 Operation *op, OpBuilder::InsertPoint previous)
636 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
637 block(previous.getBlock()),
638 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
639 ? nullptr
640 : &*previous.getPoint()) {}
641
642 static bool classof(const IRRewrite *rewrite) {
643 return rewrite->getKind() == Kind::MoveOperation;
644 }
645
646 void commit(RewriterBase &rewriter) override {
647 // The operation was already moved. Just inform the listener.
648 if (auto *listener = rewriter.getListener()) {
649 // Note: `previousIt` cannot be passed because this is a delayed
650 // notification and iterators into past IR state cannot be represented.
651 listener->notifyOperationInserted(
652 op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
653 /*insertPt=*/{}));
654 }
655 }
656
657 void rollback() override {
658 // Move the operation back to its original position.
659 Block::iterator before =
660 insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
661 block->getOperations().splice(before, op->getBlock()->getOperations(), op);
662 }
663
664private:
665 // The block in which this operation was previously contained.
666 Block *block;
667
668 // The original successor of this operation before it was moved. "nullptr"
669 // if this operation was the only operation in the region.
670 Operation *insertBeforeOp;
671};
672
673/// In-place modification of an op. This rewrite is immediately reflected in
674/// the IR. The previous state of the operation is stored in this object.
675class ModifyOperationRewrite : public OperationRewrite {
676public:
677 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
678 Operation *op)
679 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
680 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
681 operands(op->operand_begin(), op->operand_end()),
682 successors(op->successor_begin(), op->successor_end()) {
683 if (OpaqueProperties prop = op->getPropertiesStorage()) {
684 // Make a copy of the properties.
685 propertiesStorage = operator new(op->getPropertiesStorageSize());
686 OpaqueProperties propCopy(propertiesStorage);
687 name.initOpProperties(propCopy, /*init=*/prop);
688 }
689 }
690
691 static bool classof(const IRRewrite *rewrite) {
692 return rewrite->getKind() == Kind::ModifyOperation;
693 }
694
695 ~ModifyOperationRewrite() override {
696 assert(!propertiesStorage &&
697 "rewrite was neither committed nor rolled back");
698 }
699
700 void commit(RewriterBase &rewriter) override {
701 // Notify the listener that the operation was modified in-place.
702 if (auto *listener =
703 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
704 listener->notifyOperationModified(op);
705
706 if (propertiesStorage) {
707 OpaqueProperties propCopy(propertiesStorage);
708 // Note: The operation may have been erased in the mean time, so
709 // OperationName must be stored in this object.
710 name.destroyOpProperties(propCopy);
711 operator delete(propertiesStorage);
712 propertiesStorage = nullptr;
713 }
714 }
715
716 void rollback() override {
717 op->setLoc(loc);
718 op->setAttrs(attrs);
719 op->setOperands(operands);
720 for (const auto &it : llvm::enumerate(successors))
721 op->setSuccessor(it.value(), it.index());
722 if (propertiesStorage) {
723 OpaqueProperties propCopy(propertiesStorage);
724 op->copyProperties(propCopy);
725 name.destroyOpProperties(propCopy);
726 operator delete(propertiesStorage);
727 propertiesStorage = nullptr;
728 }
729 }
730
731private:
732 OperationName name;
733 LocationAttr loc;
734 DictionaryAttr attrs;
735 SmallVector<Value, 8> operands;
736 SmallVector<Block *, 2> successors;
737 void *propertiesStorage = nullptr;
738};
739
740/// Replacing an operation. Erasing an operation is treated as a special case
741/// with "null" replacements. This rewrite is not immediately reflected in the
742/// IR. An internal IR mapping is updated, but values are not replaced and the
743/// original op is not erased until the rewrite is committed.
744class ReplaceOperationRewrite : public OperationRewrite {
745public:
746 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
747 Operation *op, const TypeConverter *converter)
748 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
749 converter(converter) {}
750
751 static bool classof(const IRRewrite *rewrite) {
752 return rewrite->getKind() == Kind::ReplaceOperation;
753 }
754
755 void commit(RewriterBase &rewriter) override;
756
757 void rollback() override;
758
759 void cleanup(RewriterBase &rewriter) override;
760
761private:
762 /// An optional type converter that can be used to materialize conversions
763 /// between the new and old values if necessary.
764 const TypeConverter *converter;
765};
766
767class CreateOperationRewrite : public OperationRewrite {
768public:
769 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
770 Operation *op)
771 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
772
773 static bool classof(const IRRewrite *rewrite) {
774 return rewrite->getKind() == Kind::CreateOperation;
775 }
776
777 void commit(RewriterBase &rewriter) override {
778 // The operation was already created and inserted. Just inform the listener.
779 if (auto *listener = rewriter.getListener())
780 listener->notifyOperationInserted(op, /*previous=*/{});
781 }
782
783 void rollback() override;
784};
785
786/// The type of materialization.
787enum MaterializationKind {
788 /// This materialization materializes a conversion from an illegal type to a
789 /// legal one.
790 Target,
791
792 /// This materialization materializes a conversion from a legal type back to
793 /// an illegal one.
794 Source
795};
796
797/// Helper class that stores metadata about an unresolved materialization.
798class UnresolvedMaterializationInfo {
799public:
800 UnresolvedMaterializationInfo() = default;
801 UnresolvedMaterializationInfo(const TypeConverter *converter,
802 MaterializationKind kind, Type originalType)
803 : converterAndKind(converter, kind), originalType(originalType) {}
804
805 /// Return the type converter of this materialization (which may be null).
806 const TypeConverter *getConverter() const {
807 return converterAndKind.getPointer();
808 }
809
810 /// Return the kind of this materialization.
811 MaterializationKind getMaterializationKind() const {
812 return converterAndKind.getInt();
813 }
814
815 /// Return the original type of the SSA value.
816 Type getOriginalType() const { return originalType; }
817
818private:
819 /// The corresponding type converter to use when resolving this
820 /// materialization, and the kind of this materialization.
821 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
822 converterAndKind;
823
824 /// The original type of the SSA value. Only used for target
825 /// materializations.
826 Type originalType;
827};
828
829/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
830/// op. Unresolved materializations fold away or are replaced with
831/// source/target materializations at the end of the dialect conversion.
832class UnresolvedMaterializationRewrite : public OperationRewrite {
833public:
834 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
835 UnrealizedConversionCastOp op,
836 ValueVector mappedValues)
837 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
838 mappedValues(std::move(mappedValues)) {}
839
840 static bool classof(const IRRewrite *rewrite) {
841 return rewrite->getKind() == Kind::UnresolvedMaterialization;
842 }
843
844 void rollback() override;
845
846 UnrealizedConversionCastOp getOperation() const {
847 return cast<UnrealizedConversionCastOp>(op);
848 }
849
850private:
851 /// The values in the conversion value mapping that are being replaced by the
852 /// results of this unresolved materialization.
853 ValueVector mappedValues;
854};
855} // namespace
856
857#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
858/// Return "true" if there is an operation rewrite that matches the specified
859/// rewrite type and operation among the given rewrites.
860template <typename RewriteTy, typename R>
861static bool hasRewrite(R &&rewrites, Operation *op) {
862 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
863 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
864 return rewriteTy && rewriteTy->getOperation() == op;
865 });
866}
867
868/// Return "true" if there is a block rewrite that matches the specified
869/// rewrite type and block among the given rewrites.
870template <typename RewriteTy, typename R>
871static bool hasRewrite(R &&rewrites, Block *block) {
872 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
873 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
874 return rewriteTy && rewriteTy->getBlock() == block;
875 });
876}
877#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
878
879//===----------------------------------------------------------------------===//
880// ConversionPatternRewriterImpl
881//===----------------------------------------------------------------------===//
882namespace mlir {
883namespace detail {
885 explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
886 const ConversionConfig &config,
890
891 //===--------------------------------------------------------------------===//
892 // State Management
893 //===--------------------------------------------------------------------===//
894
895 /// Return the current state of the rewriter.
896 RewriterState getCurrentState();
897
898 /// Apply all requested operation rewrites. This method is invoked when the
899 /// conversion process succeeds.
900 void applyRewrites();
901
902 /// Reset the state of the rewriter to a previously saved point. Optionally,
903 /// the name of the pattern that triggered the rollback can specified for
904 /// debugging purposes.
905 void resetState(RewriterState state, StringRef patternName = "");
906
907 /// Append a rewrite. Rewrites are committed upon success and rolled back upon
908 /// failure.
909 template <typename RewriteTy, typename... Args>
910 void appendRewrite(Args &&...args) {
911 assert(config.allowPatternRollback && "appending rewrites is not allowed");
912 rewrites.push_back(
913 std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
914 }
915
916 /// Undo the rewrites (motions, splits) one by one in reverse order until
917 /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
918 /// that triggered the rollback can specified for debugging purposes.
919 void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
920
921 /// Remap the given values to those with potentially different types. Returns
922 /// success if the values could be remapped, failure otherwise. `valueDiagTag`
923 /// is the tag used when describing a value within a diagnostic, e.g.
924 /// "operand".
925 LogicalResult remapValues(StringRef valueDiagTag,
926 std::optional<Location> inputLoc, ValueRange values,
927 SmallVector<ValueVector> &remapped);
928
929 /// Return "true" if the given operation is ignored, and does not need to be
930 /// converted.
931 bool isOpIgnored(Operation *op) const;
932
933 /// Return "true" if the given operation was replaced or erased.
934 bool wasOpReplaced(Operation *op) const;
935
936 /// Lookup the most recently mapped values with the desired types in the
937 /// mapping, taking into account only replacements. Perform a best-effort
938 /// search for existing materializations with the desired types.
939 ///
940 /// If `skipPureTypeConversions` is "true", materializations that are pure
941 /// type conversions are not considered.
942 ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
943 bool skipPureTypeConversions = false) const;
944
945 /// Lookup the given value within the map, or return an empty vector if the
946 /// value is not mapped. If it is mapped, this follows the same behavior
947 /// as `lookupOrDefault`.
948 ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
949
950 //===--------------------------------------------------------------------===//
951 // IR Rewrites / Type Conversion
952 //===--------------------------------------------------------------------===//
953
954 /// Convert the types of block arguments within the given region.
955 FailureOr<Block *>
956 convertRegionTypes(Region *region, const TypeConverter &converter,
957 TypeConverter::SignatureConversion *entryConversion);
958
959 /// Apply the given signature conversion on the given block. The new block
960 /// containing the updated signature is returned. If no conversions were
961 /// necessary, e.g. if the block has no arguments, `block` is returned.
962 /// `converter` is used to generate any necessary cast operations that
963 /// translate between the origin argument types and those specified in the
964 /// signature conversion.
965 Block *applySignatureConversion(
966 Block *block, const TypeConverter *converter,
967 TypeConverter::SignatureConversion &signatureConversion);
968
969 /// Replace the results of the given operation with the given values and
970 /// erase the operation.
971 ///
972 /// There can be multiple replacement values for each result (1:N
973 /// replacement). If the replacement values are empty, the respective result
974 /// is dropped and a source materialization is built if the result still has
975 /// uses.
976 void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
977
978 /// Replace the uses of the given value with the given values. The specified
979 /// converter is used to build materializations (if necessary). If `functor`
980 /// is specified, only the uses that the functor returns "true" for are
981 /// replaced.
982 void replaceValueUses(Value from, ValueRange to,
983 const TypeConverter *converter,
984 function_ref<bool(OpOperand &)> functor = nullptr);
985
986 /// Erase the given block and its contents.
987 void eraseBlock(Block *block);
988
989 /// Inline the source block into the destination block before the given
990 /// iterator.
991 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
992
993 //===--------------------------------------------------------------------===//
994 // Materializations
995 //===--------------------------------------------------------------------===//
996
997 /// Build an unresolved materialization operation given a range of output
998 /// types and a list of input operands. Returns the inputs if they their
999 /// types match the output types.
1000 ///
1001 /// If a cast op was built, it can optionally be returned with the `castOp`
1002 /// output argument.
1003 ///
1004 /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
1005 /// the results of the unresolved materialization in the conversion value
1006 /// mapping.
1007 ///
1008 /// If `isPureTypeConversion` is "true", the materialization is created only
1009 /// to resolve a type mismatch. That means it is not a regular value
1010 /// replacement issued by the user. (Replacement values that are created
1011 /// "out of thin air" appear like unresolved materializations because they are
1012 /// unrealized_conversion_cast ops. However, they must be treated like
1013 /// regular value replacements.)
1014 ValueRange buildUnresolvedMaterialization(
1015 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1016 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1017 Type originalType, const TypeConverter *converter,
1018 bool isPureTypeConversion = true);
1019
1020 /// Find a replacement value for the given SSA value in the conversion value
1021 /// mapping. The replacement value must have the same type as the given SSA
1022 /// value. If there is no replacement value with the correct type, find the
1023 /// latest replacement value (regardless of the type) and build a source
1024 /// materialization.
1025 Value findOrBuildReplacementValue(Value value,
1026 const TypeConverter *converter);
1027
1028 //===--------------------------------------------------------------------===//
1029 // Rewriter Notification Hooks
1030 //===--------------------------------------------------------------------===//
1031
1032 //// Notifies that an op was inserted.
1033 void notifyOperationInserted(Operation *op,
1034 OpBuilder::InsertPoint previous) override;
1035
1036 /// Notifies that a block was inserted.
1037 void notifyBlockInserted(Block *block, Region *previous,
1038 Region::iterator previousIt) override;
1039
1040 /// Notifies that a pattern match failed for the given reason.
1041 void
1042 notifyMatchFailure(Location loc,
1043 function_ref<void(Diagnostic &)> reasonCallback) override;
1044
1045 //===--------------------------------------------------------------------===//
1046 // IR Erasure
1047 //===--------------------------------------------------------------------===//
1048
1049 /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1050 /// operation or block is erased multiple times. This rewriter assumes that
1051 /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1053 public:
1056 std::function<void(Operation *)> opErasedCallback = nullptr)
1057 : RewriterBase(context, /*listener=*/this),
1058 opErasedCallback(std::move(opErasedCallback)) {}
1059
1060 /// Erase the given op (unless it was already erased).
1061 void eraseOp(Operation *op) override {
1062 if (wasErased(op))
1063 return;
1064 op->dropAllUses();
1066 }
1067
1068 /// Erase the given block (unless it was already erased).
1069 void eraseBlock(Block *block) override {
1070 if (wasErased(block))
1071 return;
1072 assert(block->empty() && "expected empty block");
1073 block->dropAllDefinedValueUses();
1075 }
1076
1077 bool wasErased(void *ptr) const { return erased.contains(ptr); }
1078
1080 erased.insert(op);
1081 if (opErasedCallback)
1082 opErasedCallback(op);
1083 }
1084
1085 void notifyBlockErased(Block *block) override { erased.insert(block); }
1086
1087 private:
1088 /// Pointers to all erased operations and blocks.
1089 DenseSet<void *> erased;
1090
1091 /// A callback that is invoked when an operation is erased.
1092 std::function<void(Operation *)> opErasedCallback;
1093 };
1094
1095 //===--------------------------------------------------------------------===//
1096 // State
1097 //===--------------------------------------------------------------------===//
1098
1099 /// The rewriter that is used to perform the conversion.
1100 ConversionPatternRewriter &rewriter;
1101
1102 // Mapping between replaced values that differ in type. This happens when
1103 // replacing a value with one of a different type.
1104 ConversionValueMapping mapping;
1105
1106 /// Ordered list of block operations (creations, splits, motions).
1107 /// This vector is maintained only if `allowPatternRollback` is set to
1108 /// "true". Otherwise, all IR rewrites are materialized immediately and no
1109 /// bookkeeping is needed.
1111
1112 /// A set of operations that should no longer be considered for legalization.
1113 /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1114 /// tracked separately.
1116
1117 /// A set of operations that were replaced/erased. Such ops are not erased
1118 /// immediately but only when the dialect conversion succeeds. In the mean
1119 /// time, they should no longer be considered for legalization and any attempt
1120 /// to modify/access them is invalid rewriter API usage.
1122
1123 /// A set of operations that were created by the current pattern.
1125
1126 /// A set of operations that were modified by the current pattern.
1128
1129 /// A list of unresolved materializations that were created by the current
1130 /// pattern.
1132
1133 /// A mapping for looking up metadata of unresolved materializations.
1136
1137 /// The current type converter, or nullptr if no type converter is currently
1138 /// active.
1140
1141 /// A mapping of regions to type converters that should be used when
1142 /// converting the arguments of blocks within that region.
1144
1145 /// Dialect conversion configuration.
1146 const ConversionConfig &config;
1147
1148 /// The operation converter to use for recursive legalization.
1150
1151 /// A set of erased operations. This set is utilized only if
1152 /// `allowPatternRollback` is set to "false". Conceptually, this set is
1153 /// similar to `replacedOps` (which is maintained when the flag is set to
1154 /// "true"). However, erasing from a DenseSet is more efficient than erasing
1155 /// from a SetVector.
1157
1158 /// A set of erased blocks. This set is utilized only if
1159 /// `allowPatternRollback` is set to "false".
1161
1162 /// A rewriter that notifies the listener (if any) about all IR
1163 /// modifications. This rewriter is utilized only if `allowPatternRollback`
1164 /// is set to "false". If the flag is set to "true", the listener is notified
1165 /// with a separate mechanism (e.g., in `IRRewrite::commit`).
1167
1168#ifndef NDEBUG
1169 /// A set of replaced values. This set is for debugging purposes only and it
1170 /// is maintained only if `allowPatternRollback` is set to "true".
1172
1173 /// A set of operations that have pending updates. This tracking isn't
1174 /// strictly necessary, and is thus only active during debug builds for extra
1175 /// verification.
1177
1178 /// A raw output stream used to prefix the debug log.
1179 llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
1180 llvm::dbgs()};
1181
1182 /// A logger used to emit diagnostics during the conversion process.
1183 llvm::ScopedPrinter logger{os};
1184 std::string logPrefix;
1185#endif
1186};
1187} // namespace detail
1188} // namespace mlir
1189
1190const ConversionConfig &IRRewrite::getConfig() const {
1191 return rewriterImpl.config;
1192}
1193
1194void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1195 // Inform the listener about all IR modifications that have already taken
1196 // place: References to the original block have been replaced with the new
1197 // block.
1198 if (auto *listener =
1199 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1200 for (Operation *op : getNewBlock()->getUsers())
1201 listener->notifyOperationModified(op);
1202}
1203
1204void BlockTypeConversionRewrite::rollback() {
1205 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1206}
1207
1208/// Replace all uses of `from` with `repl`.
1209static void
1211 function_ref<bool(OpOperand &)> functor = nullptr) {
1212 if (isa<BlockArgument>(repl)) {
1213 // `repl` is a block argument. Directly replace all uses.
1214 if (functor) {
1215 rewriter.replaceUsesWithIf(from, repl, functor);
1216 } else {
1217 rewriter.replaceAllUsesWith(from, repl);
1218 }
1219 return;
1220 }
1221
1222 // If the replacement value is an operation, only replace those uses that:
1223 // - are in a different block than the replacement operation, or
1224 // - are in the same block but after the replacement operation.
1225 //
1226 // Example:
1227 // ^bb0(%arg0: i32):
1228 // %0 = "consumer"(%arg0) : (i32) -> (i32)
1229 // "another_consumer"(%arg0) : (i32) -> ()
1230 //
1231 // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1232 // use in "another_consumer" but not the use in "consumer". When using the
1233 // normal RewriterBase API, this would typically be done with
1234 // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1235 // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1236 // it cannot be supported efficiently with `allowPatternRollback` set to
1237 // "true". Therefore, the conversion driver is trying to be smart and replaces
1238 // only those uses that do not lead to a dominance violation. E.g., the
1239 // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1240 // behavior.
1241 //
1242 // TODO: As we move more and more towards `allowPatternRollback` set to
1243 // "false", we should remove this special handling, in order to align the
1244 // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1245 Operation *replOp = repl.getDefiningOp();
1246 Block *replBlock = replOp->getBlock();
1247 rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
1248 Operation *user = operand.getOwner();
1249 bool result =
1250 user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1251 if (result && functor)
1252 result &= functor(operand);
1253 return result;
1254 });
1255}
1256
1257void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1258 Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
1259 if (!repl)
1260 return;
1261 performReplaceValue(rewriter, value, repl);
1262}
1263
1264void ReplaceValueRewrite::rollback() {
1265 rewriterImpl.mapping.erase({value});
1266#ifndef NDEBUG
1267 rewriterImpl.replacedValues.erase(value);
1268#endif // NDEBUG
1269}
1270
1271void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1272 auto *listener =
1273 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1274
1275 // Compute replacement values.
1276 SmallVector<Value> replacements =
1277 llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1278 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1279 });
1280
1281 // Notify the listener that the operation is about to be replaced.
1282 if (listener)
1283 listener->notifyOperationReplaced(op, replacements);
1284
1285 // Replace all uses with the new values.
1286 for (auto [result, newValue] :
1287 llvm::zip_equal(op->getResults(), replacements))
1288 if (newValue)
1289 rewriter.replaceAllUsesWith(result, newValue);
1290
1291 // The original op will be erased, so remove it from the set of unlegalized
1292 // ops.
1293 if (getConfig().unlegalizedOps)
1294 getConfig().unlegalizedOps->erase(op);
1295
1296 // Notify the listener that the operation and its contents are being erased.
1297 if (listener)
1298 notifyIRErased(listener, *op);
1299
1300 // Do not erase the operation yet. It may still be referenced in `mapping`.
1301 // Just unlink it for now and erase it during cleanup.
1302 op->getBlock()->getOperations().remove(op);
1303}
1304
1305void ReplaceOperationRewrite::rollback() {
1306 for (auto result : op->getResults())
1307 rewriterImpl.mapping.erase({result});
1308}
1309
1310void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1311 rewriter.eraseOp(op);
1312}
1313
1314void CreateOperationRewrite::rollback() {
1315 for (Region &region : op->getRegions()) {
1316 while (!region.getBlocks().empty())
1317 region.getBlocks().remove(region.getBlocks().begin());
1318 }
1319 op->dropAllUses();
1320 op->erase();
1321}
1322
1323void UnresolvedMaterializationRewrite::rollback() {
1324 if (!mappedValues.empty())
1325 rewriterImpl.mapping.erase(mappedValues);
1326 rewriterImpl.unresolvedMaterializations.erase(getOperation());
1327 op->erase();
1328}
1329
1331 // Commit all rewrites. Use a new rewriter, so the modifications are not
1332 // tracked for rollback purposes etc.
1333 IRRewriter irRewriter(rewriter.getContext(), config.listener);
1334 // Note: New rewrites may be added during the "commit" phase and the
1335 // `rewrites` vector may reallocate.
1336 for (size_t i = 0; i < rewrites.size(); ++i)
1337 rewrites[i]->commit(irRewriter);
1338
1339 // Clean up all rewrites.
1340 SingleEraseRewriter eraseRewriter(
1341 rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
1342 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1343 unresolvedMaterializations.erase(castOp);
1344 });
1345 for (auto &rewrite : rewrites)
1346 rewrite->cleanup(eraseRewriter);
1347}
1348
1349//===----------------------------------------------------------------------===//
1350// State Management
1351//===----------------------------------------------------------------------===//
1352
1354 Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1355 // Helper function that looks up a single value.
1356 auto lookup = [&](const ValueVector &values) -> ValueVector {
1357 assert(!values.empty() && "expected non-empty value vector");
1358
1359 // If the pattern rollback is enabled, use the mapping to look up the
1360 // values.
1361 if (config.allowPatternRollback)
1362 return mapping.lookup(values);
1363
1364 // Otherwise, look up values by examining the IR. All replacements have
1365 // already been materialized in IR.
1366 Operation *op = getCommonDefiningOp(values);
1367 if (!op)
1368 return {};
1369 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1370 if (!castOp)
1371 return {};
1372 if (!this->unresolvedMaterializations.contains(castOp))
1373 return {};
1374 if (castOp.getOutputs() != values)
1375 return {};
1376 return castOp.getInputs();
1377 };
1378
1379 // Helper function that looks up each value in `values` individually and then
1380 // composes the results. If that fails, it tries to look up the entire vector
1381 // at once.
1382 auto composedLookup = [&](const ValueVector &values) -> ValueVector {
1383 // If possible, replace each value with (one or multiple) mapped values.
1384 ValueVector next;
1385 for (Value v : values) {
1386 ValueVector r = lookup({v});
1387 if (!r.empty()) {
1388 llvm::append_range(next, r);
1389 } else {
1390 next.push_back(v);
1391 }
1392 }
1393 if (next != values) {
1394 // At least one value was replaced.
1395 return next;
1396 }
1397
1398 // Otherwise: Check if there is a mapping for the entire vector. Such
1399 // mappings are materializations. (N:M mapping are not supported for value
1400 // replacements.)
1401 //
1402 // Note: From a correctness point of view, materializations do not have to
1403 // be stored (and looked up) in the mapping. But for performance reasons,
1404 // we choose to reuse existing IR (when possible) instead of creating it
1405 // multiple times.
1406 ValueVector r = lookup(values);
1407 if (r.empty()) {
1408 // No mapping found: The lookup stops here.
1409 return {};
1410 }
1411 return r;
1412 };
1413
1414 // Try to find the deepest values that have the desired types. If there is no
1415 // such mapping, simply return the deepest values.
1416 ValueVector desiredValue;
1417 ValueVector current{from};
1418 ValueVector lastNonMaterialization{from};
1419 do {
1420 // Store the current value if the types match.
1421 bool match = TypeRange(ValueRange(current)) == desiredTypes;
1422 if (skipPureTypeConversions) {
1423 // Skip pure type conversions, if requested.
1424 bool pureConversion = isPureTypeConversion(current);
1425 match &= !pureConversion;
1426 // Keep track of the last mapped value that was not a pure type
1427 // conversion.
1428 if (!pureConversion)
1429 lastNonMaterialization = current;
1430 }
1431 if (match)
1432 desiredValue = current;
1433
1434 // Lookup next value in the mapping.
1435 ValueVector next = composedLookup(current);
1436 if (next.empty())
1437 break;
1438 current = std::move(next);
1439 } while (true);
1440
1441 // If the desired values were found use them, otherwise default to the leaf
1442 // values. (Skip pure type conversions, if requested.)
1443 if (!desiredTypes.empty())
1444 return desiredValue;
1445 if (skipPureTypeConversions)
1446 return lastNonMaterialization;
1447 return current;
1448}
1449
1452 TypeRange desiredTypes) const {
1453 ValueVector result = lookupOrDefault(from, desiredTypes);
1454 if (result == ValueVector{from} ||
1455 (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
1456 return {};
1457 return result;
1458}
1459
1461 return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1462}
1463
1465 StringRef patternName) {
1466 // Undo any rewrites.
1467 undoRewrites(state.numRewrites, patternName);
1468
1469 // Pop all of the recorded ignored operations that are no longer valid.
1470 while (ignoredOps.size() != state.numIgnoredOperations)
1471 ignoredOps.pop_back();
1472
1473 while (replacedOps.size() != state.numReplacedOps)
1474 replacedOps.pop_back();
1475}
1476
1477void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1478 StringRef patternName) {
1479 for (auto &rewrite :
1480 llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1481 rewrite->rollback();
1482 rewrites.resize(numRewritesToKeep);
1483}
1484
1486 StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
1487 SmallVector<ValueVector> &remapped) {
1488 remapped.reserve(llvm::size(values));
1489
1490 for (const auto &it : llvm::enumerate(values)) {
1491 Value operand = it.value();
1492 Type origType = operand.getType();
1493 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1494
1495 if (!currentTypeConverter) {
1496 // The current pattern does not have a type converter. Pass the most
1497 // recently mapped values, excluding materializations. Materializations
1498 // are intentionally excluded because their presence may depend on other
1499 // patterns. Including materializations would make the lookup fragile
1500 // and unpredictable.
1501 remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1502 /*skipPureTypeConversions=*/true));
1503 continue;
1504 }
1505
1506 // If there is no legal conversion, fail to match this pattern.
1507 SmallVector<Type, 1> legalTypes;
1508 if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
1509 notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1510 diag << "unable to convert type for " << valueDiagTag << " #"
1511 << it.index() << ", type was " << origType;
1512 });
1513 return failure();
1514 }
1515 // If a type is converted to 0 types, there is nothing to do.
1516 if (legalTypes.empty()) {
1517 remapped.push_back({});
1518 continue;
1519 }
1520
1521 ValueVector repl = lookupOrDefault(operand, legalTypes);
1522 if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1523 // Mapped values have the correct type or there is an existing
1524 // materialization. Or the operand is not mapped at all and has the
1525 // correct type.
1526 remapped.push_back(std::move(repl));
1527 continue;
1528 }
1529
1530 // Create a materialization for the most recently mapped values.
1531 repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1532 /*skipPureTypeConversions=*/true);
1534 MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1535 /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1536 /*originalType=*/origType, currentTypeConverter);
1537 remapped.push_back(castValues);
1538 }
1539 return success();
1540}
1541
1543 // Check to see if this operation is ignored or was replaced.
1544 return wasOpReplaced(op) || ignoredOps.count(op);
1545}
1546
1548 // Check to see if this operation was replaced.
1549 return replacedOps.count(op) || erasedOps.count(op);
1550}
1551
1552//===----------------------------------------------------------------------===//
1553// Type Conversion
1554//===----------------------------------------------------------------------===//
1555
1557 Region *region, const TypeConverter &converter,
1558 TypeConverter::SignatureConversion *entryConversion) {
1559 regionToConverter[region] = &converter;
1560 if (region->empty())
1561 return nullptr;
1562
1563 // Convert the arguments of each non-entry block within the region.
1564 for (Block &block :
1565 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1566 // Compute the signature for the block with the provided converter.
1567 std::optional<TypeConverter::SignatureConversion> conversion =
1568 converter.convertBlockSignature(&block);
1569 if (!conversion)
1570 return failure();
1571 // Convert the block with the computed signature.
1572 applySignatureConversion(&block, &converter, *conversion);
1573 }
1574
1575 // Convert the entry block. If an entry signature conversion was provided,
1576 // use that one. Otherwise, compute the signature with the type converter.
1577 if (entryConversion)
1578 return applySignatureConversion(&region->front(), &converter,
1579 *entryConversion);
1580 std::optional<TypeConverter::SignatureConversion> conversion =
1581 converter.convertBlockSignature(&region->front());
1582 if (!conversion)
1583 return failure();
1584 return applySignatureConversion(&region->front(), &converter, *conversion);
1585}
1586
1588 Block *block, const TypeConverter *converter,
1589 TypeConverter::SignatureConversion &signatureConversion) {
1590#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1591 // A block cannot be converted multiple times.
1592 if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1593 llvm::report_fatal_error("block was already converted");
1594#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1595
1597
1598 // If no arguments are being changed or added, there is nothing to do.
1599 unsigned origArgCount = block->getNumArguments();
1600 auto convertedTypes = signatureConversion.getConvertedTypes();
1601 if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1602 return block;
1603
1604 // Compute the locations of all block arguments in the new block.
1605 SmallVector<Location> newLocs(convertedTypes.size(),
1606 rewriter.getUnknownLoc());
1607 for (unsigned i = 0; i < origArgCount; ++i) {
1608 auto inputMap = signatureConversion.getInputMapping(i);
1609 if (!inputMap || inputMap->replacedWithValues())
1610 continue;
1611 Location origLoc = block->getArgument(i).getLoc();
1612 for (unsigned j = 0; j < inputMap->size; ++j)
1613 newLocs[inputMap->inputNo + j] = origLoc;
1614 }
1615
1616 // Insert a new block with the converted block argument types and move all ops
1617 // from the old block to the new block.
1618 Block *newBlock =
1619 rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1620 convertedTypes, newLocs);
1621
1622 // If a listener is attached to the dialect conversion, ops cannot be moved
1623 // to the destination block in bulk ("fast path"). This is because at the time
1624 // the notifications are sent, it is unknown which ops were moved. Instead,
1625 // ops should be moved one-by-one ("slow path"), so that a separate
1626 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1627 // a bit more efficient, so we try to do that when possible.
1628 bool fastPath = !config.listener;
1629 if (fastPath) {
1630 if (config.allowPatternRollback)
1631 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1632 newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1633 } else {
1634 while (!block->empty())
1635 rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1636 }
1637
1638 // Replace all uses of the old block with the new block.
1639 block->replaceAllUsesWith(newBlock);
1640
1641 for (unsigned i = 0; i != origArgCount; ++i) {
1642 BlockArgument origArg = block->getArgument(i);
1643 Type origArgType = origArg.getType();
1644
1645 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1646 signatureConversion.getInputMapping(i);
1647 if (!inputMap) {
1648 // This block argument was dropped and no replacement value was provided.
1649 // Materialize a replacement value "out of thin air".
1650 // Note: Materialization must be built here because we cannot find a
1651 // valid insertion point in the new block. (Will point to the old block.)
1652 Value mat =
1654 MaterializationKind::Source,
1655 OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1656 origArg.getLoc(),
1657 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1658 /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1659 /*isPureTypeConversion=*/false)
1660 .front();
1661 replaceValueUses(origArg, mat, converter);
1662 continue;
1663 }
1664
1665 if (inputMap->replacedWithValues()) {
1666 // This block argument was dropped and replacement values were provided.
1667 assert(inputMap->size == 0 &&
1668 "invalid to provide a replacement value when the argument isn't "
1669 "dropped");
1670 replaceValueUses(origArg, inputMap->replacementValues, converter);
1671 continue;
1672 }
1673
1674 // This is a 1->1+ mapping.
1675 auto replArgs =
1676 newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1677 replaceValueUses(origArg, replArgs, converter);
1678 }
1679
1680 if (config.allowPatternRollback)
1681 appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1682
1683 // Erase the old block. (It is just unlinked for now and will be erased during
1684 // cleanup.)
1685 rewriter.eraseBlock(block);
1686
1687 return newBlock;
1688}
1689
1690//===----------------------------------------------------------------------===//
1691// Materializations
1692//===----------------------------------------------------------------------===//
1693
1694/// Build an unresolved materialization operation given an output type and set
1695/// of input operands.
1697 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1698 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1699 Type originalType, const TypeConverter *converter,
1700 bool isPureTypeConversion) {
1701 assert((!originalType || kind == MaterializationKind::Target) &&
1702 "original type is valid only for target materializations");
1703 assert(TypeRange(inputs) != outputTypes &&
1704 "materialization is not necessary");
1705
1706 // Create an unresolved materialization. We use a new OpBuilder to avoid
1707 // tracking the materialization like we do for other operations.
1708 OpBuilder builder(outputTypes.front().getContext());
1709 builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1710 UnrealizedConversionCastOp convertOp =
1711 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1712 if (config.attachDebugMaterializationKind) {
1713 StringRef kindStr =
1714 kind == MaterializationKind::Source ? "source" : "target";
1715 convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
1716 }
1718 convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
1719
1720 // Register the materialization.
1721 unresolvedMaterializations[convertOp] =
1722 UnresolvedMaterializationInfo(converter, kind, originalType);
1723 if (config.allowPatternRollback) {
1724 if (!valuesToMap.empty())
1725 mapping.map(valuesToMap, convertOp.getResults());
1727 std::move(valuesToMap));
1728 } else {
1729 patternMaterializations.insert(convertOp);
1730 }
1731 return convertOp.getResults();
1732}
1733
1735 Value value, const TypeConverter *converter) {
1736 assert(config.allowPatternRollback &&
1737 "this code path is valid only in rollback mode");
1738
1739 // Try to find a replacement value with the same type in the conversion value
1740 // mapping. This includes cached materializations. We try to reuse those
1741 // instead of generating duplicate IR.
1742 ValueVector repl = lookupOrNull(value, value.getType());
1743 if (!repl.empty())
1744 return repl.front();
1745
1746 // Check if the value is dead. No replacement value is needed in that case.
1747 // This is an approximate check that may have false negatives but does not
1748 // require computing and traversing an inverse mapping. (We may end up
1749 // building source materializations that are never used and that fold away.)
1750 if (llvm::all_of(value.getUsers(),
1751 [&](Operation *op) { return replacedOps.contains(op); }) &&
1752 !mapping.isMappedTo(value))
1753 return Value();
1754
1755 // No replacement value was found. Get the latest replacement value
1756 // (regardless of the type) and build a source materialization to the
1757 // original type.
1758 repl = lookupOrNull(value);
1759
1760 // Compute the insertion point of the materialization.
1762 if (repl.empty()) {
1763 // The source materialization has no inputs. Insert it right before the
1764 // value that it is replacing.
1765 ip = computeInsertPoint(value);
1766 } else {
1767 // Compute the "earliest" insertion point at which all values in `repl` are
1768 // defined. It is important to emit the materialization at that location
1769 // because the same materialization may be reused in a different context.
1770 // (That's because materializations are cached in the conversion value
1771 // mapping.) The insertion point of the materialization must be valid for
1772 // all future users that may be created later in the conversion process.
1773 ip = computeInsertPoint(repl);
1774 }
1776 MaterializationKind::Source, ip, value.getLoc(),
1777 /*valuesToMap=*/repl, /*inputs=*/repl,
1778 /*outputTypes=*/value.getType(),
1779 /*originalType=*/Type(), converter,
1780 /*isPureTypeConversion=*/!repl.empty())
1781 .front();
1782 return castValue;
1783}
1784
1785//===----------------------------------------------------------------------===//
1786// Rewriter Notification Hooks
1787//===----------------------------------------------------------------------===//
1788
1790 Operation *op, OpBuilder::InsertPoint previous) {
1791 // If no previous insertion point is provided, the op used to be detached.
1792 bool wasDetached = !previous.isSet();
1793 LLVM_DEBUG({
1794 logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
1795 << ")";
1796 if (wasDetached)
1797 logger.getOStream() << " (was detached)";
1798 logger.getOStream() << "\n";
1799 });
1800
1801 // In rollback mode, it is easier to misuse the API, so perform extra error
1802 // checking.
1803 assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
1804 "attempting to insert into a block within a replaced/erased op");
1805
1806 // In "no rollback" mode, the listener is always notified immediately.
1807 if (!config.allowPatternRollback && config.listener)
1808 config.listener->notifyOperationInserted(op, previous);
1809
1810 if (wasDetached) {
1811 // If the op was detached, it is most likely a newly created op. Add it the
1812 // set of newly created ops, so that it will be legalized. If this op is
1813 // not a newly created op, it will be legalized a second time, which is
1814 // inefficient but harmless.
1815 patternNewOps.insert(op);
1816
1817 if (config.allowPatternRollback) {
1818 // TODO: If the same op is inserted multiple times from a detached
1819 // state, the rollback mechanism may erase the same op multiple times.
1820 // This is a bug in the rollback-based dialect conversion driver.
1822 } else {
1823 // In "no rollback" mode, there is an extra data structure for tracking
1824 // erased operations that must be kept up to date.
1825 erasedOps.erase(op);
1826 }
1827 return;
1828 }
1829
1830 // The op was moved from one place to another.
1831 if (config.allowPatternRollback)
1833}
1834
1835/// Given that `fromRange` is about to be replaced with `toRange`, compute
1836/// replacement values with the types of `fromRange`.
1837static SmallVector<Value>
1839 const SmallVector<SmallVector<Value>> &toRange,
1840 const TypeConverter *converter) {
1841 assert(!impl.config.allowPatternRollback &&
1842 "this code path is valid only in 'no rollback' mode");
1843 SmallVector<Value> repls;
1844 for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1845 if (from.use_empty()) {
1846 // The replaced value is dead. No replacement value is needed.
1847 repls.push_back(Value());
1848 continue;
1849 }
1850
1851 if (to.empty()) {
1852 // The replaced value is dropped. Materialize a replacement value "out of
1853 // thin air".
1854 Value srcMat = impl.buildUnresolvedMaterialization(
1855 MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
1856 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1857 /*outputTypes=*/from.getType(), /*originalType=*/Type(),
1858 converter)[0];
1859 repls.push_back(srcMat);
1860 continue;
1861 }
1862
1863 if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
1864 // The replacement value already has the correct type. Use it directly.
1865 repls.push_back(to[0]);
1866 continue;
1867 }
1868
1869 // The replacement value has the wrong type. Build a source materialization
1870 // to the original type.
1871 // TODO: This is a bit inefficient. We should try to reuse existing
1872 // materializations if possible. This would require an extension of the
1873 // `lookupOrDefault` API.
1874 Value srcMat = impl.buildUnresolvedMaterialization(
1875 MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
1876 /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
1877 /*originalType=*/Type(), converter)[0];
1878 repls.push_back(srcMat);
1879 }
1880
1881 return repls;
1882}
1883
1885 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1886 assert(newValues.size() == op->getNumResults() &&
1887 "incorrect number of replacement values");
1888 LLVM_DEBUG({
1889 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
1890 << ")\n";
1892 // If the user-provided replacement types are different from the
1893 // legalized types, as per the current type converter, print a note.
1894 // In most cases, the replacement types are expected to match the types
1895 // produced by the type converter, so this could indicate a bug in the
1896 // user code.
1897 for (auto [result, repls] :
1898 llvm::zip_equal(op->getResults(), newValues)) {
1899 Type resultType = result.getType();
1900 auto logProlog = [&, repls = repls]() {
1901 logger.startLine() << " Note: Replacing op result of type "
1902 << resultType << " with value(s) of type (";
1903 llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) {
1904 logger.getOStream() << v.getType();
1905 });
1906 logger.getOStream() << ")";
1907 };
1908 SmallVector<Type> convertedTypes;
1909 if (failed(currentTypeConverter->convertTypes(resultType,
1910 convertedTypes))) {
1911 logProlog();
1912 logger.getOStream() << ", but the type converter failed to legalize "
1913 "the original type.\n";
1914 continue;
1915 }
1916 if (TypeRange(convertedTypes) != TypeRange(ValueRange(repls))) {
1917 logProlog();
1918 logger.getOStream() << ", but the legalized type(s) is/are (";
1919 llvm::interleaveComma(convertedTypes, logger.getOStream(),
1920 [&](Type t) { logger.getOStream() << t; });
1921 logger.getOStream() << ")\n";
1922 }
1923 }
1924 }
1925 });
1926
1927 if (!config.allowPatternRollback) {
1928 // Pattern rollback is not allowed: materialize all IR changes immediately.
1930 *this, op->getResults(), newValues, currentTypeConverter);
1931 // Update internal data structures, so that there are no dangling pointers
1932 // to erased IR.
1933 op->walk([&](Operation *op) {
1934 erasedOps.insert(op);
1935 ignoredOps.remove(op);
1936 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1937 unresolvedMaterializations.erase(castOp);
1938 patternMaterializations.erase(castOp);
1939 }
1940 // The original op will be erased, so remove it from the set of
1941 // unlegalized ops.
1942 if (config.unlegalizedOps)
1943 config.unlegalizedOps->erase(op);
1944 });
1945 op->walk([&](Block *block) { erasedBlocks.insert(block); });
1946 // Replace the op with the replacement values and notify the listener.
1947 notifyingRewriter.replaceOp(op, repls);
1948 return;
1949 }
1950
1951 assert(!ignoredOps.contains(op) && "operation was already replaced");
1952#ifndef NDEBUG
1953 for (Value v : op->getResults())
1954 assert(!replacedValues.contains(v) &&
1955 "attempting to replace a value that was already replaced");
1956#endif // NDEBUG
1957
1958 // Check if replaced op is an unresolved materialization, i.e., an
1959 // unrealized_conversion_cast op that was created by the conversion driver.
1960 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1961 // Make sure that the user does not mess with unresolved materializations
1962 // that were inserted by the conversion driver. We keep track of these
1963 // ops in internal data structures.
1964 assert(!unresolvedMaterializations.contains(castOp) &&
1965 "attempting to replace/erase an unresolved materialization");
1966 }
1967
1968 // Create mappings for each of the new result values.
1969 for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
1970 mapping.map(static_cast<Value>(result), std::move(repl));
1971
1973 // Mark this operation and all nested ops as replaced.
1974 op->walk([&](Operation *op) { replacedOps.insert(op); });
1975}
1976
1978 Value from, ValueRange to, const TypeConverter *converter,
1979 function_ref<bool(OpOperand &)> functor) {
1980 LLVM_DEBUG({
1981 logger.startLine() << "** Replace Value : '" << from << "'";
1982 if (auto blockArg = dyn_cast<BlockArgument>(from)) {
1983 if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
1984 logger.getOStream() << " (in region of '" << parentOp->getName()
1985 << "' (" << parentOp << ")";
1986 } else {
1987 logger.getOStream() << " (unlinked block)";
1988 }
1989 }
1990 if (functor) {
1991 logger.getOStream() << ", conditional replacement";
1992 }
1993 });
1994
1995 if (!config.allowPatternRollback) {
1996 SmallVector<Value> toConv = llvm::to_vector(to);
1997 SmallVector<Value> repls =
1998 getReplacementValues(*this, from, {toConv}, converter);
1999 IRRewriter r(from.getContext());
2000 Value repl = repls.front();
2001 if (!repl)
2002 return;
2003
2004 performReplaceValue(r, from, repl, functor);
2005 return;
2006 }
2007
2008#ifndef NDEBUG
2009 // Make sure that a value is not replaced multiple times. In rollback mode,
2010 // `replaceAllUsesWith` replaces not only all current uses of the given value,
2011 // but also all future uses that may be introduced by future pattern
2012 // applications. Therefore, it does not make sense to call
2013 // `replaceAllUsesWith` multiple times with the same value. Doing so would
2014 // overwrite the mapping and mess with the internal state of the dialect
2015 // conversion driver.
2016 assert(!replacedValues.contains(from) &&
2017 "attempting to replace a value that was already replaced");
2018 assert(!wasOpReplaced(from.getDefiningOp()) &&
2019 "attempting to replace a op result that was already replaced");
2020 replacedValues.insert(from);
2021#endif // NDEBUG
2022
2023 if (functor)
2024 llvm::report_fatal_error(
2025 "conditional value replacement is not supported in rollback mode");
2026 mapping.map(from, to);
2027 appendRewrite<ReplaceValueRewrite>(from, converter);
2028}
2029
2031 if (!config.allowPatternRollback) {
2032 // Pattern rollback is not allowed: materialize all IR changes immediately.
2033 // Update internal data structures, so that there are no dangling pointers
2034 // to erased IR.
2035 block->walk([&](Operation *op) {
2036 erasedOps.insert(op);
2037 ignoredOps.remove(op);
2038 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2039 unresolvedMaterializations.erase(castOp);
2040 patternMaterializations.erase(castOp);
2041 }
2042 // The original op will be erased, so remove it from the set of
2043 // unlegalized ops.
2044 if (config.unlegalizedOps)
2045 config.unlegalizedOps->erase(op);
2046 });
2047 block->walk([&](Block *block) { erasedBlocks.insert(block); });
2048 // Erase the block and notify the listener.
2049 notifyingRewriter.eraseBlock(block);
2050 return;
2051 }
2052
2053 assert(!wasOpReplaced(block->getParentOp()) &&
2054 "attempting to erase a block within a replaced/erased op");
2056
2057 // Unlink the block from its parent region. The block is kept in the rewrite
2058 // object and will be actually destroyed when rewrites are applied. This
2059 // allows us to keep the operations in the block live and undo the removal by
2060 // re-inserting the block.
2061 block->getParent()->getBlocks().remove(block);
2062
2063 // Mark all nested ops as erased.
2064 block->walk([&](Operation *op) { replacedOps.insert(op); });
2065}
2066
2068 Block *block, Region *previous, Region::iterator previousIt) {
2069 // If no previous insertion point is provided, the block used to be detached.
2070 bool wasDetached = !previous;
2071 Operation *newParentOp = block->getParentOp();
2072 LLVM_DEBUG(
2073 {
2074 Operation *parent = newParentOp;
2075 if (parent) {
2076 logger.startLine() << "** Insert Block into : '" << parent->getName()
2077 << "' (" << parent << ")";
2078 } else {
2079 logger.startLine()
2080 << "** Insert Block into detached Region (nullptr parent op)";
2081 }
2082 if (wasDetached)
2083 logger.getOStream() << " (was detached)";
2084 logger.getOStream() << "\n";
2085 });
2086
2087 // In rollback mode, it is easier to misuse the API, so perform extra error
2088 // checking.
2089 assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
2090 "attempting to insert into a region within a replaced/erased op");
2091 (void)newParentOp;
2092
2093 // In "no rollback" mode, the listener is always notified immediately.
2094 if (!config.allowPatternRollback && config.listener)
2095 config.listener->notifyBlockInserted(block, previous, previousIt);
2096
2097 if (wasDetached) {
2098 // If the block was detached, it is most likely a newly created block.
2099 if (config.allowPatternRollback) {
2100 // TODO: If the same block is inserted multiple times from a detached
2101 // state, the rollback mechanism may erase the same block multiple times.
2102 // This is a bug in the rollback-based dialect conversion driver.
2104 } else {
2105 // In "no rollback" mode, there is an extra data structure for tracking
2106 // erased blocks that must be kept up to date.
2107 erasedBlocks.erase(block);
2108 }
2109 return;
2110 }
2111
2112 // The block was moved from one place to another.
2113 if (config.allowPatternRollback)
2114 appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
2115}
2116
2118 Block *dest,
2119 Block::iterator before) {
2120 appendRewrite<InlineBlockRewrite>(dest, source, before);
2121}
2122
2124 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
2125 LLVM_DEBUG({
2127 reasonCallback(diag);
2128 logger.startLine() << "** Failure : " << diag.str() << "\n";
2129 if (config.notifyCallback)
2130 config.notifyCallback(diag);
2131 });
2132}
2133
2134//===----------------------------------------------------------------------===//
2135// ConversionPatternRewriter
2136//===----------------------------------------------------------------------===//
2137
2138ConversionPatternRewriter::ConversionPatternRewriter(
2139 MLIRContext *ctx, const ConversionConfig &config,
2140 OperationConverter &opConverter)
2142 *this, config, opConverter)) {
2143 setListener(impl.get());
2144}
2145
2146ConversionPatternRewriter::~ConversionPatternRewriter() = default;
2147
2148const ConversionConfig &ConversionPatternRewriter::getConfig() const {
2149 return impl->config;
2150}
2151
2152void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
2153 assert(op && newOp && "expected non-null op");
2154 replaceOp(op, newOp->getResults());
2155}
2156
2157void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
2158 assert(op->getNumResults() == newValues.size() &&
2159 "incorrect # of replacement values");
2160
2161 // If the current insertion point is before the erased operation, we adjust
2162 // the insertion point to be after the operation.
2163 if (getInsertionPoint() == op->getIterator())
2165
2166 SmallVector<SmallVector<Value>> newVals =
2167 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2168 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2169 });
2170 impl->replaceOp(op, std::move(newVals));
2171}
2172
2173void ConversionPatternRewriter::replaceOpWithMultiple(
2174 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2175 assert(op->getNumResults() == newValues.size() &&
2176 "incorrect # of replacement values");
2177
2178 // If the current insertion point is before the erased operation, we adjust
2179 // the insertion point to be after the operation.
2180 if (getInsertionPoint() == op->getIterator())
2182
2183 impl->replaceOp(op, std::move(newValues));
2184}
2185
2186void ConversionPatternRewriter::eraseOp(Operation *op) {
2187 LLVM_DEBUG({
2188 impl->logger.startLine()
2189 << "** Erase : '" << op->getName() << "'(" << op << ")\n";
2190 });
2191
2192 // If the current insertion point is before the erased operation, we adjust
2193 // the insertion point to be after the operation.
2194 if (getInsertionPoint() == op->getIterator())
2196
2197 SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
2198 impl->replaceOp(op, std::move(nullRepls));
2199}
2200
2201void ConversionPatternRewriter::eraseBlock(Block *block) {
2202 impl->eraseBlock(block);
2203}
2204
2205Block *ConversionPatternRewriter::applySignatureConversion(
2206 Block *block, TypeConverter::SignatureConversion &conversion,
2207 const TypeConverter *converter) {
2208 assert(!impl->wasOpReplaced(block->getParentOp()) &&
2209 "attempting to apply a signature conversion to a block within a "
2210 "replaced/erased op");
2211 return impl->applySignatureConversion(block, converter, conversion);
2212}
2213
2214FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2215 Region *region, const TypeConverter &converter,
2216 TypeConverter::SignatureConversion *entryConversion) {
2217 assert(!impl->wasOpReplaced(region->getParentOp()) &&
2218 "attempting to apply a signature conversion to a block within a "
2219 "replaced/erased op");
2220 return impl->convertRegionTypes(region, converter, entryConversion);
2221}
2222
2223void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
2224 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2225}
2226
2227void ConversionPatternRewriter::replaceUsesWithIf(
2228 Value from, ValueRange to, function_ref<bool(OpOperand &)> functor,
2229 bool *allUsesReplaced) {
2230 assert(!allUsesReplaced &&
2231 "allUsesReplaced is not supported in a dialect conversion");
2232 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2233}
2234
2235Value ConversionPatternRewriter::getRemappedValue(Value key) {
2236 SmallVector<ValueVector> remappedValues;
2237 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
2238 remappedValues)))
2239 return nullptr;
2240 assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
2241 return remappedValues.front().front();
2242}
2243
2244LogicalResult
2245ConversionPatternRewriter::getRemappedValues(ValueRange keys,
2246 SmallVectorImpl<Value> &results) {
2247 if (keys.empty())
2248 return success();
2249 SmallVector<ValueVector> remapped;
2250 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
2251 remapped)))
2252 return failure();
2253 for (const auto &values : remapped) {
2254 assert(values.size() == 1 && "1:N conversion not supported");
2255 results.push_back(values.front());
2256 }
2257 return success();
2258}
2259
2260LogicalResult ConversionPatternRewriter::legalize(Region *r) {
2261 // Fast path: If the region is empty, there is nothing to legalize.
2262 if (r->empty())
2263 return success();
2264
2265 // Gather a list of all operations to legalize. This is done before
2266 // converting the entry block signature because unrealized_conversion_cast
2267 // ops should not be included.
2268 SmallVector<Operation *> ops;
2269 for (Block &b : *r)
2270 for (Operation &op : b)
2271 ops.push_back(&op);
2272
2273 // If the current pattern runs with a type converter, convert the entry block
2274 // signature.
2275 if (const TypeConverter *converter = impl->currentTypeConverter) {
2276 std::optional<TypeConverter::SignatureConversion> conversion =
2277 converter->convertBlockSignature(&r->front());
2278 if (!conversion)
2279 return failure();
2280 applySignatureConversion(&r->front(), *conversion, converter);
2281 }
2282
2283 // Legalize all operations in the region.
2284 for (Operation *op : ops)
2285 if (failed(legalize(op)))
2286 return failure();
2287
2288 return success();
2289}
2290
2291void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
2292 Block::iterator before,
2293 ValueRange argValues) {
2294#ifndef NDEBUG
2295 assert(argValues.size() == source->getNumArguments() &&
2296 "incorrect # of argument replacement values");
2297 assert(!impl->wasOpReplaced(source->getParentOp()) &&
2298 "attempting to inline a block from a replaced/erased op");
2299 assert(!impl->wasOpReplaced(dest->getParentOp()) &&
2300 "attempting to inline a block into a replaced/erased op");
2301 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
2302 // The source block will be deleted, so it should not have any users (i.e.,
2303 // there should be no predecessors).
2304 assert(llvm::all_of(source->getUsers(), opIgnored) &&
2305 "expected 'source' to have no predecessors");
2306#endif // NDEBUG
2307
2308 // If a listener is attached to the dialect conversion, ops cannot be moved
2309 // to the destination block in bulk ("fast path"). This is because at the time
2310 // the notifications are sent, it is unknown which ops were moved. Instead,
2311 // ops should be moved one-by-one ("slow path"), so that a separate
2312 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
2313 // a bit more efficient, so we try to do that when possible.
2314 bool fastPath = !getConfig().listener;
2315
2316 if (fastPath && impl->config.allowPatternRollback)
2317 impl->inlineBlockBefore(source, dest, before);
2318
2319 // Replace all uses of block arguments.
2320 for (auto it : llvm::zip(source->getArguments(), argValues))
2321 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2322
2323 if (fastPath) {
2324 // Move all ops at once.
2325 dest->getOperations().splice(before, source->getOperations());
2326 } else {
2327 // Move op by op.
2328 while (!source->empty())
2329 moveOpBefore(&source->front(), dest, before);
2330 }
2331
2332 // If the current insertion point is within the source block, adjust the
2333 // insertion point to the destination block.
2334 if (getInsertionBlock() == source)
2335 setInsertionPoint(dest, getInsertionPoint());
2336
2337 // Erase the source block.
2338 eraseBlock(source);
2339}
2340
2341void ConversionPatternRewriter::startOpModification(Operation *op) {
2342 if (!impl->config.allowPatternRollback) {
2343 // Pattern rollback is not allowed: no extra bookkeeping is needed.
2345 return;
2346 }
2347 assert(!impl->wasOpReplaced(op) &&
2348 "attempting to modify a replaced/erased op");
2349#ifndef NDEBUG
2350 impl->pendingRootUpdates.insert(op);
2351#endif
2352 impl->appendRewrite<ModifyOperationRewrite>(op);
2353}
2354
2355void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2356 impl->patternModifiedOps.insert(op);
2357 if (!impl->config.allowPatternRollback) {
2359 if (getConfig().listener)
2360 getConfig().listener->notifyOperationModified(op);
2361 return;
2362 }
2363
2364 // There is nothing to do here, we only need to track the operation at the
2365 // start of the update.
2366#ifndef NDEBUG
2367 assert(!impl->wasOpReplaced(op) &&
2368 "attempting to modify a replaced/erased op");
2369 assert(impl->pendingRootUpdates.erase(op) &&
2370 "operation did not have a pending in-place update");
2371#endif
2372}
2373
2374void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2375 if (!impl->config.allowPatternRollback) {
2377 return;
2378 }
2379#ifndef NDEBUG
2380 assert(impl->pendingRootUpdates.erase(op) &&
2381 "operation did not have a pending in-place update");
2382#endif
2383 // Erase the last update for this operation.
2384 auto it = llvm::find_if(
2385 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
2386 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2387 return modifyRewrite && modifyRewrite->getOperation() == op;
2388 });
2389 assert(it != impl->rewrites.rend() && "no root update started on op");
2390 (*it)->rollback();
2391 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2392 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2393}
2394
2395detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2396 return *impl;
2397}
2398
2399//===----------------------------------------------------------------------===//
2400// ConversionPattern
2401//===----------------------------------------------------------------------===//
2402
2403FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2404 ArrayRef<ValueRange> operands) const {
2405 SmallVector<Value> oneToOneOperands;
2406 oneToOneOperands.reserve(operands.size());
2407 for (ValueRange operand : operands) {
2408 if (operand.size() != 1)
2409 return failure();
2410
2411 oneToOneOperands.push_back(operand.front());
2412 }
2413 return std::move(oneToOneOperands);
2414}
2415
2416LogicalResult
2417ConversionPattern::matchAndRewrite(Operation *op,
2418 PatternRewriter &rewriter) const {
2419 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
2420 auto &rewriterImpl = dialectRewriter.getImpl();
2421
2422 // Track the current conversion pattern type converter in the rewriter.
2423 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
2424 getTypeConverter());
2425
2426 // Remap the operands of the operation.
2427 SmallVector<ValueVector> remapped;
2428 if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
2429 op->getOperands(), remapped))) {
2430 return failure();
2431 }
2432 SmallVector<ValueRange> remappedAsRange =
2433 llvm::to_vector_of<ValueRange>(remapped);
2434 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2435}
2436
2437//===----------------------------------------------------------------------===//
2438// OperationLegalizer
2439//===----------------------------------------------------------------------===//
2440
2441namespace {
2442/// A set of rewrite patterns that can be used to legalize a given operation.
2443using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2444
2445/// This class defines a recursive operation legalizer.
2446class OperationLegalizer {
2447public:
2448 using LegalizationAction = ConversionTarget::LegalizationAction;
2449
2450 OperationLegalizer(ConversionPatternRewriter &rewriter,
2451 const ConversionTarget &targetInfo,
2452 const FrozenRewritePatternSet &patterns);
2453
2454 /// Returns true if the given operation is known to be illegal on the target.
2455 bool isIllegal(Operation *op) const;
2456
2457 /// Attempt to legalize the given operation. Returns success if the operation
2458 /// was legalized, failure otherwise.
2459 LogicalResult legalize(Operation *op);
2460
2461 /// Returns the conversion target in use by the legalizer.
2462 const ConversionTarget &getTarget() { return target; }
2463
2464private:
2465 /// Attempt to legalize the given operation by folding it.
2466 LogicalResult legalizeWithFold(Operation *op);
2467
2468 /// Attempt to legalize the given operation by applying a pattern. Returns
2469 /// success if the operation was legalized, failure otherwise.
2470 LogicalResult legalizeWithPattern(Operation *op);
2471
2472 /// Return true if the given pattern may be applied to the given operation,
2473 /// false otherwise.
2474 bool canApplyPattern(Operation *op, const Pattern &pattern);
2475
2476 /// Legalize the resultant IR after successfully applying the given pattern.
2477 LogicalResult
2478 legalizePatternResult(Operation *op, const Pattern &pattern,
2479 const RewriterState &curState,
2480 const SetVector<Operation *> &newOps,
2481 const SetVector<Operation *> &modifiedOps);
2482
2483 LogicalResult
2484 legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
2485 LogicalResult
2486 legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
2487
2488 //===--------------------------------------------------------------------===//
2489 // Cost Model
2490 //===--------------------------------------------------------------------===//
2491
2492 /// Build an optimistic legalization graph given the provided patterns. This
2493 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2494 /// patterns for operations that are not directly legal, but may be
2495 /// transitively legal for the current target given the provided patterns.
2496 void buildLegalizationGraph(
2497 LegalizationPatterns &anyOpLegalizerPatterns,
2499
2500 /// Compute the benefit of each node within the computed legalization graph.
2501 /// This orders the patterns within 'legalizerPatterns' based upon two
2502 /// criteria:
2503 /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2504 /// represent the more direct mapping to the target.
2505 /// 2) When comparing patterns with the same legalization depth, prefer the
2506 /// pattern with the highest PatternBenefit. This allows for users to
2507 /// prefer specific legalizations over others.
2508 void computeLegalizationGraphBenefit(
2509 LegalizationPatterns &anyOpLegalizerPatterns,
2511
2512 /// Compute the legalization depth when legalizing an operation of the given
2513 /// type.
2514 unsigned computeOpLegalizationDepth(
2515 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2517
2518 /// Apply the conversion cost model to the given set of patterns, and return
2519 /// the smallest legalization depth of any of the patterns. See
2520 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2521 unsigned applyCostModelToPatterns(
2522 LegalizationPatterns &patterns,
2523 DenseMap<OperationName, unsigned> &minOpPatternDepth,
2525
2526 /// The current set of patterns that have been applied.
2527 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2528
2529 /// The rewriter to use when converting operations.
2530 ConversionPatternRewriter &rewriter;
2531
2532 /// The legalization information provided by the target.
2533 const ConversionTarget &target;
2534
2535 /// The pattern applicator to use for conversions.
2536 PatternApplicator applicator;
2537};
2538} // namespace
2539
2540OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2541 const ConversionTarget &targetInfo,
2542 const FrozenRewritePatternSet &patterns)
2543 : rewriter(rewriter), target(targetInfo), applicator(patterns) {
2544 // The set of patterns that can be applied to illegal operations to transform
2545 // them into legal ones.
2547 LegalizationPatterns anyOpLegalizerPatterns;
2548
2549 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2550 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2551}
2552
2553bool OperationLegalizer::isIllegal(Operation *op) const {
2554 return target.isIllegal(op);
2555}
2556
2557LogicalResult OperationLegalizer::legalize(Operation *op) {
2558#ifndef NDEBUG
2559 const char *logLineComment =
2560 "//===-------------------------------------------===//\n";
2561
2562 auto &logger = rewriter.getImpl().logger;
2563#endif
2564
2565 // Check to see if the operation is ignored and doesn't need to be converted.
2566 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2567
2568 LLVM_DEBUG({
2569 logger.getOStream() << "\n";
2570 logger.startLine() << logLineComment;
2571 logger.startLine() << "Legalizing operation : ";
2572 // Do not print the operation name if the operation is ignored. Ignored ops
2573 // may have been erased and should not be accessed. The pointer can be
2574 // printed safely.
2575 if (!isIgnored)
2576 logger.getOStream() << "'" << op->getName() << "' ";
2577 logger.getOStream() << "(" << op << ") {\n";
2578 logger.indent();
2579
2580 // If the operation has no regions, just print it here.
2581 if (!isIgnored && op->getNumRegions() == 0) {
2582 logger.startLine() << OpWithFlags(op,
2583 OpPrintingFlags().printGenericOpForm())
2584 << "\n";
2585 }
2586 });
2587
2588 if (isIgnored) {
2589 LLVM_DEBUG({
2590 logSuccess(logger, "operation marked 'ignored' during conversion");
2591 logger.startLine() << logLineComment;
2592 });
2593 return success();
2594 }
2595
2596 // Check if this operation is legal on the target.
2597 if (auto legalityInfo = target.isLegal(op)) {
2598 LLVM_DEBUG({
2599 logSuccess(
2600 logger, "operation marked legal by the target{0}",
2601 legalityInfo->isRecursivelyLegal
2602 ? "; NOTE: operation is recursively legal; skipping internals"
2603 : "");
2604 logger.startLine() << logLineComment;
2605 });
2606
2607 // If this operation is recursively legal, mark its children as ignored so
2608 // that we don't consider them for legalization.
2609 if (legalityInfo->isRecursivelyLegal) {
2610 op->walk([&](Operation *nested) {
2611 if (op != nested)
2612 rewriter.getImpl().ignoredOps.insert(nested);
2613 });
2614 }
2615
2616 return success();
2617 }
2618
2619 // If the operation is not legal, try to fold it in-place if the folding mode
2620 // is 'BeforePatterns'. 'Never' will skip this.
2621 const ConversionConfig &config = rewriter.getConfig();
2622 if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2623 if (succeeded(legalizeWithFold(op))) {
2624 LLVM_DEBUG({
2625 logSuccess(logger, "operation was folded");
2626 logger.startLine() << logLineComment;
2627 });
2628 return success();
2629 }
2630 }
2631
2632 // Otherwise, we need to apply a legalization pattern to this operation.
2633 if (succeeded(legalizeWithPattern(op))) {
2634 LLVM_DEBUG({
2635 logSuccess(logger, "");
2636 logger.startLine() << logLineComment;
2637 });
2638 return success();
2639 }
2640
2641 // If the operation can't be legalized via patterns, try to fold it in-place
2642 // if the folding mode is 'AfterPatterns'.
2643 if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2644 if (succeeded(legalizeWithFold(op))) {
2645 LLVM_DEBUG({
2646 logSuccess(logger, "operation was folded");
2647 logger.startLine() << logLineComment;
2648 });
2649 return success();
2650 }
2651 }
2652
2653 LLVM_DEBUG({
2654 logFailure(logger, "no matched legalization pattern");
2655 logger.startLine() << logLineComment;
2656 });
2657 return failure();
2658}
2659
2660/// Helper function that moves and returns the given object. Also resets the
2661/// original object, so that it is in a valid, empty state again.
2662template <typename T>
2663static T moveAndReset(T &obj) {
2664 T result = std::move(obj);
2665 obj = T();
2666 return result;
2667}
2668
2669LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2670 auto &rewriterImpl = rewriter.getImpl();
2671 LLVM_DEBUG({
2672 rewriterImpl.logger.startLine() << "* Fold {\n";
2673 rewriterImpl.logger.indent();
2674 });
2675
2676 // Clear pattern state, so that the next pattern application starts with a
2677 // clean slate. (The op/block sets are populated by listener notifications.)
2678 auto cleanup = llvm::make_scope_exit([&]() {
2679 rewriterImpl.patternNewOps.clear();
2680 rewriterImpl.patternModifiedOps.clear();
2681 });
2682
2683 // Upon failure, undo all changes made by the folder.
2684 RewriterState curState = rewriterImpl.getCurrentState();
2685
2686 // Try to fold the operation.
2687 StringRef opName = op->getName().getStringRef();
2688 SmallVector<Value, 2> replacementValues;
2689 SmallVector<Operation *, 2> newOps;
2690 rewriter.setInsertionPoint(op);
2691 rewriter.startOpModification(op);
2692 if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2693 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2694 rewriter.cancelOpModification(op);
2695 return failure();
2696 }
2697 rewriter.finalizeOpModification(op);
2698
2699 // An empty list of replacement values indicates that the fold was in-place.
2700 // As the operation changed, a new legalization needs to be attempted.
2701 if (replacementValues.empty())
2702 return legalize(op);
2703
2704 // Insert a replacement for 'op' with the folded replacement values.
2705 rewriter.replaceOp(op, replacementValues);
2706
2707 // Recursively legalize any new constant operations.
2708 for (Operation *newOp : newOps) {
2709 if (failed(legalize(newOp))) {
2710 LLVM_DEBUG(logFailure(rewriterImpl.logger,
2711 "failed to legalize generated constant '{0}'",
2712 newOp->getName()));
2713 if (!rewriter.getConfig().allowPatternRollback) {
2714 // Rolling back a folder is like rolling back a pattern.
2715 llvm::report_fatal_error(
2716 "op '" + opName +
2717 "' folder rollback of IR modifications requested");
2718 }
2719 rewriterImpl.resetState(
2720 curState, std::string(op->getName().getStringRef()) + " folder");
2721 return failure();
2722 }
2723 }
2724
2725 LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2726 return success();
2727}
2728
2729/// Report a fatal error indicating that newly produced or modified IR could
2730/// not be legalized.
2731static void
2733 const SetVector<Operation *> &newOps,
2734 const SetVector<Operation *> &modifiedOps) {
2735 auto newOpNames = llvm::map_range(
2736 newOps, [](Operation *op) { return op->getName().getStringRef(); });
2737 auto modifiedOpNames = llvm::map_range(
2738 modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
2739 llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2740 "' produced IR that could not be legalized. " +
2741 "new ops: {" + llvm::join(newOpNames, ", ") + "}, " +
2742 "modified ops: {" +
2743 llvm::join(modifiedOpNames, ", ") + "}");
2744}
2745
2746LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2747 auto &rewriterImpl = rewriter.getImpl();
2748 const ConversionConfig &config = rewriter.getConfig();
2749
2750#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2751 Operation *checkOp;
2752 std::optional<OperationFingerPrint> topLevelFingerPrint;
2753 if (!rewriterImpl.config.allowPatternRollback) {
2754 // The op may be getting erased, so we have to check the parent op.
2755 // (In rare cases, a pattern may even erase the parent op, which will cause
2756 // a crash here. Expensive checks are "best effort".) Skip the check if the
2757 // op does not have a parent op.
2758 if ((checkOp = op->getParentOp())) {
2759 if (!op->getContext()->isMultithreadingEnabled()) {
2760 topLevelFingerPrint = OperationFingerPrint(checkOp);
2761 } else {
2762 // Another thread may be modifying a sibling operation. Therefore, the
2763 // fingerprinting mechanism of the parent op works only in
2764 // single-threaded mode.
2765 LLVM_DEBUG({
2766 rewriterImpl.logger.startLine()
2767 << "WARNING: Multi-threadeding is enabled. Some dialect "
2768 "conversion expensive checks are skipped in multithreading "
2769 "mode!\n";
2770 });
2771 }
2772 }
2773 }
2774#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2775
2776 // Functor that returns if the given pattern may be applied.
2777 auto canApply = [&](const Pattern &pattern) {
2778 bool canApply = canApplyPattern(op, pattern);
2779 if (canApply && config.listener)
2780 config.listener->notifyPatternBegin(pattern, op);
2781 return canApply;
2782 };
2783
2784 // Functor that cleans up the rewriter state after a pattern failed to match.
2785 RewriterState curState = rewriterImpl.getCurrentState();
2786 auto onFailure = [&](const Pattern &pattern) {
2787 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2788 if (!rewriterImpl.config.allowPatternRollback) {
2789 // Erase all unresolved materializations.
2790 for (auto op : rewriterImpl.patternMaterializations) {
2791 rewriterImpl.unresolvedMaterializations.erase(op);
2792 op.erase();
2793 }
2794 rewriterImpl.patternMaterializations.clear();
2795#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2796 // Expensive pattern check that can detect API violations.
2797 if (checkOp && topLevelFingerPrint) {
2798 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2799 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2800 llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2801 "' returned failure but IR did change");
2802 }
2803#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2804 }
2805 rewriterImpl.patternNewOps.clear();
2806 rewriterImpl.patternModifiedOps.clear();
2807 LLVM_DEBUG({
2808 logFailure(rewriterImpl.logger, "pattern failed to match");
2809 if (rewriterImpl.config.notifyCallback) {
2810 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
2811 diag << "Failed to apply pattern \"" << pattern.getDebugName()
2812 << "\" on op:\n"
2813 << *op;
2814 rewriterImpl.config.notifyCallback(diag);
2815 }
2816 });
2817 if (config.listener)
2818 config.listener->notifyPatternEnd(pattern, failure());
2819 rewriterImpl.resetState(curState, pattern.getDebugName());
2820 appliedPatterns.erase(&pattern);
2821 };
2822
2823 // Functor that performs additional legalization when a pattern is
2824 // successfully applied.
2825 auto onSuccess = [&](const Pattern &pattern) {
2826 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2827 if (!rewriterImpl.config.allowPatternRollback) {
2828 // Eagerly erase unused materializations.
2829 for (auto op : rewriterImpl.patternMaterializations) {
2830 if (op->use_empty()) {
2831 rewriterImpl.unresolvedMaterializations.erase(op);
2832 op.erase();
2833 }
2834 }
2835 rewriterImpl.patternMaterializations.clear();
2836 }
2837 SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2838 SetVector<Operation *> modifiedOps =
2839 moveAndReset(rewriterImpl.patternModifiedOps);
2840 auto result =
2841 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2842 appliedPatterns.erase(&pattern);
2843 if (failed(result)) {
2844 if (!rewriterImpl.config.allowPatternRollback)
2845 reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
2846 rewriterImpl.resetState(curState, pattern.getDebugName());
2847 }
2848 if (config.listener)
2849 config.listener->notifyPatternEnd(pattern, result);
2850 return result;
2851 };
2852
2853 // Try to match and rewrite a pattern on this operation.
2854 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2855 onSuccess);
2856}
2857
2858bool OperationLegalizer::canApplyPattern(Operation *op,
2859 const Pattern &pattern) {
2860 LLVM_DEBUG({
2861 auto &os = rewriter.getImpl().logger;
2862 os.getOStream() << "\n";
2863 os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2864 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2865 os.getOStream() << ")' {\n";
2866 os.indent();
2867 });
2868
2869 // Ensure that we don't cycle by not allowing the same pattern to be
2870 // applied twice in the same recursion stack if it is not known to be safe.
2871 if (!pattern.hasBoundedRewriteRecursion() &&
2872 !appliedPatterns.insert(&pattern).second) {
2873 LLVM_DEBUG(
2874 logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2875 return false;
2876 }
2877 return true;
2878}
2879
2880LogicalResult OperationLegalizer::legalizePatternResult(
2881 Operation *op, const Pattern &pattern, const RewriterState &curState,
2882 const SetVector<Operation *> &newOps,
2883 const SetVector<Operation *> &modifiedOps) {
2884 [[maybe_unused]] auto &impl = rewriter.getImpl();
2885 assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2886
2887#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2888 if (impl.config.allowPatternRollback) {
2889 // Check that the root was either replaced or updated in place.
2890 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2891 auto replacedRoot = [&] {
2892 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2893 };
2894 auto updatedRootInPlace = [&] {
2895 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2896 };
2897 if (!replacedRoot() && !updatedRootInPlace())
2898 llvm::report_fatal_error("expected pattern to replace the root operation "
2899 "or modify it in place");
2900 }
2901#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2902
2903 // Legalize each of the actions registered during application.
2904 if (failed(legalizePatternRootUpdates(modifiedOps)) ||
2905 failed(legalizePatternCreatedOperations(newOps))) {
2906 return failure();
2907 }
2908
2909 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2910 return success();
2911}
2912
2913LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2914 const SetVector<Operation *> &newOps) {
2915 for (Operation *op : newOps) {
2916 if (failed(legalize(op))) {
2917 LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
2918 "failed to legalize generated operation '{0}'({1})",
2919 op->getName(), op));
2920 return failure();
2921 }
2922 }
2923 return success();
2924}
2925
2926LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2927 const SetVector<Operation *> &modifiedOps) {
2928 for (Operation *op : modifiedOps) {
2929 if (failed(legalize(op))) {
2930 LLVM_DEBUG(
2931 logFailure(rewriter.getImpl().logger,
2932 "failed to legalize operation updated in-place '{0}'",
2933 op->getName()));
2934 return failure();
2935 }
2936 }
2937 return success();
2938}
2939
2940//===----------------------------------------------------------------------===//
2941// Cost Model
2942//===----------------------------------------------------------------------===//
2943
2944void OperationLegalizer::buildLegalizationGraph(
2945 LegalizationPatterns &anyOpLegalizerPatterns,
2947 // A mapping between an operation and a set of operations that can be used to
2948 // generate it.
2950 // A mapping between an operation and any currently invalid patterns it has.
2952 // A worklist of patterns to consider for legality.
2953 SetVector<const Pattern *> patternWorklist;
2954
2955 // Build the mapping from operations to the parent ops that may generate them.
2956 applicator.walkAllPatterns([&](const Pattern &pattern) {
2957 std::optional<OperationName> root = pattern.getRootKind();
2958
2959 // If the pattern has no specific root, we can't analyze the relationship
2960 // between the root op and generated operations. Given that, add all such
2961 // patterns to the legalization set.
2962 if (!root) {
2963 anyOpLegalizerPatterns.push_back(&pattern);
2964 return;
2965 }
2966
2967 // Skip operations that are always known to be legal.
2968 if (target.getOpAction(*root) == LegalizationAction::Legal)
2969 return;
2970
2971 // Add this pattern to the invalid set for the root op and record this root
2972 // as a parent for any generated operations.
2973 invalidPatterns[*root].insert(&pattern);
2974 for (auto op : pattern.getGeneratedOps())
2975 parentOps[op].insert(*root);
2976
2977 // Add this pattern to the worklist.
2978 patternWorklist.insert(&pattern);
2979 });
2980
2981 // If there are any patterns that don't have a specific root kind, we can't
2982 // make direct assumptions about what operations will never be legalized.
2983 // Note: Technically we could, but it would require an analysis that may
2984 // recurse into itself. It would be better to perform this kind of filtering
2985 // at a higher level than here anyways.
2986 if (!anyOpLegalizerPatterns.empty()) {
2987 for (const Pattern *pattern : patternWorklist)
2988 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2989 return;
2990 }
2991
2992 while (!patternWorklist.empty()) {
2993 auto *pattern = patternWorklist.pop_back_val();
2994
2995 // Check to see if any of the generated operations are invalid.
2996 if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2997 std::optional<LegalizationAction> action = target.getOpAction(op);
2998 return !legalizerPatterns.count(op) &&
2999 (!action || action == LegalizationAction::Illegal);
3000 }))
3001 continue;
3002
3003 // Otherwise, if all of the generated operation are valid, this op is now
3004 // legal so add all of the child patterns to the worklist.
3005 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
3006 invalidPatterns[*pattern->getRootKind()].erase(pattern);
3007
3008 // Add any invalid patterns of the parent operations to see if they have now
3009 // become legal.
3010 for (auto op : parentOps[*pattern->getRootKind()])
3011 patternWorklist.set_union(invalidPatterns[op]);
3012 }
3013}
3014
3015void OperationLegalizer::computeLegalizationGraphBenefit(
3016 LegalizationPatterns &anyOpLegalizerPatterns,
3018 // The smallest pattern depth, when legalizing an operation.
3019 DenseMap<OperationName, unsigned> minOpPatternDepth;
3020
3021 // For each operation that is transitively legal, compute a cost for it.
3022 for (auto &opIt : legalizerPatterns)
3023 if (!minOpPatternDepth.count(opIt.first))
3024 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3025 legalizerPatterns);
3026
3027 // Apply the cost model to the patterns that can match any operation. Those
3028 // with a specific operation type are already resolved when computing the op
3029 // legalization depth.
3030 if (!anyOpLegalizerPatterns.empty())
3031 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3032 legalizerPatterns);
3033
3034 // Apply a cost model to the pattern applicator. We order patterns first by
3035 // depth then benefit. `legalizerPatterns` contains per-op patterns by
3036 // decreasing benefit.
3037 applicator.applyCostModel([&](const Pattern &pattern) {
3038 ArrayRef<const Pattern *> orderedPatternList;
3039 if (std::optional<OperationName> rootName = pattern.getRootKind())
3040 orderedPatternList = legalizerPatterns[*rootName];
3041 else
3042 orderedPatternList = anyOpLegalizerPatterns;
3043
3044 // If the pattern is not found, then it was removed and cannot be matched.
3045 auto *it = llvm::find(orderedPatternList, &pattern);
3046 if (it == orderedPatternList.end())
3048
3049 // Patterns found earlier in the list have higher benefit.
3050 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3051 });
3052}
3053
3054unsigned OperationLegalizer::computeOpLegalizationDepth(
3055 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
3057 // Check for existing depth.
3058 auto depthIt = minOpPatternDepth.find(op);
3059 if (depthIt != minOpPatternDepth.end())
3060 return depthIt->second;
3061
3062 // If a mapping for this operation does not exist, then this operation
3063 // is always legal. Return 0 as the depth for a directly legal operation.
3064 auto opPatternsIt = legalizerPatterns.find(op);
3065 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3066 return 0u;
3067
3068 // Record this initial depth in case we encounter this op again when
3069 // recursively computing the depth.
3070 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3071
3072 // Apply the cost model to the operation patterns, and update the minimum
3073 // depth.
3074 unsigned minDepth = applyCostModelToPatterns(
3075 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3076 minOpPatternDepth[op] = minDepth;
3077 return minDepth;
3078}
3079
3080unsigned OperationLegalizer::applyCostModelToPatterns(
3081 LegalizationPatterns &patterns,
3082 DenseMap<OperationName, unsigned> &minOpPatternDepth,
3084 unsigned minDepth = std::numeric_limits<unsigned>::max();
3085
3086 // Compute the depth for each pattern within the set.
3087 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3088 patternsByDepth.reserve(patterns.size());
3089 for (const Pattern *pattern : patterns) {
3090 unsigned depth = 1;
3091 for (auto generatedOp : pattern->getGeneratedOps()) {
3092 unsigned generatedOpDepth = computeOpLegalizationDepth(
3093 generatedOp, minOpPatternDepth, legalizerPatterns);
3094 depth = std::max(depth, generatedOpDepth + 1);
3095 }
3096 patternsByDepth.emplace_back(pattern, depth);
3097
3098 // Update the minimum depth of the pattern list.
3099 minDepth = std::min(minDepth, depth);
3100 }
3101
3102 // If the operation only has one legalization pattern, there is no need to
3103 // sort them.
3104 if (patternsByDepth.size() == 1)
3105 return minDepth;
3106
3107 // Sort the patterns by those likely to be the most beneficial.
3108 llvm::stable_sort(patternsByDepth,
3109 [](const std::pair<const Pattern *, unsigned> &lhs,
3110 const std::pair<const Pattern *, unsigned> &rhs) {
3111 // First sort by the smaller pattern legalization
3112 // depth.
3113 if (lhs.second != rhs.second)
3114 return lhs.second < rhs.second;
3115
3116 // Then sort by the larger pattern benefit.
3117 auto lhsBenefit = lhs.first->getBenefit();
3118 auto rhsBenefit = rhs.first->getBenefit();
3119 return lhsBenefit > rhsBenefit;
3120 });
3121
3122 // Update the legalization pattern to use the new sorted list.
3123 patterns.clear();
3124 for (auto &patternIt : patternsByDepth)
3125 patterns.push_back(patternIt.first);
3126 return minDepth;
3127}
3128
3129//===----------------------------------------------------------------------===//
3130// Reconcile Unrealized Casts
3131//===----------------------------------------------------------------------===//
3132
3133/// Try to reconcile all given UnrealizedConversionCastOps and store the
3134/// left-over ops in `remainingCastOps` (if provided). See documentation in
3135/// DialectConversion.h for more details.
3136/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3137/// algorithm may visit an operand (or user) which is a cast op, but will not
3138/// try to reconcile it if not in the filtered set.
3139template <typename RangeT>
3141 RangeT castOps,
3142 function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3144 // A worklist of cast ops to process.
3145 SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
3146
3147 // Helper function that return the unrealized_conversion_cast op that
3148 // defines all inputs of the given op (in the same order). Return "nullptr"
3149 // if there is no such op.
3150 auto getInputCast =
3151 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3152 if (castOp.getInputs().empty())
3153 return {};
3154 auto inputCastOp =
3155 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3156 if (!inputCastOp)
3157 return {};
3158 if (inputCastOp.getOutputs() != castOp.getInputs())
3159 return {};
3160 return inputCastOp;
3161 };
3162
3163 // Process ops in the worklist bottom-to-top.
3164 while (!worklist.empty()) {
3165 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3166
3167 // Traverse the chain of input cast ops to see if an op with the same
3168 // input types can be found.
3169 UnrealizedConversionCastOp nextCast = castOp;
3170 while (nextCast) {
3171 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3172 if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
3173 return v.getDefiningOp() == castOp;
3174 })) {
3175 // Ran into a cycle.
3176 break;
3177 }
3178
3179 // Found a cast where the input types match the output types of the
3180 // matched op. We can directly use those inputs.
3181 castOp.replaceAllUsesWith(nextCast.getInputs());
3182 break;
3183 }
3184 nextCast = getInputCast(nextCast);
3185 }
3186 }
3187
3188 // A set of all alive cast ops. I.e., ops whose results are (transitively)
3189 // used by an op that is not a cast op.
3190 DenseSet<Operation *> liveOps;
3191
3192 // Helper function that marks the given op and transitively reachable input
3193 // cast ops as alive.
3194 auto markOpLive = [&](Operation *rootOp) {
3195 SmallVector<Operation *> worklist;
3196 worklist.push_back(rootOp);
3197 while (!worklist.empty()) {
3198 Operation *op = worklist.pop_back_val();
3199 if (liveOps.insert(op).second) {
3200 // Successfully inserted: process reachable input cast ops.
3201 for (Value v : op->getOperands())
3202 if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3203 if (isCastOpOfInterestFn(castOp))
3204 worklist.push_back(castOp);
3205 }
3206 }
3207 };
3208
3209 // Find all alive cast ops.
3210 for (UnrealizedConversionCastOp op : castOps) {
3211 // The op may have been marked live already as being an operand of another
3212 // live cast op.
3213 if (liveOps.contains(op.getOperation()))
3214 continue;
3215 // If any of the users is not a cast op, mark the current op (and its
3216 // input ops) as live.
3217 if (llvm::any_of(op->getUsers(), [&](Operation *user) {
3218 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3219 return !castOp || !isCastOpOfInterestFn(castOp);
3220 }))
3221 markOpLive(op);
3222 }
3223
3224 // Erase all dead cast ops.
3225 for (UnrealizedConversionCastOp op : castOps) {
3226 if (liveOps.contains(op)) {
3227 // Op is alive and was not erased. Add it to the remaining cast ops.
3228 if (remainingCastOps)
3229 remainingCastOps->push_back(op);
3230 continue;
3231 }
3232
3233 // Op is dead. Erase it.
3234 op->dropAllUses();
3235 op->erase();
3236 }
3237}
3238
3240 ArrayRef<UnrealizedConversionCastOp> castOps,
3241 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3242 // Set of all cast ops for faster lookups.
3243 DenseSet<UnrealizedConversionCastOp> castOpSet;
3244 for (UnrealizedConversionCastOp op : castOps)
3245 castOpSet.insert(op);
3246 reconcileUnrealizedCasts(castOpSet, remainingCastOps);
3247}
3248
3250 const DenseSet<UnrealizedConversionCastOp> &castOps,
3251 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3253 llvm::make_range(castOps.begin(), castOps.end()),
3254 [&](UnrealizedConversionCastOp castOp) {
3255 return castOps.contains(castOp);
3256 },
3257 remainingCastOps);
3258}
3259
3260namespace mlir {
3263 &castOps,
3266 castOps.keys(),
3267 [&](UnrealizedConversionCastOp castOp) {
3268 return castOps.contains(castOp);
3269 },
3270 remainingCastOps);
3271}
3272} // namespace mlir
3273
3274//===----------------------------------------------------------------------===//
3275// OperationConverter
3276//===----------------------------------------------------------------------===//
3277
3278namespace mlir {
3279// This class converts operations to a given conversion target via a set of
3280// rewrite patterns. The conversion behaves differently depending on the
3281// conversion mode.
3285 const ConversionConfig &config,
3286 OpConversionMode mode)
3287 : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
3288 mode(mode) {}
3289
3290 /// Converts the given operations to the conversion target.
3291 LogicalResult convertOperations(ArrayRef<Operation *> ops);
3292
3293 /// Converts a single operation. If `isRecursiveLegalization` is "true", the
3294 /// conversion is a recursive legalization request, triggered from within a
3295 /// pattern. In that case, do not emit errors because there will be another
3296 /// attempt at legalizing the operation later (via the regular pre-order
3297 /// legalization mechanism).
3298 LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
3299
3300private:
3301 /// The rewriter to use when converting operations.
3302 ConversionPatternRewriter rewriter;
3303
3304 /// The legalizer to use when converting operations.
3305 OperationLegalizer opLegalizer;
3306
3307 /// The conversion mode to use when legalizing operations.
3308 OpConversionMode mode;
3309};
3310} // namespace mlir
3311
3312LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
3313 return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true);
3314}
3315
3317 bool isRecursiveLegalization) {
3318 const ConversionConfig &config = rewriter.getConfig();
3319
3320 // Legalize the given operation.
3321 if (failed(opLegalizer.legalize(op))) {
3322 // Handle the case of a failed conversion for each of the different modes.
3323 // Full conversions expect all operations to be converted.
3324 if (mode == OpConversionMode::Full) {
3325 if (!isRecursiveLegalization)
3326 op->emitError() << "failed to legalize operation '" << op->getName()
3327 << "'";
3328 return failure();
3329 }
3330 // Partial conversions allow conversions to fail iff the operation was not
3331 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
3332 // set, non-legalizable ops are added to that set.
3333 if (mode == OpConversionMode::Partial) {
3334 if (opLegalizer.isIllegal(op)) {
3335 if (!isRecursiveLegalization)
3336 op->emitError() << "failed to legalize operation '" << op->getName()
3337 << "' that was explicitly marked illegal";
3338 return failure();
3339 }
3340 if (config.unlegalizedOps && !isRecursiveLegalization)
3341 config.unlegalizedOps->insert(op);
3342 }
3343 } else if (mode == OpConversionMode::Analysis) {
3344 // Analysis conversions don't fail if any operations fail to legalize,
3345 // they are only interested in the operations that were successfully
3346 // legalized.
3347 if (config.legalizableOps && !isRecursiveLegalization)
3348 config.legalizableOps->insert(op);
3349 }
3350 return success();
3351}
3352
3353static LogicalResult
3355 UnrealizedConversionCastOp op,
3356 const UnresolvedMaterializationInfo &info) {
3357 assert(!op.use_empty() &&
3358 "expected that dead materializations have already been DCE'd");
3359 Operation::operand_range inputOperands = op.getOperands();
3360
3361 // Try to materialize the conversion.
3362 if (const TypeConverter *converter = info.getConverter()) {
3363 rewriter.setInsertionPoint(op);
3364 SmallVector<Value> newMaterialization;
3365 switch (info.getMaterializationKind()) {
3366 case MaterializationKind::Target:
3367 newMaterialization = converter->materializeTargetConversion(
3368 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3369 info.getOriginalType());
3370 break;
3371 case MaterializationKind::Source:
3372 assert(op->getNumResults() == 1 && "expected single result");
3373 Value sourceMat = converter->materializeSourceConversion(
3374 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3375 if (sourceMat)
3376 newMaterialization.push_back(sourceMat);
3377 break;
3378 }
3379 if (!newMaterialization.empty()) {
3380#ifndef NDEBUG
3381 ValueRange newMaterializationRange(newMaterialization);
3382 assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
3383 "materialization callback produced value of incorrect type");
3384#endif // NDEBUG
3385 rewriter.replaceOp(op, newMaterialization);
3386 return success();
3387 }
3388 }
3389
3390 InFlightDiagnostic diag = op->emitError()
3391 << "failed to legalize unresolved materialization "
3392 "from ("
3393 << inputOperands.getTypes() << ") to ("
3394 << op.getResultTypes()
3395 << ") that remained live after conversion";
3396 diag.attachNote(op->getUsers().begin()->getLoc())
3397 << "see existing live user here: " << *op->getUsers().begin();
3398 return failure();
3399}
3400
3402 const ConversionTarget &target = opLegalizer.getTarget();
3403
3404 // Compute the set of operations and blocks to convert.
3405 SmallVector<Operation *> toConvert;
3406 for (auto *op : ops) {
3408 [&](Operation *op) {
3409 toConvert.push_back(op);
3410 // Don't check this operation's children for conversion if the
3411 // operation is recursively legal.
3412 auto legalityInfo = target.isLegal(op);
3413 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3414 return WalkResult::skip();
3415 return WalkResult::advance();
3416 });
3417 }
3418
3419 // Convert each operation and discard rewrites on failure.
3420 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
3421
3422 for (auto *op : toConvert) {
3423 if (failed(convert(op))) {
3424 // Dialect conversion failed.
3425 if (rewriterImpl.config.allowPatternRollback) {
3426 // Rollback is allowed: restore the original IR.
3427 rewriterImpl.undoRewrites();
3428 } else {
3429 // Rollback is not allowed: apply all modifications that have been
3430 // performed so far.
3431 rewriterImpl.applyRewrites();
3432 }
3433 return failure();
3434 }
3435 }
3436
3437 // After a successful conversion, apply rewrites.
3438 rewriterImpl.applyRewrites();
3439
3440 // Reconcile all UnrealizedConversionCastOps that were inserted by the
3441 // dialect conversion frameworks. (Not the ones that were inserted by
3442 // patterns.)
3444 &materializations = rewriterImpl.unresolvedMaterializations;
3446 reconcileUnrealizedCasts(materializations, &remainingCastOps);
3447
3448 // Drop markers.
3449 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3450 castOp->removeAttr(kPureTypeConversionMarker);
3451
3452 // Try to legalize all unresolved materializations.
3453 if (rewriter.getConfig().buildMaterializations) {
3454 // Use a new rewriter, so the modifications are not tracked for rollback
3455 // purposes etc.
3456 IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
3457 rewriter.getConfig().listener);
3458 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3459 auto it = materializations.find(castOp);
3460 assert(it != materializations.end() && "inconsistent state");
3461 if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
3462 it->second)))
3463 return failure();
3464 }
3465 }
3466
3467 return success();
3468}
3469
3470//===----------------------------------------------------------------------===//
3471// Type Conversion
3472//===----------------------------------------------------------------------===//
3473
3474void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
3475 ArrayRef<Type> types) {
3476 assert(!types.empty() && "expected valid types");
3477 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
3478 addInputs(types);
3479}
3480
3481void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
3482 assert(!types.empty() &&
3483 "1->0 type remappings don't need to be added explicitly");
3484 argTypes.append(types.begin(), types.end());
3485}
3486
3487void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3488 unsigned newInputNo,
3489 unsigned newInputCount) {
3490 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3491 assert(newInputCount != 0 && "expected valid input count");
3492 remappedInputs[origInputNo] =
3493 InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
3494}
3495
3496void TypeConverter::SignatureConversion::remapInput(
3497 unsigned origInputNo, ArrayRef<Value> replacements) {
3498 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3499 remappedInputs[origInputNo] = InputMapping{
3500 origInputNo, /*size=*/0,
3501 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3502}
3503
3504/// Internal implementation of the type conversion.
3505/// This is used with either a Type or a Value as the first argument.
3506/// - we can cache the context-free conversions until the last registered
3507/// context-aware conversion.
3508/// - we can't cache the result of type conversion happening after context-aware
3509/// conversions, because the type converter may return different results for the
3510/// same input type.
3511LogicalResult
3512TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3513 SmallVectorImpl<Type> &results) const {
3514 assert(typeOrValue && "expected non-null type");
3515 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3516 : cast<Type>(typeOrValue);
3517 {
3518 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3519 std::defer_lock);
3521 cacheReadLock.lock();
3522 auto existingIt = cachedDirectConversions.find(t);
3523 if (existingIt != cachedDirectConversions.end()) {
3524 if (existingIt->second)
3525 results.push_back(existingIt->second);
3526 return success(existingIt->second != nullptr);
3527 }
3528 auto multiIt = cachedMultiConversions.find(t);
3529 if (multiIt != cachedMultiConversions.end()) {
3530 results.append(multiIt->second.begin(), multiIt->second.end());
3531 return success();
3532 }
3533 }
3534 // Walk the added converters in reverse order to apply the most recently
3535 // registered first.
3536 size_t currentCount = results.size();
3537
3538 // We can cache the context-free conversions until the last registered
3539 // context-aware conversion. But only if we're processing a Value right now.
3540 auto isCacheable = [&](int index) {
3541 int numberOfConversionsUntilContextAware =
3542 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3543 return index < numberOfConversionsUntilContextAware;
3544 };
3545
3546 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3547 std::defer_lock);
3548
3549 for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3550 const ConversionCallbackFn &converter = indexedConverter.value();
3551 std::optional<LogicalResult> result = converter(typeOrValue, results);
3552 if (!result) {
3553 assert(results.size() == currentCount &&
3554 "failed type conversion should not change results");
3555 continue;
3556 }
3557 if (!isCacheable(indexedConverter.index()))
3558 return success();
3560 cacheWriteLock.lock();
3561 if (!succeeded(*result)) {
3562 assert(results.size() == currentCount &&
3563 "failed type conversion should not change results");
3564 cachedDirectConversions.try_emplace(t, nullptr);
3565 return failure();
3566 }
3567 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3568 if (newTypes.size() == 1)
3569 cachedDirectConversions.try_emplace(t, newTypes.front());
3570 else
3571 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3572 return success();
3573 }
3574 return failure();
3575}
3576
3577LogicalResult TypeConverter::convertType(Type t,
3578 SmallVectorImpl<Type> &results) const {
3579 return convertTypeImpl(t, results);
3580}
3581
3582LogicalResult TypeConverter::convertType(Value v,
3583 SmallVectorImpl<Type> &results) const {
3584 return convertTypeImpl(v, results);
3585}
3586
3587Type TypeConverter::convertType(Type t) const {
3588 // Use the multi-type result version to convert the type.
3589 SmallVector<Type, 1> results;
3590 if (failed(convertType(t, results)))
3591 return nullptr;
3592
3593 // Check to ensure that only one type was produced.
3594 return results.size() == 1 ? results.front() : nullptr;
3595}
3596
3597Type TypeConverter::convertType(Value v) const {
3598 // Use the multi-type result version to convert the type.
3599 SmallVector<Type, 1> results;
3600 if (failed(convertType(v, results)))
3601 return nullptr;
3602
3603 // Check to ensure that only one type was produced.
3604 return results.size() == 1 ? results.front() : nullptr;
3605}
3606
3607LogicalResult
3608TypeConverter::convertTypes(TypeRange types,
3609 SmallVectorImpl<Type> &results) const {
3610 for (Type type : types)
3611 if (failed(convertType(type, results)))
3612 return failure();
3613 return success();
3614}
3615
3616LogicalResult
3617TypeConverter::convertTypes(ValueRange values,
3618 SmallVectorImpl<Type> &results) const {
3619 for (Value value : values)
3620 if (failed(convertType(value, results)))
3621 return failure();
3622 return success();
3623}
3624
3625bool TypeConverter::isLegal(Type type) const {
3626 return convertType(type) == type;
3627}
3628
3629bool TypeConverter::isLegal(Value value) const {
3630 return convertType(value) == value.getType();
3631}
3632
3633bool TypeConverter::isLegal(Operation *op) const {
3634 return isLegal(op->getOperands()) && isLegal(op->getResults());
3635}
3636
3637bool TypeConverter::isLegal(Region *region) const {
3638 return llvm::all_of(
3639 *region, [this](Block &block) { return isLegal(block.getArguments()); });
3640}
3641
3642bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3643 if (!isLegal(ty.getInputs()))
3644 return false;
3645 if (!isLegal(ty.getResults()))
3646 return false;
3647 return true;
3648}
3649
3650LogicalResult
3651TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
3652 SignatureConversion &result) const {
3653 // Try to convert the given input type.
3654 SmallVector<Type, 1> convertedTypes;
3655 if (failed(convertType(type, convertedTypes)))
3656 return failure();
3657
3658 // If this argument is being dropped, there is nothing left to do.
3659 if (convertedTypes.empty())
3660 return success();
3661
3662 // Otherwise, add the new inputs.
3663 result.addInputs(inputNo, convertedTypes);
3664 return success();
3665}
3666LogicalResult
3667TypeConverter::convertSignatureArgs(TypeRange types,
3668 SignatureConversion &result,
3669 unsigned origInputOffset) const {
3670 for (unsigned i = 0, e = types.size(); i != e; ++i)
3671 if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3672 return failure();
3673 return success();
3674}
3675LogicalResult
3676TypeConverter::convertSignatureArg(unsigned inputNo, Value value,
3677 SignatureConversion &result) const {
3678 // Try to convert the given input type.
3679 SmallVector<Type, 1> convertedTypes;
3680 if (failed(convertType(value, convertedTypes)))
3681 return failure();
3682
3683 // If this argument is being dropped, there is nothing left to do.
3684 if (convertedTypes.empty())
3685 return success();
3686
3687 // Otherwise, add the new inputs.
3688 result.addInputs(inputNo, convertedTypes);
3689 return success();
3690}
3691LogicalResult
3692TypeConverter::convertSignatureArgs(ValueRange values,
3693 SignatureConversion &result,
3694 unsigned origInputOffset) const {
3695 for (unsigned i = 0, e = values.size(); i != e; ++i)
3696 if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
3697 return failure();
3698 return success();
3699}
3700
3701Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3702 Location loc, Type resultType,
3703 ValueRange inputs) const {
3704 for (const SourceMaterializationCallbackFn &fn :
3705 llvm::reverse(sourceMaterializations))
3706 if (Value result = fn(builder, resultType, inputs, loc))
3707 return result;
3708 return nullptr;
3709}
3710
3711Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3712 Location loc, Type resultType,
3713 ValueRange inputs,
3714 Type originalType) const {
3715 SmallVector<Value> result = materializeTargetConversion(
3716 builder, loc, TypeRange(resultType), inputs, originalType);
3717 if (result.empty())
3718 return nullptr;
3719 assert(result.size() == 1 && "expected single result");
3720 return result.front();
3721}
3722
3723SmallVector<Value> TypeConverter::materializeTargetConversion(
3724 OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3725 Type originalType) const {
3726 for (const TargetMaterializationCallbackFn &fn :
3727 llvm::reverse(targetMaterializations)) {
3728 SmallVector<Value> result =
3729 fn(builder, resultTypes, inputs, loc, originalType);
3730 if (result.empty())
3731 continue;
3732 assert(TypeRange(ValueRange(result)) == resultTypes &&
3733 "callback produced incorrect number of values or values with "
3734 "incorrect types");
3735 return result;
3736 }
3737 return {};
3738}
3739
3740std::optional<TypeConverter::SignatureConversion>
3741TypeConverter::convertBlockSignature(Block *block) const {
3742 SignatureConversion conversion(block->getNumArguments());
3743 if (failed(convertSignatureArgs(block->getArguments(), conversion)))
3744 return std::nullopt;
3745 return conversion;
3746}
3747
3748//===----------------------------------------------------------------------===//
3749// Type attribute conversion
3750//===----------------------------------------------------------------------===//
3751TypeConverter::AttributeConversionResult
3752TypeConverter::AttributeConversionResult::result(Attribute attr) {
3753 return AttributeConversionResult(attr, resultTag);
3754}
3755
3756TypeConverter::AttributeConversionResult
3757TypeConverter::AttributeConversionResult::na() {
3758 return AttributeConversionResult(nullptr, naTag);
3759}
3760
3761TypeConverter::AttributeConversionResult
3762TypeConverter::AttributeConversionResult::abort() {
3763 return AttributeConversionResult(nullptr, abortTag);
3764}
3765
3766bool TypeConverter::AttributeConversionResult::hasResult() const {
3767 return impl.getInt() == resultTag;
3768}
3769
3770bool TypeConverter::AttributeConversionResult::isNa() const {
3771 return impl.getInt() == naTag;
3772}
3773
3774bool TypeConverter::AttributeConversionResult::isAbort() const {
3775 return impl.getInt() == abortTag;
3776}
3777
3778Attribute TypeConverter::AttributeConversionResult::getResult() const {
3779 assert(hasResult() && "Cannot get result from N/A or abort");
3780 return impl.getPointer();
3781}
3782
3783std::optional<Attribute>
3784TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
3785 for (const TypeAttributeConversionCallbackFn &fn :
3786 llvm::reverse(typeAttributeConversions)) {
3787 AttributeConversionResult res = fn(type, attr);
3788 if (res.hasResult())
3789 return res.getResult();
3790 if (res.isAbort())
3791 return std::nullopt;
3792 }
3793 return std::nullopt;
3794}
3795
3796//===----------------------------------------------------------------------===//
3797// FunctionOpInterfaceSignatureConversion
3798//===----------------------------------------------------------------------===//
3799
3800static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3801 const TypeConverter &typeConverter,
3802 ConversionPatternRewriter &rewriter) {
3803 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3804 if (!type)
3805 return failure();
3806
3807 // Convert the original function types.
3808 TypeConverter::SignatureConversion result(type.getNumInputs());
3809 SmallVector<Type, 1> newResults;
3810 if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3811 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3812 return failure();
3813 if (!funcOp.getFunctionBody().empty())
3814 rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
3815 &typeConverter);
3816
3817 // Update the function signature in-place.
3818 auto newType = FunctionType::get(rewriter.getContext(),
3819 result.getConvertedTypes(), newResults);
3820
3821 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3822
3823 return success();
3824}
3825
3826/// Create a default conversion pattern that rewrites the type signature of a
3827/// FunctionOpInterface op. This only supports ops which use FunctionType to
3828/// represent their type.
3829namespace {
3830struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3831 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3832 MLIRContext *ctx,
3833 const TypeConverter &converter,
3834 PatternBenefit benefit)
3835 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3836
3837 LogicalResult
3838 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3839 ConversionPatternRewriter &rewriter) const override {
3840 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3841 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3842 }
3843};
3844
3845struct AnyFunctionOpInterfaceSignatureConversion
3846 : public OpInterfaceConversionPattern<FunctionOpInterface> {
3847 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3848
3849 LogicalResult
3850 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3851 ConversionPatternRewriter &rewriter) const override {
3852 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3853 }
3854};
3855} // namespace
3856
3857FailureOr<Operation *>
3858mlir::convertOpResultTypes(Operation *op, ValueRange operands,
3859 const TypeConverter &converter,
3860 ConversionPatternRewriter &rewriter) {
3861 assert(op && "Invalid op");
3862 Location loc = op->getLoc();
3863 if (converter.isLegal(op))
3864 return rewriter.notifyMatchFailure(loc, "op already legal");
3865
3866 OperationState newOp(loc, op->getName());
3867 newOp.addOperands(operands);
3868
3869 SmallVector<Type> newResultTypes;
3870 if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
3871 return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3872
3873 newOp.addTypes(newResultTypes);
3874 newOp.addAttributes(op->getAttrs());
3875 return rewriter.create(newOp);
3876}
3877
3878void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3879 StringRef functionLikeOpName, RewritePatternSet &patterns,
3880 const TypeConverter &converter, PatternBenefit benefit) {
3881 patterns.add<FunctionOpInterfaceSignatureConversion>(
3882 functionLikeOpName, patterns.getContext(), converter, benefit);
3883}
3884
3885void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3886 RewritePatternSet &patterns, const TypeConverter &converter,
3887 PatternBenefit benefit) {
3888 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3889 converter, patterns.getContext(), benefit);
3890}
3891
3892//===----------------------------------------------------------------------===//
3893// ConversionTarget
3894//===----------------------------------------------------------------------===//
3895
3896void ConversionTarget::setOpAction(OperationName op,
3897 LegalizationAction action) {
3898 legalOperations[op].action = action;
3899}
3900
3901void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3902 LegalizationAction action) {
3903 for (StringRef dialect : dialectNames)
3904 legalDialects[dialect] = action;
3905}
3906
3907auto ConversionTarget::getOpAction(OperationName op) const
3908 -> std::optional<LegalizationAction> {
3909 std::optional<LegalizationInfo> info = getOpInfo(op);
3910 return info ? info->action : std::optional<LegalizationAction>();
3911}
3912
3913auto ConversionTarget::isLegal(Operation *op) const
3914 -> std::optional<LegalOpDetails> {
3915 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3916 if (!info)
3917 return std::nullopt;
3918
3919 // Returns true if this operation instance is known to be legal.
3920 auto isOpLegal = [&] {
3921 // Handle dynamic legality either with the provided legality function.
3922 if (info->action == LegalizationAction::Dynamic) {
3923 std::optional<bool> result = info->legalityFn(op);
3924 if (result)
3925 return *result;
3926 }
3927
3928 // Otherwise, the operation is only legal if it was marked 'Legal'.
3929 return info->action == LegalizationAction::Legal;
3930 };
3931 if (!isOpLegal())
3932 return std::nullopt;
3933
3934 // This operation is legal, compute any additional legality information.
3935 LegalOpDetails legalityDetails;
3936 if (info->isRecursivelyLegal) {
3937 auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3938 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3939 legalityDetails.isRecursivelyLegal =
3940 legalityFnIt->second(op).value_or(true);
3941 } else {
3942 legalityDetails.isRecursivelyLegal = true;
3943 }
3944 }
3945 return legalityDetails;
3946}
3947
3948bool ConversionTarget::isIllegal(Operation *op) const {
3949 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3950 if (!info)
3951 return false;
3952
3953 if (info->action == LegalizationAction::Dynamic) {
3954 std::optional<bool> result = info->legalityFn(op);
3955 if (!result)
3956 return false;
3957
3958 return !(*result);
3959 }
3960
3961 return info->action == LegalizationAction::Illegal;
3962}
3963
3964static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
3965 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
3966 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
3967 if (!oldCallback)
3968 return newCallback;
3969
3970 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3971 Operation *op) -> std::optional<bool> {
3972 if (std::optional<bool> result = newCl(op))
3973 return *result;
3974
3975 return oldCl(op);
3976 };
3977 return chain;
3978}
3979
3980void ConversionTarget::setLegalityCallback(
3981 OperationName name, const DynamicLegalityCallbackFn &callback) {
3982 assert(callback && "expected valid legality callback");
3983 auto *infoIt = legalOperations.find(name);
3984 assert(infoIt != legalOperations.end() &&
3985 infoIt->second.action == LegalizationAction::Dynamic &&
3986 "expected operation to already be marked as dynamically legal");
3987 infoIt->second.legalityFn =
3988 composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3989}
3990
3991void ConversionTarget::markOpRecursivelyLegal(
3992 OperationName name, const DynamicLegalityCallbackFn &callback) {
3993 auto *infoIt = legalOperations.find(name);
3994 assert(infoIt != legalOperations.end() &&
3995 infoIt->second.action != LegalizationAction::Illegal &&
3996 "expected operation to already be marked as legal");
3997 infoIt->second.isRecursivelyLegal = true;
3998 if (callback)
3999 opRecursiveLegalityFns[name] = composeLegalityCallbacks(
4000 std::move(opRecursiveLegalityFns[name]), callback);
4001 else
4002 opRecursiveLegalityFns.erase(name);
4003}
4004
4005void ConversionTarget::setLegalityCallback(
4006 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
4007 assert(callback && "expected valid legality callback");
4008 for (StringRef dialect : dialects)
4009 dialectLegalityFns[dialect] = composeLegalityCallbacks(
4010 std::move(dialectLegalityFns[dialect]), callback);
4011}
4012
4013void ConversionTarget::setLegalityCallback(
4014 const DynamicLegalityCallbackFn &callback) {
4015 assert(callback && "expected valid legality callback");
4016 unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
4017}
4018
4019auto ConversionTarget::getOpInfo(OperationName op) const
4020 -> std::optional<LegalizationInfo> {
4021 // Check for info for this specific operation.
4022 const auto *it = legalOperations.find(op);
4023 if (it != legalOperations.end())
4024 return it->second;
4025 // Check for info for the parent dialect.
4026 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4027 if (dialectIt != legalDialects.end()) {
4028 DynamicLegalityCallbackFn callback;
4029 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4030 if (dialectFn != dialectLegalityFns.end())
4031 callback = dialectFn->second;
4032 return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
4033 callback};
4034 }
4035 // Otherwise, check if we mark unknown operations as dynamic.
4036 if (unknownLegalityFn)
4037 return LegalizationInfo{LegalizationAction::Dynamic,
4038 /*isRecursivelyLegal=*/false, unknownLegalityFn};
4039 return std::nullopt;
4040}
4041
4042#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4043//===----------------------------------------------------------------------===//
4044// PDL Configuration
4045//===----------------------------------------------------------------------===//
4046
4047void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4048 auto &rewriterImpl =
4049 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4050 rewriterImpl.currentTypeConverter = getTypeConverter();
4051}
4052
4053void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4054 auto &rewriterImpl =
4055 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4056 rewriterImpl.currentTypeConverter = nullptr;
4057}
4058
4059/// Remap the given value using the rewriter and the type converter in the
4060/// provided config.
4061static FailureOr<SmallVector<Value>>
4062pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
4063 SmallVector<Value> mappedValues;
4064 if (failed(rewriter.getRemappedValues(values, mappedValues)))
4065 return failure();
4066 return std::move(mappedValues);
4067}
4068
4069void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
4070 patterns.getPDLPatterns().registerRewriteFunction(
4071 "convertValue",
4072 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4073 auto results = pdllConvertValues(
4074 static_cast<ConversionPatternRewriter &>(rewriter), value);
4075 if (failed(results))
4076 return failure();
4077 return results->front();
4078 });
4079 patterns.getPDLPatterns().registerRewriteFunction(
4080 "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
4081 return pdllConvertValues(
4082 static_cast<ConversionPatternRewriter &>(rewriter), values);
4083 });
4084 patterns.getPDLPatterns().registerRewriteFunction(
4085 "convertType",
4086 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4087 auto &rewriterImpl =
4088 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4089 if (const TypeConverter *converter =
4090 rewriterImpl.currentTypeConverter) {
4091 if (Type newType = converter->convertType(type))
4092 return newType;
4093 return failure();
4094 }
4095 return type;
4096 });
4097 patterns.getPDLPatterns().registerRewriteFunction(
4098 "convertTypes",
4099 [](PatternRewriter &rewriter,
4100 TypeRange types) -> FailureOr<SmallVector<Type>> {
4101 auto &rewriterImpl =
4102 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4103 const TypeConverter *converter = rewriterImpl.currentTypeConverter;
4104 if (!converter)
4105 return SmallVector<Type>(types);
4106
4107 SmallVector<Type> remappedTypes;
4108 if (failed(converter->convertTypes(types, remappedTypes)))
4109 return failure();
4110 return std::move(remappedTypes);
4111 });
4112}
4113#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
4114
4115//===----------------------------------------------------------------------===//
4116// Op Conversion Entry Points
4117//===----------------------------------------------------------------------===//
4118
4119/// This is the type of Action that is dispatched when a conversion is applied.
4121 : public tracing::ActionImpl<ApplyConversionAction> {
4122public:
4125 static constexpr StringLiteral tag = "apply-conversion";
4126 static constexpr StringLiteral desc =
4127 "Encapsulate the application of a dialect conversion";
4128
4129 void print(raw_ostream &os) const override { os << tag; }
4130};
4131
4133 const ConversionTarget &target,
4135 ConversionConfig config,
4136 OpConversionMode mode) {
4137 if (ops.empty())
4138 return success();
4139 MLIRContext *ctx = ops.front()->getContext();
4140 LogicalResult status = success();
4141 SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
4143 [&] {
4144 OperationConverter opConverter(ops.front()->getContext(), target,
4145 patterns, config, mode);
4146 status = opConverter.convertOperations(ops);
4147 },
4148 irUnits);
4149 return status;
4150}
4151
4152//===----------------------------------------------------------------------===//
4153// Partial Conversion
4154//===----------------------------------------------------------------------===//
4155
4156LogicalResult mlir::applyPartialConversion(
4157 ArrayRef<Operation *> ops, const ConversionTarget &target,
4158 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4159 return applyConversion(ops, target, patterns, config,
4160 OpConversionMode::Partial);
4161}
4162LogicalResult
4163mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
4164 const FrozenRewritePatternSet &patterns,
4165 ConversionConfig config) {
4166 return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
4167}
4168
4169//===----------------------------------------------------------------------===//
4170// Full Conversion
4171//===----------------------------------------------------------------------===//
4172
4173LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4174 const ConversionTarget &target,
4175 const FrozenRewritePatternSet &patterns,
4176 ConversionConfig config) {
4177 return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
4178}
4179LogicalResult mlir::applyFullConversion(Operation *op,
4180 const ConversionTarget &target,
4181 const FrozenRewritePatternSet &patterns,
4182 ConversionConfig config) {
4183 return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
4184}
4185
4186//===----------------------------------------------------------------------===//
4187// Analysis Conversion
4188//===----------------------------------------------------------------------===//
4189
4190/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
4191/// op is a top-level module op (which is expected to be isolated from above),
4192/// return that op.
4194 // Check if there is a top-level operation within `ops`. If so, return that
4195 // op.
4196 for (Operation *op : ops) {
4197 if (!op->getParentOp()) {
4198#ifndef NDEBUG
4199 assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
4200 "expected top-level op to be isolated from above");
4201 for (Operation *other : ops)
4202 assert(op->isAncestor(other) &&
4203 "expected ops to have a common ancestor");
4204#endif // NDEBUG
4205 return op;
4206 }
4207 }
4208
4209 // No top-level op. Find a common ancestor.
4210 Operation *commonAncestor =
4211 ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
4212 for (Operation *op : ops.drop_front()) {
4213 while (!commonAncestor->isProperAncestor(op)) {
4214 commonAncestor =
4216 assert(commonAncestor &&
4217 "expected to find a common isolated from above ancestor");
4218 }
4219 }
4220
4221 return commonAncestor;
4222}
4223
4224LogicalResult mlir::applyAnalysisConversion(
4225 ArrayRef<Operation *> ops, ConversionTarget &target,
4226 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4227#ifndef NDEBUG
4228 if (config.legalizableOps)
4229 assert(config.legalizableOps->empty() && "expected empty set");
4230#endif // NDEBUG
4231
4232 // Clone closted common ancestor that is isolated from above.
4233 Operation *commonAncestor = findCommonAncestor(ops);
4234 IRMapping mapping;
4235 Operation *clonedAncestor = commonAncestor->clone(mapping);
4236 // Compute inverse IR mapping.
4237 DenseMap<Operation *, Operation *> inverseOperationMap;
4238 for (auto &it : mapping.getOperationMap())
4239 inverseOperationMap[it.second] = it.first;
4240
4241 // Convert the cloned operations. The original IR will remain unchanged.
4242 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4243 ops, [&](Operation *op) { return mapping.lookup(op); });
4244 LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
4245 OpConversionMode::Analysis);
4246
4247 // Remap `legalizableOps`, so that they point to the original ops and not the
4248 // cloned ops.
4249 if (config.legalizableOps) {
4250 DenseSet<Operation *> originalLegalizableOps;
4251 for (Operation *op : *config.legalizableOps)
4252 originalLegalizableOps.insert(inverseOperationMap[op]);
4253 *config.legalizableOps = std::move(originalLegalizableOps);
4254 }
4255
4256 // Erase the cloned IR.
4257 clonedAncestor->erase();
4258 return status;
4259}
4260
4261LogicalResult
4262mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
4263 const FrozenRewritePatternSet &patterns,
4264 ConversionConfig config) {
4265 return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
4266}
return success()
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
#define DEBUG_TYPE
b getContext())
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition SCCP.cpp:67
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Definition Value.h:309
Location getLoc() const
Return the location for this argument.
Definition Value.h:324
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:149
bool empty()
Definition Block.h:148
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:308
BlockArgListType getArguments()
Definition Block.h:87
iterator end()
Definition Block.h:144
iterator begin()
Definition Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
UnitAttr getUnitAttr()
Definition Builders.cpp:98
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
MLIRContext * context
Definition Builders.h:202
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition Builders.h:327
Block::iterator getPoint() const
Definition Builders.h:340
bool isSet() const
Returns true if this insert point is set.
Definition Builders.h:337
Block * getBlock() const
Definition Builders.h:339
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition Builders.cpp:473
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:421
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for ops that are known to be isolated from above.
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
void destroyOpProperties(OpaqueProperties properties) const
This hooks destroy the op properties.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:226
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:852
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:834
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition Operation.h:900
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
CRTP Implementation of an action.
Definition Action.h:76
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Definition Action.h:66
AttrTypeReplacer.
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
This iterator enumerates elements according to their dominance relationship.
Definition Iterators.h:48
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition Builders.h:308
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition Builders.h:298
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.