MLIR 22.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/Config/mlir-config.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Iterators.h"
17#include "mlir/IR/Operation.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/FormatVariadic.h"
25#include "llvm/Support/SaveAndRestore.h"
26#include "llvm/Support/ScopedPrinter.h"
27#include <optional>
28#include <utility>
29
30using namespace mlir;
31using namespace mlir::detail;
32
33#define DEBUG_TYPE "dialect-conversion"
34
35/// A utility function to log a successful result for the given reason.
36template <typename... Args>
37static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
38 LLVM_DEBUG({
39 os.unindent();
40 os.startLine() << "} -> SUCCESS";
41 if (!fmt.empty())
42 os.getOStream() << " : "
43 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
44 os.getOStream() << "\n";
45 });
46}
47
48/// A utility function to log a failure result for the given reason.
49template <typename... Args>
50static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
51 LLVM_DEBUG({
52 os.unindent();
53 os.startLine() << "} -> FAILURE : "
54 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
55 << "\n";
56 });
57}
58
59/// Helper function that computes an insertion point where the given value is
60/// defined and can be used without a dominance violation.
62 Block *insertBlock = value.getParentBlock();
63 Block::iterator insertPt = insertBlock->begin();
64 if (OpResult inputRes = dyn_cast<OpResult>(value))
65 insertPt = ++inputRes.getOwner()->getIterator();
66 return OpBuilder::InsertPoint(insertBlock, insertPt);
67}
68
69/// Helper function that computes an insertion point where the given values are
70/// defined and can be used without a dominance violation.
72 assert(!vals.empty() && "expected at least one value");
73 DominanceInfo domInfo;
75 for (Value v : vals.drop_front()) {
76 // Choose the "later" insertion point.
78 if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
79 nextPt.getPoint())) {
80 // pt is before nextPt => choose nextPt.
81 pt = nextPt;
82 } else {
83#ifndef NDEBUG
84 // nextPt should be before pt => choose pt.
85 // If pt, nextPt are no dominance relationship, then there is no valid
86 // insertion point at which all given values are defined.
87 bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
88 pt.getBlock(), pt.getPoint());
89 assert(dom && "unable to find valid insertion point");
90#endif // NDEBUG
91 }
92 }
93 return pt;
94}
95
96namespace {
97enum OpConversionMode {
98 /// In this mode, the conversion will ignore failed conversions to allow
99 /// illegal operations to co-exist in the IR.
100 Partial,
101
102 /// In this mode, all operations must be legal for the given target for the
103 /// conversion to succeed.
104 Full,
105
106 /// In this mode, operations are analyzed for legality. No actual rewrites are
107 /// applied to the operations on success.
108 Analysis,
109};
110} // namespace
111
112//===----------------------------------------------------------------------===//
113// ConversionValueMapping
114//===----------------------------------------------------------------------===//
115
116/// A vector of SSA values, optimized for the most common case of a single
117/// value.
119
120namespace {
121
122/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
123struct ValueVectorMapInfo {
124 static ValueVector getEmptyKey() { return ValueVector{Value()}; }
125 static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
126 static ::llvm::hash_code getHashValue(const ValueVector &val) {
127 return ::llvm::hash_combine_range(val);
128 }
129 static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
130 return LHS == RHS;
131 }
132};
133
134/// This class wraps a IRMapping to provide recursive lookup
135/// functionality, i.e. we will traverse if the mapped value also has a mapping.
136struct ConversionValueMapping {
137 /// Return "true" if an SSA value is mapped to the given value. May return
138 /// false positives.
139 bool isMappedTo(Value value) const { return mappedTo.contains(value); }
140
141 /// Lookup a value in the mapping.
142 ValueVector lookup(const ValueVector &from) const;
143
144 template <typename T>
145 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
146
147 /// Map a value vector to the one provided.
148 template <typename OldVal, typename NewVal>
149 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
150 map(OldVal &&oldVal, NewVal &&newVal) {
151 LLVM_DEBUG({
152 ValueVector next(newVal);
153 while (true) {
154 assert(next != oldVal && "inserting cyclic mapping");
155 auto it = mapping.find(next);
156 if (it == mapping.end())
157 break;
158 next = it->second;
159 }
160 });
161 mappedTo.insert_range(newVal);
162
163 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
164 }
165
166 /// Map a value vector or single value to the one provided.
167 template <typename OldVal, typename NewVal>
168 std::enable_if_t<!IsValueVector<OldVal>::value ||
169 !IsValueVector<NewVal>::value>
170 map(OldVal &&oldVal, NewVal &&newVal) {
171 if constexpr (IsValueVector<OldVal>{}) {
172 map(std::forward<OldVal>(oldVal), ValueVector{newVal});
173 } else if constexpr (IsValueVector<NewVal>{}) {
174 map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
175 } else {
176 map(ValueVector{oldVal}, ValueVector{newVal});
177 }
178 }
179
180 void map(Value oldVal, SmallVector<Value> &&newVal) {
181 map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
182 }
183
184 /// Drop the last mapping for the given values.
185 void erase(const ValueVector &value) { mapping.erase(value); }
186
187private:
188 /// Current value mappings.
190
191 /// All SSA values that are mapped to. May contain false positives.
192 DenseSet<Value> mappedTo;
193};
194} // namespace
195
196/// Marker attribute for pure type conversions. I.e., mappings whose only
197/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
198/// the replacement values of a "replaceOp" call, etc., are not pure type
199/// conversions.)
200static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
201
202/// Return the operation that defines all values in the vector. Return nullptr
203/// if the values are not defined by the same operation.
205 assert(!values.empty() && "expected non-empty value vector");
206 Operation *op = values.front().getDefiningOp();
207 for (Value v : llvm::drop_begin(values)) {
208 if (v.getDefiningOp() != op)
209 return nullptr;
210 }
211 return op;
212}
213
214/// A vector of values is a pure type conversion if all values are defined by
215/// the same operation and the operation has the `kPureTypeConversionMarker`
216/// attribute.
217static bool isPureTypeConversion(const ValueVector &values) {
218 assert(!values.empty() && "expected non-empty value vector");
219 Operation *op = getCommonDefiningOp(values);
220 return op && op->hasAttr(kPureTypeConversionMarker);
221}
222
223ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
224 auto it = mapping.find(from);
225 if (it == mapping.end()) {
226 // No mapping found: The lookup stops here.
227 return {};
228 }
229 return it->second;
230}
231
232//===----------------------------------------------------------------------===//
233// Rewriter and Translation State
234//===----------------------------------------------------------------------===//
235namespace {
236/// This class contains a snapshot of the current conversion rewriter state.
237/// This is useful when saving and undoing a set of rewrites.
238struct RewriterState {
239 RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
240 unsigned numReplacedOps)
241 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
242 numReplacedOps(numReplacedOps) {}
243
244 /// The current number of rewrites performed.
245 unsigned numRewrites;
246
247 /// The current number of ignored operations.
248 unsigned numIgnoredOperations;
249
250 /// The current number of replaced ops that are scheduled for erasure.
251 unsigned numReplacedOps;
252};
253
254//===----------------------------------------------------------------------===//
255// IR rewrites
256//===----------------------------------------------------------------------===//
257
258static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
259
260/// Notify the listener that the given block and its contents are being erased.
261static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
262 for (Operation &op : b)
263 notifyIRErased(listener, op);
264 listener->notifyBlockErased(&b);
265}
266
267/// Notify the listener that the given operation and its contents are being
268/// erased.
269static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
270 for (Region &r : op.getRegions()) {
271 for (Block &b : r) {
272 notifyIRErased(listener, b);
273 }
274 }
275 listener->notifyOperationErased(&op);
276}
277
278/// An IR rewrite that can be committed (upon success) or rolled back (upon
279/// failure).
280///
281/// The dialect conversion keeps track of IR modifications (requested by the
282/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
283/// are directly applied to the IR as the rewriter API is used, some are applied
284/// partially, and some are delayed until the `IRRewrite` objects are committed.
285class IRRewrite {
286public:
287 /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
288 /// Enum values are ordered, so that they can be used in `classof`: first all
289 /// block rewrites, then all operation rewrites.
290 enum class Kind {
291 // Block rewrites
292 CreateBlock,
293 EraseBlock,
294 InlineBlock,
295 MoveBlock,
296 BlockTypeConversion,
297 // Operation rewrites
298 MoveOperation,
299 ModifyOperation,
300 ReplaceOperation,
301 CreateOperation,
302 UnresolvedMaterialization,
303 // Value rewrites
304 ReplaceValue
305 };
306
307 virtual ~IRRewrite() = default;
308
309 /// Roll back the rewrite. Operations may be erased during rollback.
310 virtual void rollback() = 0;
311
312 /// Commit the rewrite. At this point, it is certain that the dialect
313 /// conversion will succeed. All IR modifications, except for operation/block
314 /// erasure, must be performed through the given rewriter.
315 ///
316 /// Instead of erasing operations/blocks, they should merely be unlinked
317 /// commit phase and finally be erased during the cleanup phase. This is
318 /// because internal dialect conversion state (such as `mapping`) may still
319 /// be using them.
320 ///
321 /// Any IR modification that was already performed before the commit phase
322 /// (e.g., insertion of an op) must be communicated to the listener that may
323 /// be attached to the given rewriter.
324 virtual void commit(RewriterBase &rewriter) {}
325
326 /// Cleanup operations/blocks. Cleanup is called after commit.
327 virtual void cleanup(RewriterBase &rewriter) {}
328
329 Kind getKind() const { return kind; }
330
331 static bool classof(const IRRewrite *rewrite) { return true; }
332
333protected:
334 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
335 : kind(kind), rewriterImpl(rewriterImpl) {}
336
337 const ConversionConfig &getConfig() const;
338
339 const Kind kind;
340 ConversionPatternRewriterImpl &rewriterImpl;
341};
342
343/// A block rewrite.
344class BlockRewrite : public IRRewrite {
345public:
346 /// Return the block that this rewrite operates on.
347 Block *getBlock() const { return block; }
348
349 static bool classof(const IRRewrite *rewrite) {
350 return rewrite->getKind() >= Kind::CreateBlock &&
351 rewrite->getKind() <= Kind::BlockTypeConversion;
352 }
353
354protected:
355 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
356 Block *block)
357 : IRRewrite(kind, rewriterImpl), block(block) {}
358
359 // The block that this rewrite operates on.
360 Block *block;
361};
362
363/// A value rewrite.
364class ValueRewrite : public IRRewrite {
365public:
366 /// Return the value that this rewrite operates on.
367 Value getValue() const { return value; }
368
369 static bool classof(const IRRewrite *rewrite) {
370 return rewrite->getKind() == Kind::ReplaceValue;
371 }
372
373protected:
374 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
375 Value value)
376 : IRRewrite(kind, rewriterImpl), value(value) {}
377
378 // The value that this rewrite operates on.
379 Value value;
380};
381
382/// Creation of a block. Block creations are immediately reflected in the IR.
383/// There is no extra work to commit the rewrite. During rollback, the newly
384/// created block is erased.
385class CreateBlockRewrite : public BlockRewrite {
386public:
387 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
388 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
389
390 static bool classof(const IRRewrite *rewrite) {
391 return rewrite->getKind() == Kind::CreateBlock;
392 }
393
394 void commit(RewriterBase &rewriter) override {
395 // The block was already created and inserted. Just inform the listener.
396 if (auto *listener = rewriter.getListener())
397 listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
398 }
399
400 void rollback() override {
401 // Unlink all of the operations within this block, they will be deleted
402 // separately.
403 auto &blockOps = block->getOperations();
404 while (!blockOps.empty())
405 blockOps.remove(blockOps.begin());
406 block->dropAllUses();
407 if (block->getParent())
408 block->erase();
409 else
410 delete block;
411 }
412};
413
414/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
415/// blocks are immediately unlinked, but only erased during cleanup. This makes
416/// it easier to rollback a block erasure: the block is simply inserted into its
417/// original location.
418class EraseBlockRewrite : public BlockRewrite {
419public:
420 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
421 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
422 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
423
424 static bool classof(const IRRewrite *rewrite) {
425 return rewrite->getKind() == Kind::EraseBlock;
426 }
427
428 ~EraseBlockRewrite() override {
429 assert(!block &&
430 "rewrite was neither rolled back nor committed/cleaned up");
431 }
432
433 void rollback() override {
434 // The block (owned by this rewrite) was not actually erased yet. It was
435 // just unlinked. Put it back into its original position.
436 assert(block && "expected block");
437 auto &blockList = region->getBlocks();
438 Region::iterator before = insertBeforeBlock
439 ? Region::iterator(insertBeforeBlock)
440 : blockList.end();
441 blockList.insert(before, block);
442 block = nullptr;
443 }
444
445 void commit(RewriterBase &rewriter) override {
446 assert(block && "expected block");
447
448 // Notify the listener that the block and its contents are being erased.
449 if (auto *listener =
450 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
451 notifyIRErased(listener, *block);
452 }
453
454 void cleanup(RewriterBase &rewriter) override {
455 // Erase the contents of the block.
456 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
457 rewriter.eraseOp(&op);
458 assert(block->empty() && "expected empty block");
459
460 // Erase the block.
461 block->dropAllDefinedValueUses();
462 delete block;
463 block = nullptr;
464 }
465
466private:
467 // The region in which this block was previously contained.
468 Region *region;
469
470 // The original successor of this block before it was unlinked. "nullptr" if
471 // this block was the only block in the region.
472 Block *insertBeforeBlock;
473};
474
475/// Inlining of a block. This rewrite is immediately reflected in the IR.
476/// Note: This rewrite represents only the inlining of the operations. The
477/// erasure of the inlined block is a separate rewrite.
478class InlineBlockRewrite : public BlockRewrite {
479public:
480 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
481 Block *sourceBlock, Block::iterator before)
482 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
483 sourceBlock(sourceBlock),
484 firstInlinedInst(sourceBlock->empty() ? nullptr
485 : &sourceBlock->front()),
486 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
487 // If a listener is attached to the dialect conversion, ops must be moved
488 // one-by-one. When they are moved in bulk, notifications cannot be sent
489 // because the ops that used to be in the source block at the time of the
490 // inlining (before the "commit" phase) are unknown at the time when
491 // notifications are sent (which is during the "commit" phase).
492 assert(!getConfig().listener &&
493 "InlineBlockRewrite not supported if listener is attached");
494 }
495
496 static bool classof(const IRRewrite *rewrite) {
497 return rewrite->getKind() == Kind::InlineBlock;
498 }
499
500 void rollback() override {
501 // Put the operations from the destination block (owned by the rewrite)
502 // back into the source block.
503 if (firstInlinedInst) {
504 assert(lastInlinedInst && "expected operation");
505 sourceBlock->getOperations().splice(sourceBlock->begin(),
506 block->getOperations(),
507 Block::iterator(firstInlinedInst),
508 ++Block::iterator(lastInlinedInst));
509 }
510 }
511
512private:
513 // The block that originally contained the operations.
514 Block *sourceBlock;
515
516 // The first inlined operation.
517 Operation *firstInlinedInst;
518
519 // The last inlined operation.
520 Operation *lastInlinedInst;
521};
522
523/// Moving of a block. This rewrite is immediately reflected in the IR.
524class MoveBlockRewrite : public BlockRewrite {
525public:
526 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
527 Region *previousRegion, Region::iterator previousIt)
528 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
529 region(previousRegion),
530 insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
531 : &*previousIt) {}
532
533 static bool classof(const IRRewrite *rewrite) {
534 return rewrite->getKind() == Kind::MoveBlock;
535 }
536
537 void commit(RewriterBase &rewriter) override {
538 // The block was already moved. Just inform the listener.
539 if (auto *listener = rewriter.getListener()) {
540 // Note: `previousIt` cannot be passed because this is a delayed
541 // notification and iterators into past IR state cannot be represented.
542 listener->notifyBlockInserted(block, /*previous=*/region,
543 /*previousIt=*/{});
544 }
545 }
546
547 void rollback() override {
548 // Move the block back to its original position.
549 Region::iterator before =
550 insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
551 region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
552 }
553
554private:
555 // The region in which this block was previously contained.
556 Region *region;
557
558 // The original successor of this block before it was moved. "nullptr" if
559 // this block was the only block in the region.
560 Block *insertBeforeBlock;
561};
562
563/// Block type conversion. This rewrite is partially reflected in the IR.
564class BlockTypeConversionRewrite : public BlockRewrite {
565public:
566 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
567 Block *origBlock, Block *newBlock)
568 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
569 newBlock(newBlock) {}
570
571 static bool classof(const IRRewrite *rewrite) {
572 return rewrite->getKind() == Kind::BlockTypeConversion;
573 }
574
575 Block *getOrigBlock() const { return block; }
576
577 Block *getNewBlock() const { return newBlock; }
578
579 void commit(RewriterBase &rewriter) override;
580
581 void rollback() override;
582
583private:
584 /// The new block that was created as part of this signature conversion.
585 Block *newBlock;
586};
587
588/// Replacing a value. This rewrite is not immediately reflected in the
589/// IR. An internal IR mapping is updated, but the actual replacement is delayed
590/// until the rewrite is committed.
591class ReplaceValueRewrite : public ValueRewrite {
592public:
593 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
594 const TypeConverter *converter)
595 : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
596 converter(converter) {}
597
598 static bool classof(const IRRewrite *rewrite) {
599 return rewrite->getKind() == Kind::ReplaceValue;
600 }
601
602 void commit(RewriterBase &rewriter) override;
603
604 void rollback() override;
605
606private:
607 /// The current type converter when the value was replaced.
608 const TypeConverter *converter;
609};
610
611/// An operation rewrite.
612class OperationRewrite : public IRRewrite {
613public:
614 /// Return the operation that this rewrite operates on.
615 Operation *getOperation() const { return op; }
616
617 static bool classof(const IRRewrite *rewrite) {
618 return rewrite->getKind() >= Kind::MoveOperation &&
619 rewrite->getKind() <= Kind::UnresolvedMaterialization;
620 }
621
622protected:
623 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
624 Operation *op)
625 : IRRewrite(kind, rewriterImpl), op(op) {}
626
627 // The operation that this rewrite operates on.
628 Operation *op;
629};
630
631/// Moving of an operation. This rewrite is immediately reflected in the IR.
632class MoveOperationRewrite : public OperationRewrite {
633public:
634 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
635 Operation *op, OpBuilder::InsertPoint previous)
636 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
637 block(previous.getBlock()),
638 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
639 ? nullptr
640 : &*previous.getPoint()) {}
641
642 static bool classof(const IRRewrite *rewrite) {
643 return rewrite->getKind() == Kind::MoveOperation;
644 }
645
646 void commit(RewriterBase &rewriter) override {
647 // The operation was already moved. Just inform the listener.
648 if (auto *listener = rewriter.getListener()) {
649 // Note: `previousIt` cannot be passed because this is a delayed
650 // notification and iterators into past IR state cannot be represented.
651 listener->notifyOperationInserted(
652 op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
653 /*insertPt=*/{}));
654 }
655 }
656
657 void rollback() override {
658 // Move the operation back to its original position.
659 Block::iterator before =
660 insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
661 block->getOperations().splice(before, op->getBlock()->getOperations(), op);
662 }
663
664private:
665 // The block in which this operation was previously contained.
666 Block *block;
667
668 // The original successor of this operation before it was moved. "nullptr"
669 // if this operation was the only operation in the region.
670 Operation *insertBeforeOp;
671};
672
673/// In-place modification of an op. This rewrite is immediately reflected in
674/// the IR. The previous state of the operation is stored in this object.
675class ModifyOperationRewrite : public OperationRewrite {
676public:
677 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
678 Operation *op)
679 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
680 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
681 operands(op->operand_begin(), op->operand_end()),
682 successors(op->successor_begin(), op->successor_end()) {
683 if (OpaqueProperties prop = op->getPropertiesStorage()) {
684 // Make a copy of the properties.
685 propertiesStorage = operator new(op->getPropertiesStorageSize());
686 OpaqueProperties propCopy(propertiesStorage);
687 name.initOpProperties(propCopy, /*init=*/prop);
688 }
689 }
690
691 static bool classof(const IRRewrite *rewrite) {
692 return rewrite->getKind() == Kind::ModifyOperation;
693 }
694
695 ~ModifyOperationRewrite() override {
696 assert(!propertiesStorage &&
697 "rewrite was neither committed nor rolled back");
698 }
699
700 void commit(RewriterBase &rewriter) override {
701 // Notify the listener that the operation was modified in-place.
702 if (auto *listener =
703 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
704 listener->notifyOperationModified(op);
705
706 if (propertiesStorage) {
707 OpaqueProperties propCopy(propertiesStorage);
708 // Note: The operation may have been erased in the mean time, so
709 // OperationName must be stored in this object.
710 name.destroyOpProperties(propCopy);
711 operator delete(propertiesStorage);
712 propertiesStorage = nullptr;
713 }
714 }
715
716 void rollback() override {
717 op->setLoc(loc);
718 op->setAttrs(attrs);
719 op->setOperands(operands);
720 for (const auto &it : llvm::enumerate(successors))
721 op->setSuccessor(it.value(), it.index());
722 if (propertiesStorage) {
723 OpaqueProperties propCopy(propertiesStorage);
724 op->copyProperties(propCopy);
725 name.destroyOpProperties(propCopy);
726 operator delete(propertiesStorage);
727 propertiesStorage = nullptr;
728 }
729 }
730
731private:
732 OperationName name;
733 LocationAttr loc;
734 DictionaryAttr attrs;
735 SmallVector<Value, 8> operands;
736 SmallVector<Block *, 2> successors;
737 void *propertiesStorage = nullptr;
738};
739
740/// Replacing an operation. Erasing an operation is treated as a special case
741/// with "null" replacements. This rewrite is not immediately reflected in the
742/// IR. An internal IR mapping is updated, but values are not replaced and the
743/// original op is not erased until the rewrite is committed.
744class ReplaceOperationRewrite : public OperationRewrite {
745public:
746 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
747 Operation *op, const TypeConverter *converter)
748 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
749 converter(converter) {}
750
751 static bool classof(const IRRewrite *rewrite) {
752 return rewrite->getKind() == Kind::ReplaceOperation;
753 }
754
755 void commit(RewriterBase &rewriter) override;
756
757 void rollback() override;
758
759 void cleanup(RewriterBase &rewriter) override;
760
761private:
762 /// An optional type converter that can be used to materialize conversions
763 /// between the new and old values if necessary.
764 const TypeConverter *converter;
765};
766
767class CreateOperationRewrite : public OperationRewrite {
768public:
769 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
770 Operation *op)
771 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
772
773 static bool classof(const IRRewrite *rewrite) {
774 return rewrite->getKind() == Kind::CreateOperation;
775 }
776
777 void commit(RewriterBase &rewriter) override {
778 // The operation was already created and inserted. Just inform the listener.
779 if (auto *listener = rewriter.getListener())
780 listener->notifyOperationInserted(op, /*previous=*/{});
781 }
782
783 void rollback() override;
784};
785
786/// The type of materialization.
787enum MaterializationKind {
788 /// This materialization materializes a conversion from an illegal type to a
789 /// legal one.
790 Target,
791
792 /// This materialization materializes a conversion from a legal type back to
793 /// an illegal one.
794 Source
795};
796
797/// Helper class that stores metadata about an unresolved materialization.
798class UnresolvedMaterializationInfo {
799public:
800 UnresolvedMaterializationInfo() = default;
801 UnresolvedMaterializationInfo(const TypeConverter *converter,
802 MaterializationKind kind, Type originalType)
803 : converterAndKind(converter, kind), originalType(originalType) {}
804
805 /// Return the type converter of this materialization (which may be null).
806 const TypeConverter *getConverter() const {
807 return converterAndKind.getPointer();
808 }
809
810 /// Return the kind of this materialization.
811 MaterializationKind getMaterializationKind() const {
812 return converterAndKind.getInt();
813 }
814
815 /// Return the original type of the SSA value.
816 Type getOriginalType() const { return originalType; }
817
818private:
819 /// The corresponding type converter to use when resolving this
820 /// materialization, and the kind of this materialization.
821 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
822 converterAndKind;
823
824 /// The original type of the SSA value. Only used for target
825 /// materializations.
826 Type originalType;
827};
828
829/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
830/// op. Unresolved materializations fold away or are replaced with
831/// source/target materializations at the end of the dialect conversion.
832class UnresolvedMaterializationRewrite : public OperationRewrite {
833public:
834 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
835 UnrealizedConversionCastOp op,
836 ValueVector mappedValues)
837 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
838 mappedValues(std::move(mappedValues)) {}
839
840 static bool classof(const IRRewrite *rewrite) {
841 return rewrite->getKind() == Kind::UnresolvedMaterialization;
842 }
843
844 void rollback() override;
845
846 UnrealizedConversionCastOp getOperation() const {
847 return cast<UnrealizedConversionCastOp>(op);
848 }
849
850private:
851 /// The values in the conversion value mapping that are being replaced by the
852 /// results of this unresolved materialization.
853 ValueVector mappedValues;
854};
855} // namespace
856
857#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
858/// Return "true" if there is an operation rewrite that matches the specified
859/// rewrite type and operation among the given rewrites.
860template <typename RewriteTy, typename R>
861static bool hasRewrite(R &&rewrites, Operation *op) {
862 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
863 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
864 return rewriteTy && rewriteTy->getOperation() == op;
865 });
866}
867
868/// Return "true" if there is a block rewrite that matches the specified
869/// rewrite type and block among the given rewrites.
870template <typename RewriteTy, typename R>
871static bool hasRewrite(R &&rewrites, Block *block) {
872 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
873 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
874 return rewriteTy && rewriteTy->getBlock() == block;
875 });
876}
877#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
878
879//===----------------------------------------------------------------------===//
880// ConversionPatternRewriterImpl
881//===----------------------------------------------------------------------===//
882namespace mlir {
883namespace detail {
885 explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
886 const ConversionConfig &config,
890
891 //===--------------------------------------------------------------------===//
892 // State Management
893 //===--------------------------------------------------------------------===//
894
895 /// Return the current state of the rewriter.
896 RewriterState getCurrentState();
897
898 /// Apply all requested operation rewrites. This method is invoked when the
899 /// conversion process succeeds.
900 void applyRewrites();
901
902 /// Reset the state of the rewriter to a previously saved point. Optionally,
903 /// the name of the pattern that triggered the rollback can specified for
904 /// debugging purposes.
905 void resetState(RewriterState state, StringRef patternName = "");
906
907 /// Append a rewrite. Rewrites are committed upon success and rolled back upon
908 /// failure.
909 template <typename RewriteTy, typename... Args>
910 void appendRewrite(Args &&...args) {
911 assert(config.allowPatternRollback && "appending rewrites is not allowed");
912 rewrites.push_back(
913 std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
914 }
915
916 /// Undo the rewrites (motions, splits) one by one in reverse order until
917 /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
918 /// that triggered the rollback can specified for debugging purposes.
919 void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
920
921 /// Remap the given values to those with potentially different types. Returns
922 /// success if the values could be remapped, failure otherwise. `valueDiagTag`
923 /// is the tag used when describing a value within a diagnostic, e.g.
924 /// "operand".
925 LogicalResult remapValues(StringRef valueDiagTag,
926 std::optional<Location> inputLoc, ValueRange values,
927 SmallVector<ValueVector> &remapped);
928
929 /// Return "true" if the given operation is ignored, and does not need to be
930 /// converted.
931 bool isOpIgnored(Operation *op) const;
932
933 /// Return "true" if the given operation was replaced or erased.
934 bool wasOpReplaced(Operation *op) const;
935
936 /// Lookup the most recently mapped values with the desired types in the
937 /// mapping, taking into account only replacements. Perform a best-effort
938 /// search for existing materializations with the desired types.
939 ///
940 /// If `skipPureTypeConversions` is "true", materializations that are pure
941 /// type conversions are not considered.
942 ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
943 bool skipPureTypeConversions = false) const;
944
945 /// Lookup the given value within the map, or return an empty vector if the
946 /// value is not mapped. If it is mapped, this follows the same behavior
947 /// as `lookupOrDefault`.
948 ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
949
950 //===--------------------------------------------------------------------===//
951 // IR Rewrites / Type Conversion
952 //===--------------------------------------------------------------------===//
953
954 /// Convert the types of block arguments within the given region.
955 FailureOr<Block *>
956 convertRegionTypes(Region *region, const TypeConverter &converter,
957 TypeConverter::SignatureConversion *entryConversion);
958
959 /// Apply the given signature conversion on the given block. The new block
960 /// containing the updated signature is returned. If no conversions were
961 /// necessary, e.g. if the block has no arguments, `block` is returned.
962 /// `converter` is used to generate any necessary cast operations that
963 /// translate between the origin argument types and those specified in the
964 /// signature conversion.
965 Block *applySignatureConversion(
966 Block *block, const TypeConverter *converter,
967 TypeConverter::SignatureConversion &signatureConversion);
968
969 /// Replace the results of the given operation with the given values and
970 /// erase the operation.
971 ///
972 /// There can be multiple replacement values for each result (1:N
973 /// replacement). If the replacement values are empty, the respective result
974 /// is dropped and a source materialization is built if the result still has
975 /// uses.
976 void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
977
978 /// Replace the uses of the given value with the given values. The specified
979 /// converter is used to build materializations (if necessary).
980 void replaceAllUsesWith(Value from, ValueRange to,
981 const TypeConverter *converter);
982
983 /// Erase the given block and its contents.
984 void eraseBlock(Block *block);
985
986 /// Inline the source block into the destination block before the given
987 /// iterator.
988 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
989
990 //===--------------------------------------------------------------------===//
991 // Materializations
992 //===--------------------------------------------------------------------===//
993
994 /// Build an unresolved materialization operation given a range of output
995 /// types and a list of input operands. Returns the inputs if they their
996 /// types match the output types.
997 ///
998 /// If a cast op was built, it can optionally be returned with the `castOp`
999 /// output argument.
1000 ///
1001 /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
1002 /// the results of the unresolved materialization in the conversion value
1003 /// mapping.
1004 ///
1005 /// If `isPureTypeConversion` is "true", the materialization is created only
1006 /// to resolve a type mismatch. That means it is not a regular value
1007 /// replacement issued by the user. (Replacement values that are created
1008 /// "out of thin air" appear like unresolved materializations because they are
1009 /// unrealized_conversion_cast ops. However, they must be treated like
1010 /// regular value replacements.)
1011 ValueRange buildUnresolvedMaterialization(
1012 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1013 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1014 Type originalType, const TypeConverter *converter,
1015 bool isPureTypeConversion = true);
1016
1017 /// Find a replacement value for the given SSA value in the conversion value
1018 /// mapping. The replacement value must have the same type as the given SSA
1019 /// value. If there is no replacement value with the correct type, find the
1020 /// latest replacement value (regardless of the type) and build a source
1021 /// materialization.
1022 Value findOrBuildReplacementValue(Value value,
1023 const TypeConverter *converter);
1024
1025 //===--------------------------------------------------------------------===//
1026 // Rewriter Notification Hooks
1027 //===--------------------------------------------------------------------===//
1028
1029 //// Notifies that an op was inserted.
1030 void notifyOperationInserted(Operation *op,
1031 OpBuilder::InsertPoint previous) override;
1032
1033 /// Notifies that a block was inserted.
1034 void notifyBlockInserted(Block *block, Region *previous,
1035 Region::iterator previousIt) override;
1036
1037 /// Notifies that a pattern match failed for the given reason.
1038 void
1039 notifyMatchFailure(Location loc,
1040 function_ref<void(Diagnostic &)> reasonCallback) override;
1041
1042 //===--------------------------------------------------------------------===//
1043 // IR Erasure
1044 //===--------------------------------------------------------------------===//
1045
1046 /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1047 /// operation or block is erased multiple times. This rewriter assumes that
1048 /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1050 public:
1053 std::function<void(Operation *)> opErasedCallback = nullptr)
1054 : RewriterBase(context, /*listener=*/this),
1055 opErasedCallback(std::move(opErasedCallback)) {}
1056
1057 /// Erase the given op (unless it was already erased).
1058 void eraseOp(Operation *op) override {
1059 if (wasErased(op))
1060 return;
1061 op->dropAllUses();
1063 }
1064
1065 /// Erase the given block (unless it was already erased).
1066 void eraseBlock(Block *block) override {
1067 if (wasErased(block))
1068 return;
1069 assert(block->empty() && "expected empty block");
1070 block->dropAllDefinedValueUses();
1072 }
1073
1074 bool wasErased(void *ptr) const { return erased.contains(ptr); }
1075
1077 erased.insert(op);
1078 if (opErasedCallback)
1079 opErasedCallback(op);
1080 }
1081
1082 void notifyBlockErased(Block *block) override { erased.insert(block); }
1083
1084 private:
1085 /// Pointers to all erased operations and blocks.
1086 DenseSet<void *> erased;
1087
1088 /// A callback that is invoked when an operation is erased.
1089 std::function<void(Operation *)> opErasedCallback;
1090 };
1091
1092 //===--------------------------------------------------------------------===//
1093 // State
1094 //===--------------------------------------------------------------------===//
1095
1096 /// The rewriter that is used to perform the conversion.
1097 ConversionPatternRewriter &rewriter;
1098
1099 // Mapping between replaced values that differ in type. This happens when
1100 // replacing a value with one of a different type.
1101 ConversionValueMapping mapping;
1102
1103 /// Ordered list of block operations (creations, splits, motions).
1104 /// This vector is maintained only if `allowPatternRollback` is set to
1105 /// "true". Otherwise, all IR rewrites are materialized immediately and no
1106 /// bookkeeping is needed.
1108
1109 /// A set of operations that should no longer be considered for legalization.
1110 /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1111 /// tracked separately.
1113
1114 /// A set of operations that were replaced/erased. Such ops are not erased
1115 /// immediately but only when the dialect conversion succeeds. In the mean
1116 /// time, they should no longer be considered for legalization and any attempt
1117 /// to modify/access them is invalid rewriter API usage.
1119
1120 /// A set of operations that were created by the current pattern.
1122
1123 /// A set of operations that were modified by the current pattern.
1125
1126 /// A list of unresolved materializations that were created by the current
1127 /// pattern.
1129
1130 /// A mapping for looking up metadata of unresolved materializations.
1133
1134 /// The current type converter, or nullptr if no type converter is currently
1135 /// active.
1137
1138 /// A mapping of regions to type converters that should be used when
1139 /// converting the arguments of blocks within that region.
1141
1142 /// Dialect conversion configuration.
1143 const ConversionConfig &config;
1144
1145 /// The operation converter to use for recursive legalization.
1147
1148 /// A set of erased operations. This set is utilized only if
1149 /// `allowPatternRollback` is set to "false". Conceptually, this set is
1150 /// similar to `replacedOps` (which is maintained when the flag is set to
1151 /// "true"). However, erasing from a DenseSet is more efficient than erasing
1152 /// from a SetVector.
1154
1155 /// A set of erased blocks. This set is utilized only if
1156 /// `allowPatternRollback` is set to "false".
1158
1159 /// A rewriter that notifies the listener (if any) about all IR
1160 /// modifications. This rewriter is utilized only if `allowPatternRollback`
1161 /// is set to "false". If the flag is set to "true", the listener is notified
1162 /// with a separate mechanism (e.g., in `IRRewrite::commit`).
1164
1165#ifndef NDEBUG
1166 /// A set of replaced values. This set is for debugging purposes only and it
1167 /// is maintained only if `allowPatternRollback` is set to "true".
1169
1170 /// A set of operations that have pending updates. This tracking isn't
1171 /// strictly necessary, and is thus only active during debug builds for extra
1172 /// verification.
1174
1175 /// A raw output stream used to prefix the debug log.
1176 llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
1177 llvm::dbgs()};
1178
1179 /// A logger used to emit diagnostics during the conversion process.
1180 llvm::ScopedPrinter logger{os};
1181 std::string logPrefix;
1182#endif
1183};
1184} // namespace detail
1185} // namespace mlir
1186
1187const ConversionConfig &IRRewrite::getConfig() const {
1188 return rewriterImpl.config;
1189}
1190
1191void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1192 // Inform the listener about all IR modifications that have already taken
1193 // place: References to the original block have been replaced with the new
1194 // block.
1195 if (auto *listener =
1196 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1197 for (Operation *op : getNewBlock()->getUsers())
1198 listener->notifyOperationModified(op);
1199}
1200
1201void BlockTypeConversionRewrite::rollback() {
1202 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1203}
1204
1205/// Replace all uses of `from` with `repl`.
1206static void performReplaceValue(RewriterBase &rewriter, Value from,
1207 Value repl) {
1208 if (isa<BlockArgument>(repl)) {
1209 // `repl` is a block argument. Directly replace all uses.
1210 rewriter.replaceAllUsesWith(from, repl);
1211 return;
1212 }
1213
1214 // If the replacement value is an operation, only replace those uses that:
1215 // - are in a different block than the replacement operation, or
1216 // - are in the same block but after the replacement operation.
1217 //
1218 // Example:
1219 // ^bb0(%arg0: i32):
1220 // %0 = "consumer"(%arg0) : (i32) -> (i32)
1221 // "another_consumer"(%arg0) : (i32) -> ()
1222 //
1223 // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1224 // use in "another_consumer" but not the use in "consumer". When using the
1225 // normal RewriterBase API, this would typically be done with
1226 // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1227 // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1228 // it cannot be supported efficiently with `allowPatternRollback` set to
1229 // "true". Therefore, the conversion driver is trying to be smart and replaces
1230 // only those uses that do not lead to a dominance violation. E.g., the
1231 // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1232 // behavior.
1233 //
1234 // TODO: As we move more and more towards `allowPatternRollback` set to
1235 // "false", we should remove this special handling, in order to align the
1236 // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1237 Operation *replOp = repl.getDefiningOp();
1238 Block *replBlock = replOp->getBlock();
1239 rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
1240 Operation *user = operand.getOwner();
1241 return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1242 });
1243}
1244
1245void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1246 Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
1247 if (!repl)
1248 return;
1249 performReplaceValue(rewriter, value, repl);
1250}
1251
1252void ReplaceValueRewrite::rollback() {
1253 rewriterImpl.mapping.erase({value});
1254#ifndef NDEBUG
1255 rewriterImpl.replacedValues.erase(value);
1256#endif // NDEBUG
1257}
1258
1259void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1260 auto *listener =
1261 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1262
1263 // Compute replacement values.
1264 SmallVector<Value> replacements =
1265 llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1266 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1267 });
1268
1269 // Notify the listener that the operation is about to be replaced.
1270 if (listener)
1271 listener->notifyOperationReplaced(op, replacements);
1272
1273 // Replace all uses with the new values.
1274 for (auto [result, newValue] :
1275 llvm::zip_equal(op->getResults(), replacements))
1276 if (newValue)
1277 rewriter.replaceAllUsesWith(result, newValue);
1278
1279 // The original op will be erased, so remove it from the set of unlegalized
1280 // ops.
1281 if (getConfig().unlegalizedOps)
1282 getConfig().unlegalizedOps->erase(op);
1283
1284 // Notify the listener that the operation and its contents are being erased.
1285 if (listener)
1286 notifyIRErased(listener, *op);
1287
1288 // Do not erase the operation yet. It may still be referenced in `mapping`.
1289 // Just unlink it for now and erase it during cleanup.
1290 op->getBlock()->getOperations().remove(op);
1291}
1292
1293void ReplaceOperationRewrite::rollback() {
1294 for (auto result : op->getResults())
1295 rewriterImpl.mapping.erase({result});
1296}
1297
1298void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1299 rewriter.eraseOp(op);
1300}
1301
1302void CreateOperationRewrite::rollback() {
1303 for (Region &region : op->getRegions()) {
1304 while (!region.getBlocks().empty())
1305 region.getBlocks().remove(region.getBlocks().begin());
1306 }
1307 op->dropAllUses();
1308 op->erase();
1309}
1310
1311void UnresolvedMaterializationRewrite::rollback() {
1312 if (!mappedValues.empty())
1313 rewriterImpl.mapping.erase(mappedValues);
1314 rewriterImpl.unresolvedMaterializations.erase(getOperation());
1315 op->erase();
1316}
1317
1319 // Commit all rewrites. Use a new rewriter, so the modifications are not
1320 // tracked for rollback purposes etc.
1321 IRRewriter irRewriter(rewriter.getContext(), config.listener);
1322 // Note: New rewrites may be added during the "commit" phase and the
1323 // `rewrites` vector may reallocate.
1324 for (size_t i = 0; i < rewrites.size(); ++i)
1325 rewrites[i]->commit(irRewriter);
1326
1327 // Clean up all rewrites.
1328 SingleEraseRewriter eraseRewriter(
1329 rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
1330 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1331 unresolvedMaterializations.erase(castOp);
1332 });
1333 for (auto &rewrite : rewrites)
1334 rewrite->cleanup(eraseRewriter);
1335}
1336
1337//===----------------------------------------------------------------------===//
1338// State Management
1339//===----------------------------------------------------------------------===//
1340
1342 Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1343 // Helper function that looks up a single value.
1344 auto lookup = [&](const ValueVector &values) -> ValueVector {
1345 assert(!values.empty() && "expected non-empty value vector");
1346
1347 // If the pattern rollback is enabled, use the mapping to look up the
1348 // values.
1349 if (config.allowPatternRollback)
1350 return mapping.lookup(values);
1351
1352 // Otherwise, look up values by examining the IR. All replacements have
1353 // already been materialized in IR.
1354 Operation *op = getCommonDefiningOp(values);
1355 if (!op)
1356 return {};
1357 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1358 if (!castOp)
1359 return {};
1360 if (!this->unresolvedMaterializations.contains(castOp))
1361 return {};
1362 if (castOp.getOutputs() != values)
1363 return {};
1364 return castOp.getInputs();
1365 };
1366
1367 // Helper function that looks up each value in `values` individually and then
1368 // composes the results. If that fails, it tries to look up the entire vector
1369 // at once.
1370 auto composedLookup = [&](const ValueVector &values) -> ValueVector {
1371 // If possible, replace each value with (one or multiple) mapped values.
1372 ValueVector next;
1373 for (Value v : values) {
1374 ValueVector r = lookup({v});
1375 if (!r.empty()) {
1376 llvm::append_range(next, r);
1377 } else {
1378 next.push_back(v);
1379 }
1380 }
1381 if (next != values) {
1382 // At least one value was replaced.
1383 return next;
1384 }
1385
1386 // Otherwise: Check if there is a mapping for the entire vector. Such
1387 // mappings are materializations. (N:M mapping are not supported for value
1388 // replacements.)
1389 //
1390 // Note: From a correctness point of view, materializations do not have to
1391 // be stored (and looked up) in the mapping. But for performance reasons,
1392 // we choose to reuse existing IR (when possible) instead of creating it
1393 // multiple times.
1394 ValueVector r = lookup(values);
1395 if (r.empty()) {
1396 // No mapping found: The lookup stops here.
1397 return {};
1398 }
1399 return r;
1400 };
1401
1402 // Try to find the deepest values that have the desired types. If there is no
1403 // such mapping, simply return the deepest values.
1404 ValueVector desiredValue;
1405 ValueVector current{from};
1406 ValueVector lastNonMaterialization{from};
1407 do {
1408 // Store the current value if the types match.
1409 bool match = TypeRange(ValueRange(current)) == desiredTypes;
1410 if (skipPureTypeConversions) {
1411 // Skip pure type conversions, if requested.
1412 bool pureConversion = isPureTypeConversion(current);
1413 match &= !pureConversion;
1414 // Keep track of the last mapped value that was not a pure type
1415 // conversion.
1416 if (!pureConversion)
1417 lastNonMaterialization = current;
1418 }
1419 if (match)
1420 desiredValue = current;
1421
1422 // Lookup next value in the mapping.
1423 ValueVector next = composedLookup(current);
1424 if (next.empty())
1425 break;
1426 current = std::move(next);
1427 } while (true);
1428
1429 // If the desired values were found use them, otherwise default to the leaf
1430 // values. (Skip pure type conversions, if requested.)
1431 if (!desiredTypes.empty())
1432 return desiredValue;
1433 if (skipPureTypeConversions)
1434 return lastNonMaterialization;
1435 return current;
1436}
1437
1440 TypeRange desiredTypes) const {
1441 ValueVector result = lookupOrDefault(from, desiredTypes);
1442 if (result == ValueVector{from} ||
1443 (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
1444 return {};
1445 return result;
1446}
1447
1449 return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1450}
1451
1453 StringRef patternName) {
1454 // Undo any rewrites.
1455 undoRewrites(state.numRewrites, patternName);
1456
1457 // Pop all of the recorded ignored operations that are no longer valid.
1458 while (ignoredOps.size() != state.numIgnoredOperations)
1459 ignoredOps.pop_back();
1460
1461 while (replacedOps.size() != state.numReplacedOps)
1462 replacedOps.pop_back();
1463}
1464
1465void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1466 StringRef patternName) {
1467 for (auto &rewrite :
1468 llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1469 rewrite->rollback();
1470 rewrites.resize(numRewritesToKeep);
1471}
1472
1474 StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
1475 SmallVector<ValueVector> &remapped) {
1476 remapped.reserve(llvm::size(values));
1477
1478 for (const auto &it : llvm::enumerate(values)) {
1479 Value operand = it.value();
1480 Type origType = operand.getType();
1481 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1482
1483 if (!currentTypeConverter) {
1484 // The current pattern does not have a type converter. Pass the most
1485 // recently mapped values, excluding materializations. Materializations
1486 // are intentionally excluded because their presence may depend on other
1487 // patterns. Including materializations would make the lookup fragile
1488 // and unpredictable.
1489 remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1490 /*skipPureTypeConversions=*/true));
1491 continue;
1492 }
1493
1494 // If there is no legal conversion, fail to match this pattern.
1495 SmallVector<Type, 1> legalTypes;
1496 if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
1497 notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1498 diag << "unable to convert type for " << valueDiagTag << " #"
1499 << it.index() << ", type was " << origType;
1500 });
1501 return failure();
1502 }
1503 // If a type is converted to 0 types, there is nothing to do.
1504 if (legalTypes.empty()) {
1505 remapped.push_back({});
1506 continue;
1507 }
1508
1509 ValueVector repl = lookupOrDefault(operand, legalTypes);
1510 if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1511 // Mapped values have the correct type or there is an existing
1512 // materialization. Or the operand is not mapped at all and has the
1513 // correct type.
1514 remapped.push_back(std::move(repl));
1515 continue;
1516 }
1517
1518 // Create a materialization for the most recently mapped values.
1519 repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1520 /*skipPureTypeConversions=*/true);
1522 MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1523 /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1524 /*originalType=*/origType, currentTypeConverter);
1525 remapped.push_back(castValues);
1526 }
1527 return success();
1528}
1529
1531 // Check to see if this operation is ignored or was replaced.
1532 return wasOpReplaced(op) || ignoredOps.count(op);
1533}
1534
1536 // Check to see if this operation was replaced.
1537 return replacedOps.count(op) || erasedOps.count(op);
1538}
1539
1540//===----------------------------------------------------------------------===//
1541// Type Conversion
1542//===----------------------------------------------------------------------===//
1543
1545 Region *region, const TypeConverter &converter,
1546 TypeConverter::SignatureConversion *entryConversion) {
1547 regionToConverter[region] = &converter;
1548 if (region->empty())
1549 return nullptr;
1550
1551 // Convert the arguments of each non-entry block within the region.
1552 for (Block &block :
1553 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1554 // Compute the signature for the block with the provided converter.
1555 std::optional<TypeConverter::SignatureConversion> conversion =
1556 converter.convertBlockSignature(&block);
1557 if (!conversion)
1558 return failure();
1559 // Convert the block with the computed signature.
1560 applySignatureConversion(&block, &converter, *conversion);
1561 }
1562
1563 // Convert the entry block. If an entry signature conversion was provided,
1564 // use that one. Otherwise, compute the signature with the type converter.
1565 if (entryConversion)
1566 return applySignatureConversion(&region->front(), &converter,
1567 *entryConversion);
1568 std::optional<TypeConverter::SignatureConversion> conversion =
1569 converter.convertBlockSignature(&region->front());
1570 if (!conversion)
1571 return failure();
1572 return applySignatureConversion(&region->front(), &converter, *conversion);
1573}
1574
1576 Block *block, const TypeConverter *converter,
1577 TypeConverter::SignatureConversion &signatureConversion) {
1578#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1579 // A block cannot be converted multiple times.
1580 if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1581 llvm::report_fatal_error("block was already converted");
1582#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1583
1585
1586 // If no arguments are being changed or added, there is nothing to do.
1587 unsigned origArgCount = block->getNumArguments();
1588 auto convertedTypes = signatureConversion.getConvertedTypes();
1589 if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1590 return block;
1591
1592 // Compute the locations of all block arguments in the new block.
1593 SmallVector<Location> newLocs(convertedTypes.size(),
1594 rewriter.getUnknownLoc());
1595 for (unsigned i = 0; i < origArgCount; ++i) {
1596 auto inputMap = signatureConversion.getInputMapping(i);
1597 if (!inputMap || inputMap->replacedWithValues())
1598 continue;
1599 Location origLoc = block->getArgument(i).getLoc();
1600 for (unsigned j = 0; j < inputMap->size; ++j)
1601 newLocs[inputMap->inputNo + j] = origLoc;
1602 }
1603
1604 // Insert a new block with the converted block argument types and move all ops
1605 // from the old block to the new block.
1606 Block *newBlock =
1607 rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1608 convertedTypes, newLocs);
1609
1610 // If a listener is attached to the dialect conversion, ops cannot be moved
1611 // to the destination block in bulk ("fast path"). This is because at the time
1612 // the notifications are sent, it is unknown which ops were moved. Instead,
1613 // ops should be moved one-by-one ("slow path"), so that a separate
1614 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1615 // a bit more efficient, so we try to do that when possible.
1616 bool fastPath = !config.listener;
1617 if (fastPath) {
1618 if (config.allowPatternRollback)
1619 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1620 newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1621 } else {
1622 while (!block->empty())
1623 rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1624 }
1625
1626 // Replace all uses of the old block with the new block.
1627 block->replaceAllUsesWith(newBlock);
1628
1629 for (unsigned i = 0; i != origArgCount; ++i) {
1630 BlockArgument origArg = block->getArgument(i);
1631 Type origArgType = origArg.getType();
1632
1633 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1634 signatureConversion.getInputMapping(i);
1635 if (!inputMap) {
1636 // This block argument was dropped and no replacement value was provided.
1637 // Materialize a replacement value "out of thin air".
1638 // Note: Materialization must be built here because we cannot find a
1639 // valid insertion point in the new block. (Will point to the old block.)
1640 Value mat =
1642 MaterializationKind::Source,
1643 OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1644 origArg.getLoc(),
1645 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1646 /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1647 /*isPureTypeConversion=*/false)
1648 .front();
1649 replaceAllUsesWith(origArg, mat, converter);
1650 continue;
1651 }
1652
1653 if (inputMap->replacedWithValues()) {
1654 // This block argument was dropped and replacement values were provided.
1655 assert(inputMap->size == 0 &&
1656 "invalid to provide a replacement value when the argument isn't "
1657 "dropped");
1658 replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
1659 continue;
1660 }
1661
1662 // This is a 1->1+ mapping.
1663 auto replArgs =
1664 newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1665 replaceAllUsesWith(origArg, replArgs, converter);
1666 }
1667
1668 if (config.allowPatternRollback)
1669 appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1670
1671 // Erase the old block. (It is just unlinked for now and will be erased during
1672 // cleanup.)
1673 rewriter.eraseBlock(block);
1674
1675 return newBlock;
1676}
1677
1678//===----------------------------------------------------------------------===//
1679// Materializations
1680//===----------------------------------------------------------------------===//
1681
1682/// Build an unresolved materialization operation given an output type and set
1683/// of input operands.
1685 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1686 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1687 Type originalType, const TypeConverter *converter,
1688 bool isPureTypeConversion) {
1689 assert((!originalType || kind == MaterializationKind::Target) &&
1690 "original type is valid only for target materializations");
1691 assert(TypeRange(inputs) != outputTypes &&
1692 "materialization is not necessary");
1693
1694 // Create an unresolved materialization. We use a new OpBuilder to avoid
1695 // tracking the materialization like we do for other operations.
1696 OpBuilder builder(outputTypes.front().getContext());
1697 builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1698 UnrealizedConversionCastOp convertOp =
1699 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1700 if (config.attachDebugMaterializationKind) {
1701 StringRef kindStr =
1702 kind == MaterializationKind::Source ? "source" : "target";
1703 convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
1704 }
1706 convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
1707
1708 // Register the materialization.
1709 unresolvedMaterializations[convertOp] =
1710 UnresolvedMaterializationInfo(converter, kind, originalType);
1711 if (config.allowPatternRollback) {
1712 if (!valuesToMap.empty())
1713 mapping.map(valuesToMap, convertOp.getResults());
1715 std::move(valuesToMap));
1716 } else {
1717 patternMaterializations.insert(convertOp);
1718 }
1719 return convertOp.getResults();
1720}
1721
1723 Value value, const TypeConverter *converter) {
1724 assert(config.allowPatternRollback &&
1725 "this code path is valid only in rollback mode");
1726
1727 // Try to find a replacement value with the same type in the conversion value
1728 // mapping. This includes cached materializations. We try to reuse those
1729 // instead of generating duplicate IR.
1730 ValueVector repl = lookupOrNull(value, value.getType());
1731 if (!repl.empty())
1732 return repl.front();
1733
1734 // Check if the value is dead. No replacement value is needed in that case.
1735 // This is an approximate check that may have false negatives but does not
1736 // require computing and traversing an inverse mapping. (We may end up
1737 // building source materializations that are never used and that fold away.)
1738 if (llvm::all_of(value.getUsers(),
1739 [&](Operation *op) { return replacedOps.contains(op); }) &&
1740 !mapping.isMappedTo(value))
1741 return Value();
1742
1743 // No replacement value was found. Get the latest replacement value
1744 // (regardless of the type) and build a source materialization to the
1745 // original type.
1746 repl = lookupOrNull(value);
1747
1748 // Compute the insertion point of the materialization.
1750 if (repl.empty()) {
1751 // The source materialization has no inputs. Insert it right before the
1752 // value that it is replacing.
1753 ip = computeInsertPoint(value);
1754 } else {
1755 // Compute the "earliest" insertion point at which all values in `repl` are
1756 // defined. It is important to emit the materialization at that location
1757 // because the same materialization may be reused in a different context.
1758 // (That's because materializations are cached in the conversion value
1759 // mapping.) The insertion point of the materialization must be valid for
1760 // all future users that may be created later in the conversion process.
1761 ip = computeInsertPoint(repl);
1762 }
1764 MaterializationKind::Source, ip, value.getLoc(),
1765 /*valuesToMap=*/repl, /*inputs=*/repl,
1766 /*outputTypes=*/value.getType(),
1767 /*originalType=*/Type(), converter,
1768 /*isPureTypeConversion=*/!repl.empty())
1769 .front();
1770 return castValue;
1771}
1772
1773//===----------------------------------------------------------------------===//
1774// Rewriter Notification Hooks
1775//===----------------------------------------------------------------------===//
1776
1778 Operation *op, OpBuilder::InsertPoint previous) {
1779 // If no previous insertion point is provided, the op used to be detached.
1780 bool wasDetached = !previous.isSet();
1781 LLVM_DEBUG({
1782 logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
1783 << ")";
1784 if (wasDetached)
1785 logger.getOStream() << " (was detached)";
1786 logger.getOStream() << "\n";
1787 });
1788
1789 // In rollback mode, it is easier to misuse the API, so perform extra error
1790 // checking.
1791 assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
1792 "attempting to insert into a block within a replaced/erased op");
1793
1794 // In "no rollback" mode, the listener is always notified immediately.
1795 if (!config.allowPatternRollback && config.listener)
1796 config.listener->notifyOperationInserted(op, previous);
1797
1798 if (wasDetached) {
1799 // If the op was detached, it is most likely a newly created op. Add it the
1800 // set of newly created ops, so that it will be legalized. If this op is
1801 // not a newly created op, it will be legalized a second time, which is
1802 // inefficient but harmless.
1803 patternNewOps.insert(op);
1804
1805 if (config.allowPatternRollback) {
1806 // TODO: If the same op is inserted multiple times from a detached
1807 // state, the rollback mechanism may erase the same op multiple times.
1808 // This is a bug in the rollback-based dialect conversion driver.
1810 } else {
1811 // In "no rollback" mode, there is an extra data structure for tracking
1812 // erased operations that must be kept up to date.
1813 erasedOps.erase(op);
1814 }
1815 return;
1816 }
1817
1818 // The op was moved from one place to another.
1819 if (config.allowPatternRollback)
1821}
1822
1823/// Given that `fromRange` is about to be replaced with `toRange`, compute
1824/// replacement values with the types of `fromRange`.
1825static SmallVector<Value>
1827 const SmallVector<SmallVector<Value>> &toRange,
1828 const TypeConverter *converter) {
1829 assert(!impl.config.allowPatternRollback &&
1830 "this code path is valid only in 'no rollback' mode");
1831 SmallVector<Value> repls;
1832 for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1833 if (from.use_empty()) {
1834 // The replaced value is dead. No replacement value is needed.
1835 repls.push_back(Value());
1836 continue;
1837 }
1838
1839 if (to.empty()) {
1840 // The replaced value is dropped. Materialize a replacement value "out of
1841 // thin air".
1842 Value srcMat = impl.buildUnresolvedMaterialization(
1843 MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
1844 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1845 /*outputTypes=*/from.getType(), /*originalType=*/Type(),
1846 converter)[0];
1847 repls.push_back(srcMat);
1848 continue;
1849 }
1850
1851 if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
1852 // The replacement value already has the correct type. Use it directly.
1853 repls.push_back(to[0]);
1854 continue;
1855 }
1856
1857 // The replacement value has the wrong type. Build a source materialization
1858 // to the original type.
1859 // TODO: This is a bit inefficient. We should try to reuse existing
1860 // materializations if possible. This would require an extension of the
1861 // `lookupOrDefault` API.
1862 Value srcMat = impl.buildUnresolvedMaterialization(
1863 MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
1864 /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
1865 /*originalType=*/Type(), converter)[0];
1866 repls.push_back(srcMat);
1867 }
1868
1869 return repls;
1870}
1871
1873 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1874 assert(newValues.size() == op->getNumResults() &&
1875 "incorrect number of replacement values");
1876 LLVM_DEBUG({
1877 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
1878 << ")\n";
1880 // If the user-provided replacement types are different from the
1881 // legalized types, as per the current type converter, print a note.
1882 // In most cases, the replacement types are expected to match the types
1883 // produced by the type converter, so this could indicate a bug in the
1884 // user code.
1885 for (auto [result, repls] :
1886 llvm::zip_equal(op->getResults(), newValues)) {
1887 Type resultType = result.getType();
1888 auto logProlog = [&, repls = repls]() {
1889 logger.startLine() << " Note: Replacing op result of type "
1890 << resultType << " with value(s) of type (";
1891 llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) {
1892 logger.getOStream() << v.getType();
1893 });
1894 logger.getOStream() << ")";
1895 };
1896 SmallVector<Type> convertedTypes;
1897 if (failed(currentTypeConverter->convertTypes(resultType,
1898 convertedTypes))) {
1899 logProlog();
1900 logger.getOStream() << ", but the type converter failed to legalize "
1901 "the original type.\n";
1902 continue;
1903 }
1904 if (TypeRange(convertedTypes) != TypeRange(ValueRange(repls))) {
1905 logProlog();
1906 logger.getOStream() << ", but the legalized type(s) is/are (";
1907 llvm::interleaveComma(convertedTypes, logger.getOStream(),
1908 [&](Type t) { logger.getOStream() << t; });
1909 logger.getOStream() << ")\n";
1910 }
1911 }
1912 }
1913 });
1914
1915 if (!config.allowPatternRollback) {
1916 // Pattern rollback is not allowed: materialize all IR changes immediately.
1918 *this, op->getResults(), newValues, currentTypeConverter);
1919 // Update internal data structures, so that there are no dangling pointers
1920 // to erased IR.
1921 op->walk([&](Operation *op) {
1922 erasedOps.insert(op);
1923 ignoredOps.remove(op);
1924 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1925 unresolvedMaterializations.erase(castOp);
1926 patternMaterializations.erase(castOp);
1927 }
1928 // The original op will be erased, so remove it from the set of
1929 // unlegalized ops.
1930 if (config.unlegalizedOps)
1931 config.unlegalizedOps->erase(op);
1932 });
1933 op->walk([&](Block *block) { erasedBlocks.insert(block); });
1934 // Replace the op with the replacement values and notify the listener.
1935 notifyingRewriter.replaceOp(op, repls);
1936 return;
1937 }
1938
1939 assert(!ignoredOps.contains(op) && "operation was already replaced");
1940#ifndef NDEBUG
1941 for (Value v : op->getResults())
1942 assert(!replacedValues.contains(v) &&
1943 "attempting to replace a value that was already replaced");
1944#endif // NDEBUG
1945
1946 // Check if replaced op is an unresolved materialization, i.e., an
1947 // unrealized_conversion_cast op that was created by the conversion driver.
1948 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1949 // Make sure that the user does not mess with unresolved materializations
1950 // that were inserted by the conversion driver. We keep track of these
1951 // ops in internal data structures.
1952 assert(!unresolvedMaterializations.contains(castOp) &&
1953 "attempting to replace/erase an unresolved materialization");
1954 }
1955
1956 // Create mappings for each of the new result values.
1957 for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
1958 mapping.map(static_cast<Value>(result), std::move(repl));
1959
1961 // Mark this operation and all nested ops as replaced.
1962 op->walk([&](Operation *op) { replacedOps.insert(op); });
1963}
1964
1966 Value from, ValueRange to, const TypeConverter *converter) {
1967 if (!config.allowPatternRollback) {
1968 SmallVector<Value> toConv = llvm::to_vector(to);
1969 SmallVector<Value> repls =
1970 getReplacementValues(*this, from, {toConv}, converter);
1971 IRRewriter r(from.getContext());
1972 Value repl = repls.front();
1973 if (!repl)
1974 return;
1975
1976 performReplaceValue(r, from, repl);
1977 return;
1978 }
1979
1980#ifndef NDEBUG
1981 // Make sure that a value is not replaced multiple times. In rollback mode,
1982 // `replaceAllUsesWith` replaces not only all current uses of the given value,
1983 // but also all future uses that may be introduced by future pattern
1984 // applications. Therefore, it does not make sense to call
1985 // `replaceAllUsesWith` multiple times with the same value. Doing so would
1986 // overwrite the mapping and mess with the internal state of the dialect
1987 // conversion driver.
1988 assert(!replacedValues.contains(from) &&
1989 "attempting to replace a value that was already replaced");
1990 assert(!wasOpReplaced(from.getDefiningOp()) &&
1991 "attempting to replace a op result that was already replaced");
1992 replacedValues.insert(from);
1993#endif // NDEBUG
1994
1995 mapping.map(from, to);
1996 appendRewrite<ReplaceValueRewrite>(from, converter);
1997}
1998
2000 if (!config.allowPatternRollback) {
2001 // Pattern rollback is not allowed: materialize all IR changes immediately.
2002 // Update internal data structures, so that there are no dangling pointers
2003 // to erased IR.
2004 block->walk([&](Operation *op) {
2005 erasedOps.insert(op);
2006 ignoredOps.remove(op);
2007 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2008 unresolvedMaterializations.erase(castOp);
2009 patternMaterializations.erase(castOp);
2010 }
2011 // The original op will be erased, so remove it from the set of
2012 // unlegalized ops.
2013 if (config.unlegalizedOps)
2014 config.unlegalizedOps->erase(op);
2015 });
2016 block->walk([&](Block *block) { erasedBlocks.insert(block); });
2017 // Erase the block and notify the listener.
2018 notifyingRewriter.eraseBlock(block);
2019 return;
2020 }
2021
2022 assert(!wasOpReplaced(block->getParentOp()) &&
2023 "attempting to erase a block within a replaced/erased op");
2025
2026 // Unlink the block from its parent region. The block is kept in the rewrite
2027 // object and will be actually destroyed when rewrites are applied. This
2028 // allows us to keep the operations in the block live and undo the removal by
2029 // re-inserting the block.
2030 block->getParent()->getBlocks().remove(block);
2031
2032 // Mark all nested ops as erased.
2033 block->walk([&](Operation *op) { replacedOps.insert(op); });
2034}
2035
2037 Block *block, Region *previous, Region::iterator previousIt) {
2038 // If no previous insertion point is provided, the block used to be detached.
2039 bool wasDetached = !previous;
2040 Operation *newParentOp = block->getParentOp();
2041 LLVM_DEBUG(
2042 {
2043 Operation *parent = newParentOp;
2044 if (parent) {
2045 logger.startLine() << "** Insert Block into : '" << parent->getName()
2046 << "' (" << parent << ")";
2047 } else {
2048 logger.startLine()
2049 << "** Insert Block into detached Region (nullptr parent op)";
2050 }
2051 if (wasDetached)
2052 logger.getOStream() << " (was detached)";
2053 logger.getOStream() << "\n";
2054 });
2055
2056 // In rollback mode, it is easier to misuse the API, so perform extra error
2057 // checking.
2058 assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
2059 "attempting to insert into a region within a replaced/erased op");
2060 (void)newParentOp;
2061
2062 // In "no rollback" mode, the listener is always notified immediately.
2063 if (!config.allowPatternRollback && config.listener)
2064 config.listener->notifyBlockInserted(block, previous, previousIt);
2065
2066 if (wasDetached) {
2067 // If the block was detached, it is most likely a newly created block.
2068 if (config.allowPatternRollback) {
2069 // TODO: If the same block is inserted multiple times from a detached
2070 // state, the rollback mechanism may erase the same block multiple times.
2071 // This is a bug in the rollback-based dialect conversion driver.
2073 } else {
2074 // In "no rollback" mode, there is an extra data structure for tracking
2075 // erased blocks that must be kept up to date.
2076 erasedBlocks.erase(block);
2077 }
2078 return;
2079 }
2080
2081 // The block was moved from one place to another.
2082 if (config.allowPatternRollback)
2083 appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
2084}
2085
2087 Block *dest,
2088 Block::iterator before) {
2089 appendRewrite<InlineBlockRewrite>(dest, source, before);
2090}
2091
2093 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
2094 LLVM_DEBUG({
2096 reasonCallback(diag);
2097 logger.startLine() << "** Failure : " << diag.str() << "\n";
2098 if (config.notifyCallback)
2099 config.notifyCallback(diag);
2100 });
2101}
2102
2103//===----------------------------------------------------------------------===//
2104// ConversionPatternRewriter
2105//===----------------------------------------------------------------------===//
2106
2107ConversionPatternRewriter::ConversionPatternRewriter(
2108 MLIRContext *ctx, const ConversionConfig &config,
2109 OperationConverter &opConverter)
2111 *this, config, opConverter)) {
2112 setListener(impl.get());
2113}
2114
2115ConversionPatternRewriter::~ConversionPatternRewriter() = default;
2116
2117const ConversionConfig &ConversionPatternRewriter::getConfig() const {
2118 return impl->config;
2119}
2120
2121void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
2122 assert(op && newOp && "expected non-null op");
2123 replaceOp(op, newOp->getResults());
2124}
2125
2126void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
2127 assert(op->getNumResults() == newValues.size() &&
2128 "incorrect # of replacement values");
2129
2130 // If the current insertion point is before the erased operation, we adjust
2131 // the insertion point to be after the operation.
2132 if (getInsertionPoint() == op->getIterator())
2134
2135 SmallVector<SmallVector<Value>> newVals =
2136 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2137 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2138 });
2139 impl->replaceOp(op, std::move(newVals));
2140}
2141
2142void ConversionPatternRewriter::replaceOpWithMultiple(
2143 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2144 assert(op->getNumResults() == newValues.size() &&
2145 "incorrect # of replacement values");
2146
2147 // If the current insertion point is before the erased operation, we adjust
2148 // the insertion point to be after the operation.
2149 if (getInsertionPoint() == op->getIterator())
2151
2152 impl->replaceOp(op, std::move(newValues));
2153}
2154
2155void ConversionPatternRewriter::eraseOp(Operation *op) {
2156 LLVM_DEBUG({
2157 impl->logger.startLine()
2158 << "** Erase : '" << op->getName() << "'(" << op << ")\n";
2159 });
2160
2161 // If the current insertion point is before the erased operation, we adjust
2162 // the insertion point to be after the operation.
2163 if (getInsertionPoint() == op->getIterator())
2165
2166 SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
2167 impl->replaceOp(op, std::move(nullRepls));
2168}
2169
2170void ConversionPatternRewriter::eraseBlock(Block *block) {
2171 impl->eraseBlock(block);
2172}
2173
2174Block *ConversionPatternRewriter::applySignatureConversion(
2175 Block *block, TypeConverter::SignatureConversion &conversion,
2176 const TypeConverter *converter) {
2177 assert(!impl->wasOpReplaced(block->getParentOp()) &&
2178 "attempting to apply a signature conversion to a block within a "
2179 "replaced/erased op");
2180 return impl->applySignatureConversion(block, converter, conversion);
2181}
2182
2183FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2184 Region *region, const TypeConverter &converter,
2185 TypeConverter::SignatureConversion *entryConversion) {
2186 assert(!impl->wasOpReplaced(region->getParentOp()) &&
2187 "attempting to apply a signature conversion to a block within a "
2188 "replaced/erased op");
2189 return impl->convertRegionTypes(region, converter, entryConversion);
2190}
2191
2192void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
2193 LLVM_DEBUG({
2194 impl->logger.startLine() << "** Replace Value : '" << from << "'";
2195 if (auto blockArg = dyn_cast<BlockArgument>(from)) {
2196 if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
2197 impl->logger.getOStream() << " (in region of '" << parentOp->getName()
2198 << "' (" << parentOp << ")\n";
2199 } else {
2200 impl->logger.getOStream() << " (unlinked block)\n";
2201 }
2202 }
2203 });
2204 impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
2205}
2206
2207Value ConversionPatternRewriter::getRemappedValue(Value key) {
2208 SmallVector<ValueVector> remappedValues;
2209 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
2210 remappedValues)))
2211 return nullptr;
2212 assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
2213 return remappedValues.front().front();
2214}
2215
2216LogicalResult
2217ConversionPatternRewriter::getRemappedValues(ValueRange keys,
2218 SmallVectorImpl<Value> &results) {
2219 if (keys.empty())
2220 return success();
2221 SmallVector<ValueVector> remapped;
2222 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
2223 remapped)))
2224 return failure();
2225 for (const auto &values : remapped) {
2226 assert(values.size() == 1 && "1:N conversion not supported");
2227 results.push_back(values.front());
2228 }
2229 return success();
2230}
2231
2232LogicalResult ConversionPatternRewriter::legalize(Region *r) {
2233 // Fast path: If the region is empty, there is nothing to legalize.
2234 if (r->empty())
2235 return success();
2236
2237 // Gather a list of all operations to legalize. This is done before
2238 // converting the entry block signature because unrealized_conversion_cast
2239 // ops should not be included.
2240 SmallVector<Operation *> ops;
2241 for (Block &b : *r)
2242 for (Operation &op : b)
2243 ops.push_back(&op);
2244
2245 // If the current pattern runs with a type converter, convert the entry block
2246 // signature.
2247 if (const TypeConverter *converter = impl->currentTypeConverter) {
2248 std::optional<TypeConverter::SignatureConversion> conversion =
2249 converter->convertBlockSignature(&r->front());
2250 if (!conversion)
2251 return failure();
2252 applySignatureConversion(&r->front(), *conversion, converter);
2253 }
2254
2255 // Legalize all operations in the region.
2256 for (Operation *op : ops)
2257 if (failed(legalize(op)))
2258 return failure();
2259
2260 return success();
2261}
2262
2263void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
2264 Block::iterator before,
2265 ValueRange argValues) {
2266#ifndef NDEBUG
2267 assert(argValues.size() == source->getNumArguments() &&
2268 "incorrect # of argument replacement values");
2269 assert(!impl->wasOpReplaced(source->getParentOp()) &&
2270 "attempting to inline a block from a replaced/erased op");
2271 assert(!impl->wasOpReplaced(dest->getParentOp()) &&
2272 "attempting to inline a block into a replaced/erased op");
2273 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
2274 // The source block will be deleted, so it should not have any users (i.e.,
2275 // there should be no predecessors).
2276 assert(llvm::all_of(source->getUsers(), opIgnored) &&
2277 "expected 'source' to have no predecessors");
2278#endif // NDEBUG
2279
2280 // If a listener is attached to the dialect conversion, ops cannot be moved
2281 // to the destination block in bulk ("fast path"). This is because at the time
2282 // the notifications are sent, it is unknown which ops were moved. Instead,
2283 // ops should be moved one-by-one ("slow path"), so that a separate
2284 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
2285 // a bit more efficient, so we try to do that when possible.
2286 bool fastPath = !getConfig().listener;
2287
2288 if (fastPath && impl->config.allowPatternRollback)
2289 impl->inlineBlockBefore(source, dest, before);
2290
2291 // Replace all uses of block arguments.
2292 for (auto it : llvm::zip(source->getArguments(), argValues))
2293 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2294
2295 if (fastPath) {
2296 // Move all ops at once.
2297 dest->getOperations().splice(before, source->getOperations());
2298 } else {
2299 // Move op by op.
2300 while (!source->empty())
2301 moveOpBefore(&source->front(), dest, before);
2302 }
2303
2304 // If the current insertion point is within the source block, adjust the
2305 // insertion point to the destination block.
2306 if (getInsertionBlock() == source)
2307 setInsertionPoint(dest, getInsertionPoint());
2308
2309 // Erase the source block.
2310 eraseBlock(source);
2311}
2312
2313void ConversionPatternRewriter::startOpModification(Operation *op) {
2314 if (!impl->config.allowPatternRollback) {
2315 // Pattern rollback is not allowed: no extra bookkeeping is needed.
2317 return;
2318 }
2319 assert(!impl->wasOpReplaced(op) &&
2320 "attempting to modify a replaced/erased op");
2321#ifndef NDEBUG
2322 impl->pendingRootUpdates.insert(op);
2323#endif
2324 impl->appendRewrite<ModifyOperationRewrite>(op);
2325}
2326
2327void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2328 impl->patternModifiedOps.insert(op);
2329 if (!impl->config.allowPatternRollback) {
2331 if (getConfig().listener)
2332 getConfig().listener->notifyOperationModified(op);
2333 return;
2334 }
2335
2336 // There is nothing to do here, we only need to track the operation at the
2337 // start of the update.
2338#ifndef NDEBUG
2339 assert(!impl->wasOpReplaced(op) &&
2340 "attempting to modify a replaced/erased op");
2341 assert(impl->pendingRootUpdates.erase(op) &&
2342 "operation did not have a pending in-place update");
2343#endif
2344}
2345
2346void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2347 if (!impl->config.allowPatternRollback) {
2349 return;
2350 }
2351#ifndef NDEBUG
2352 assert(impl->pendingRootUpdates.erase(op) &&
2353 "operation did not have a pending in-place update");
2354#endif
2355 // Erase the last update for this operation.
2356 auto it = llvm::find_if(
2357 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
2358 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2359 return modifyRewrite && modifyRewrite->getOperation() == op;
2360 });
2361 assert(it != impl->rewrites.rend() && "no root update started on op");
2362 (*it)->rollback();
2363 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2364 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2365}
2366
2367detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2368 return *impl;
2369}
2370
2371//===----------------------------------------------------------------------===//
2372// ConversionPattern
2373//===----------------------------------------------------------------------===//
2374
2375FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2376 ArrayRef<ValueRange> operands) const {
2377 SmallVector<Value> oneToOneOperands;
2378 oneToOneOperands.reserve(operands.size());
2379 for (ValueRange operand : operands) {
2380 if (operand.size() != 1)
2381 return failure();
2382
2383 oneToOneOperands.push_back(operand.front());
2384 }
2385 return std::move(oneToOneOperands);
2386}
2387
2388LogicalResult
2389ConversionPattern::matchAndRewrite(Operation *op,
2390 PatternRewriter &rewriter) const {
2391 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
2392 auto &rewriterImpl = dialectRewriter.getImpl();
2393
2394 // Track the current conversion pattern type converter in the rewriter.
2395 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
2396 getTypeConverter());
2397
2398 // Remap the operands of the operation.
2399 SmallVector<ValueVector> remapped;
2400 if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
2401 op->getOperands(), remapped))) {
2402 return failure();
2403 }
2404 SmallVector<ValueRange> remappedAsRange =
2405 llvm::to_vector_of<ValueRange>(remapped);
2406 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2407}
2408
2409//===----------------------------------------------------------------------===//
2410// OperationLegalizer
2411//===----------------------------------------------------------------------===//
2412
2413namespace {
2414/// A set of rewrite patterns that can be used to legalize a given operation.
2415using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2416
2417/// This class defines a recursive operation legalizer.
2418class OperationLegalizer {
2419public:
2420 using LegalizationAction = ConversionTarget::LegalizationAction;
2421
2422 OperationLegalizer(ConversionPatternRewriter &rewriter,
2423 const ConversionTarget &targetInfo,
2424 const FrozenRewritePatternSet &patterns);
2425
2426 /// Returns true if the given operation is known to be illegal on the target.
2427 bool isIllegal(Operation *op) const;
2428
2429 /// Attempt to legalize the given operation. Returns success if the operation
2430 /// was legalized, failure otherwise.
2431 LogicalResult legalize(Operation *op);
2432
2433 /// Returns the conversion target in use by the legalizer.
2434 const ConversionTarget &getTarget() { return target; }
2435
2436private:
2437 /// Attempt to legalize the given operation by folding it.
2438 LogicalResult legalizeWithFold(Operation *op);
2439
2440 /// Attempt to legalize the given operation by applying a pattern. Returns
2441 /// success if the operation was legalized, failure otherwise.
2442 LogicalResult legalizeWithPattern(Operation *op);
2443
2444 /// Return true if the given pattern may be applied to the given operation,
2445 /// false otherwise.
2446 bool canApplyPattern(Operation *op, const Pattern &pattern);
2447
2448 /// Legalize the resultant IR after successfully applying the given pattern.
2449 LogicalResult
2450 legalizePatternResult(Operation *op, const Pattern &pattern,
2451 const RewriterState &curState,
2452 const SetVector<Operation *> &newOps,
2453 const SetVector<Operation *> &modifiedOps);
2454
2455 LogicalResult
2456 legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
2457 LogicalResult
2458 legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
2459
2460 //===--------------------------------------------------------------------===//
2461 // Cost Model
2462 //===--------------------------------------------------------------------===//
2463
2464 /// Build an optimistic legalization graph given the provided patterns. This
2465 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2466 /// patterns for operations that are not directly legal, but may be
2467 /// transitively legal for the current target given the provided patterns.
2468 void buildLegalizationGraph(
2469 LegalizationPatterns &anyOpLegalizerPatterns,
2471
2472 /// Compute the benefit of each node within the computed legalization graph.
2473 /// This orders the patterns within 'legalizerPatterns' based upon two
2474 /// criteria:
2475 /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2476 /// represent the more direct mapping to the target.
2477 /// 2) When comparing patterns with the same legalization depth, prefer the
2478 /// pattern with the highest PatternBenefit. This allows for users to
2479 /// prefer specific legalizations over others.
2480 void computeLegalizationGraphBenefit(
2481 LegalizationPatterns &anyOpLegalizerPatterns,
2483
2484 /// Compute the legalization depth when legalizing an operation of the given
2485 /// type.
2486 unsigned computeOpLegalizationDepth(
2487 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2489
2490 /// Apply the conversion cost model to the given set of patterns, and return
2491 /// the smallest legalization depth of any of the patterns. See
2492 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2493 unsigned applyCostModelToPatterns(
2494 LegalizationPatterns &patterns,
2495 DenseMap<OperationName, unsigned> &minOpPatternDepth,
2497
2498 /// The current set of patterns that have been applied.
2499 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2500
2501 /// The rewriter to use when converting operations.
2502 ConversionPatternRewriter &rewriter;
2503
2504 /// The legalization information provided by the target.
2505 const ConversionTarget &target;
2506
2507 /// The pattern applicator to use for conversions.
2508 PatternApplicator applicator;
2509};
2510} // namespace
2511
2512OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2513 const ConversionTarget &targetInfo,
2514 const FrozenRewritePatternSet &patterns)
2515 : rewriter(rewriter), target(targetInfo), applicator(patterns) {
2516 // The set of patterns that can be applied to illegal operations to transform
2517 // them into legal ones.
2519 LegalizationPatterns anyOpLegalizerPatterns;
2520
2521 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2522 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2523}
2524
2525bool OperationLegalizer::isIllegal(Operation *op) const {
2526 return target.isIllegal(op);
2527}
2528
2529LogicalResult OperationLegalizer::legalize(Operation *op) {
2530#ifndef NDEBUG
2531 const char *logLineComment =
2532 "//===-------------------------------------------===//\n";
2533
2534 auto &logger = rewriter.getImpl().logger;
2535#endif
2536
2537 // Check to see if the operation is ignored and doesn't need to be converted.
2538 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2539
2540 LLVM_DEBUG({
2541 logger.getOStream() << "\n";
2542 logger.startLine() << logLineComment;
2543 logger.startLine() << "Legalizing operation : ";
2544 // Do not print the operation name if the operation is ignored. Ignored ops
2545 // may have been erased and should not be accessed. The pointer can be
2546 // printed safely.
2547 if (!isIgnored)
2548 logger.getOStream() << "'" << op->getName() << "' ";
2549 logger.getOStream() << "(" << op << ") {\n";
2550 logger.indent();
2551
2552 // If the operation has no regions, just print it here.
2553 if (!isIgnored && op->getNumRegions() == 0) {
2554 logger.startLine() << OpWithFlags(op,
2555 OpPrintingFlags().printGenericOpForm())
2556 << "\n";
2557 }
2558 });
2559
2560 if (isIgnored) {
2561 LLVM_DEBUG({
2562 logSuccess(logger, "operation marked 'ignored' during conversion");
2563 logger.startLine() << logLineComment;
2564 });
2565 return success();
2566 }
2567
2568 // Check if this operation is legal on the target.
2569 if (auto legalityInfo = target.isLegal(op)) {
2570 LLVM_DEBUG({
2571 logSuccess(
2572 logger, "operation marked legal by the target{0}",
2573 legalityInfo->isRecursivelyLegal
2574 ? "; NOTE: operation is recursively legal; skipping internals"
2575 : "");
2576 logger.startLine() << logLineComment;
2577 });
2578
2579 // If this operation is recursively legal, mark its children as ignored so
2580 // that we don't consider them for legalization.
2581 if (legalityInfo->isRecursivelyLegal) {
2582 op->walk([&](Operation *nested) {
2583 if (op != nested)
2584 rewriter.getImpl().ignoredOps.insert(nested);
2585 });
2586 }
2587
2588 return success();
2589 }
2590
2591 // If the operation is not legal, try to fold it in-place if the folding mode
2592 // is 'BeforePatterns'. 'Never' will skip this.
2593 const ConversionConfig &config = rewriter.getConfig();
2594 if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2595 if (succeeded(legalizeWithFold(op))) {
2596 LLVM_DEBUG({
2597 logSuccess(logger, "operation was folded");
2598 logger.startLine() << logLineComment;
2599 });
2600 return success();
2601 }
2602 }
2603
2604 // Otherwise, we need to apply a legalization pattern to this operation.
2605 if (succeeded(legalizeWithPattern(op))) {
2606 LLVM_DEBUG({
2607 logSuccess(logger, "");
2608 logger.startLine() << logLineComment;
2609 });
2610 return success();
2611 }
2612
2613 // If the operation can't be legalized via patterns, try to fold it in-place
2614 // if the folding mode is 'AfterPatterns'.
2615 if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2616 if (succeeded(legalizeWithFold(op))) {
2617 LLVM_DEBUG({
2618 logSuccess(logger, "operation was folded");
2619 logger.startLine() << logLineComment;
2620 });
2621 return success();
2622 }
2623 }
2624
2625 LLVM_DEBUG({
2626 logFailure(logger, "no matched legalization pattern");
2627 logger.startLine() << logLineComment;
2628 });
2629 return failure();
2630}
2631
2632/// Helper function that moves and returns the given object. Also resets the
2633/// original object, so that it is in a valid, empty state again.
2634template <typename T>
2635static T moveAndReset(T &obj) {
2636 T result = std::move(obj);
2637 obj = T();
2638 return result;
2639}
2640
2641LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2642 auto &rewriterImpl = rewriter.getImpl();
2643 LLVM_DEBUG({
2644 rewriterImpl.logger.startLine() << "* Fold {\n";
2645 rewriterImpl.logger.indent();
2646 });
2647
2648 // Clear pattern state, so that the next pattern application starts with a
2649 // clean slate. (The op/block sets are populated by listener notifications.)
2650 auto cleanup = llvm::make_scope_exit([&]() {
2651 rewriterImpl.patternNewOps.clear();
2652 rewriterImpl.patternModifiedOps.clear();
2653 });
2654
2655 // Upon failure, undo all changes made by the folder.
2656 RewriterState curState = rewriterImpl.getCurrentState();
2657
2658 // Try to fold the operation.
2659 StringRef opName = op->getName().getStringRef();
2660 SmallVector<Value, 2> replacementValues;
2661 SmallVector<Operation *, 2> newOps;
2662 rewriter.setInsertionPoint(op);
2663 rewriter.startOpModification(op);
2664 if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2665 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2666 rewriter.cancelOpModification(op);
2667 return failure();
2668 }
2669 rewriter.finalizeOpModification(op);
2670
2671 // An empty list of replacement values indicates that the fold was in-place.
2672 // As the operation changed, a new legalization needs to be attempted.
2673 if (replacementValues.empty())
2674 return legalize(op);
2675
2676 // Insert a replacement for 'op' with the folded replacement values.
2677 rewriter.replaceOp(op, replacementValues);
2678
2679 // Recursively legalize any new constant operations.
2680 for (Operation *newOp : newOps) {
2681 if (failed(legalize(newOp))) {
2682 LLVM_DEBUG(logFailure(rewriterImpl.logger,
2683 "failed to legalize generated constant '{0}'",
2684 newOp->getName()));
2685 if (!rewriter.getConfig().allowPatternRollback) {
2686 // Rolling back a folder is like rolling back a pattern.
2687 llvm::report_fatal_error(
2688 "op '" + opName +
2689 "' folder rollback of IR modifications requested");
2690 }
2691 rewriterImpl.resetState(
2692 curState, std::string(op->getName().getStringRef()) + " folder");
2693 return failure();
2694 }
2695 }
2696
2697 LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2698 return success();
2699}
2700
2701/// Report a fatal error indicating that newly produced or modified IR could
2702/// not be legalized.
2703static void
2705 const SetVector<Operation *> &newOps,
2706 const SetVector<Operation *> &modifiedOps) {
2707 auto newOpNames = llvm::map_range(
2708 newOps, [](Operation *op) { return op->getName().getStringRef(); });
2709 auto modifiedOpNames = llvm::map_range(
2710 modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
2711 llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2712 "' produced IR that could not be legalized. " +
2713 "new ops: {" + llvm::join(newOpNames, ", ") + "}, " +
2714 "modified ops: {" +
2715 llvm::join(modifiedOpNames, ", ") + "}");
2716}
2717
2718LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2719 auto &rewriterImpl = rewriter.getImpl();
2720 const ConversionConfig &config = rewriter.getConfig();
2721
2722#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2723 Operation *checkOp;
2724 std::optional<OperationFingerPrint> topLevelFingerPrint;
2725 if (!rewriterImpl.config.allowPatternRollback) {
2726 // The op may be getting erased, so we have to check the parent op.
2727 // (In rare cases, a pattern may even erase the parent op, which will cause
2728 // a crash here. Expensive checks are "best effort".) Skip the check if the
2729 // op does not have a parent op.
2730 if ((checkOp = op->getParentOp())) {
2731 if (!op->getContext()->isMultithreadingEnabled()) {
2732 topLevelFingerPrint = OperationFingerPrint(checkOp);
2733 } else {
2734 // Another thread may be modifying a sibling operation. Therefore, the
2735 // fingerprinting mechanism of the parent op works only in
2736 // single-threaded mode.
2737 LLVM_DEBUG({
2738 rewriterImpl.logger.startLine()
2739 << "WARNING: Multi-threadeding is enabled. Some dialect "
2740 "conversion expensive checks are skipped in multithreading "
2741 "mode!\n";
2742 });
2743 }
2744 }
2745 }
2746#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2747
2748 // Functor that returns if the given pattern may be applied.
2749 auto canApply = [&](const Pattern &pattern) {
2750 bool canApply = canApplyPattern(op, pattern);
2751 if (canApply && config.listener)
2752 config.listener->notifyPatternBegin(pattern, op);
2753 return canApply;
2754 };
2755
2756 // Functor that cleans up the rewriter state after a pattern failed to match.
2757 RewriterState curState = rewriterImpl.getCurrentState();
2758 auto onFailure = [&](const Pattern &pattern) {
2759 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2760 if (!rewriterImpl.config.allowPatternRollback) {
2761 // Erase all unresolved materializations.
2762 for (auto op : rewriterImpl.patternMaterializations) {
2763 rewriterImpl.unresolvedMaterializations.erase(op);
2764 op.erase();
2765 }
2766 rewriterImpl.patternMaterializations.clear();
2767#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2768 // Expensive pattern check that can detect API violations.
2769 if (checkOp && topLevelFingerPrint) {
2770 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2771 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2772 llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2773 "' returned failure but IR did change");
2774 }
2775#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2776 }
2777 rewriterImpl.patternNewOps.clear();
2778 rewriterImpl.patternModifiedOps.clear();
2779 LLVM_DEBUG({
2780 logFailure(rewriterImpl.logger, "pattern failed to match");
2781 if (rewriterImpl.config.notifyCallback) {
2782 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
2783 diag << "Failed to apply pattern \"" << pattern.getDebugName()
2784 << "\" on op:\n"
2785 << *op;
2786 rewriterImpl.config.notifyCallback(diag);
2787 }
2788 });
2789 if (config.listener)
2790 config.listener->notifyPatternEnd(pattern, failure());
2791 rewriterImpl.resetState(curState, pattern.getDebugName());
2792 appliedPatterns.erase(&pattern);
2793 };
2794
2795 // Functor that performs additional legalization when a pattern is
2796 // successfully applied.
2797 auto onSuccess = [&](const Pattern &pattern) {
2798 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2799 if (!rewriterImpl.config.allowPatternRollback) {
2800 // Eagerly erase unused materializations.
2801 for (auto op : rewriterImpl.patternMaterializations) {
2802 if (op->use_empty()) {
2803 rewriterImpl.unresolvedMaterializations.erase(op);
2804 op.erase();
2805 }
2806 }
2807 rewriterImpl.patternMaterializations.clear();
2808 }
2809 SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2810 SetVector<Operation *> modifiedOps =
2811 moveAndReset(rewriterImpl.patternModifiedOps);
2812 auto result =
2813 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2814 appliedPatterns.erase(&pattern);
2815 if (failed(result)) {
2816 if (!rewriterImpl.config.allowPatternRollback)
2817 reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
2818 rewriterImpl.resetState(curState, pattern.getDebugName());
2819 }
2820 if (config.listener)
2821 config.listener->notifyPatternEnd(pattern, result);
2822 return result;
2823 };
2824
2825 // Try to match and rewrite a pattern on this operation.
2826 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2827 onSuccess);
2828}
2829
2830bool OperationLegalizer::canApplyPattern(Operation *op,
2831 const Pattern &pattern) {
2832 LLVM_DEBUG({
2833 auto &os = rewriter.getImpl().logger;
2834 os.getOStream() << "\n";
2835 os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2836 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2837 os.getOStream() << ")' {\n";
2838 os.indent();
2839 });
2840
2841 // Ensure that we don't cycle by not allowing the same pattern to be
2842 // applied twice in the same recursion stack if it is not known to be safe.
2843 if (!pattern.hasBoundedRewriteRecursion() &&
2844 !appliedPatterns.insert(&pattern).second) {
2845 LLVM_DEBUG(
2846 logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2847 return false;
2848 }
2849 return true;
2850}
2851
2852LogicalResult OperationLegalizer::legalizePatternResult(
2853 Operation *op, const Pattern &pattern, const RewriterState &curState,
2854 const SetVector<Operation *> &newOps,
2855 const SetVector<Operation *> &modifiedOps) {
2856 [[maybe_unused]] auto &impl = rewriter.getImpl();
2857 assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2858
2859#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2860 if (impl.config.allowPatternRollback) {
2861 // Check that the root was either replaced or updated in place.
2862 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2863 auto replacedRoot = [&] {
2864 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2865 };
2866 auto updatedRootInPlace = [&] {
2867 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2868 };
2869 if (!replacedRoot() && !updatedRootInPlace())
2870 llvm::report_fatal_error("expected pattern to replace the root operation "
2871 "or modify it in place");
2872 }
2873#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2874
2875 // Legalize each of the actions registered during application.
2876 if (failed(legalizePatternRootUpdates(modifiedOps)) ||
2877 failed(legalizePatternCreatedOperations(newOps))) {
2878 return failure();
2879 }
2880
2881 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2882 return success();
2883}
2884
2885LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2886 const SetVector<Operation *> &newOps) {
2887 for (Operation *op : newOps) {
2888 if (failed(legalize(op))) {
2889 LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
2890 "failed to legalize generated operation '{0}'({1})",
2891 op->getName(), op));
2892 return failure();
2893 }
2894 }
2895 return success();
2896}
2897
2898LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2899 const SetVector<Operation *> &modifiedOps) {
2900 for (Operation *op : modifiedOps) {
2901 if (failed(legalize(op))) {
2902 LLVM_DEBUG(
2903 logFailure(rewriter.getImpl().logger,
2904 "failed to legalize operation updated in-place '{0}'",
2905 op->getName()));
2906 return failure();
2907 }
2908 }
2909 return success();
2910}
2911
2912//===----------------------------------------------------------------------===//
2913// Cost Model
2914//===----------------------------------------------------------------------===//
2915
2916void OperationLegalizer::buildLegalizationGraph(
2917 LegalizationPatterns &anyOpLegalizerPatterns,
2919 // A mapping between an operation and a set of operations that can be used to
2920 // generate it.
2922 // A mapping between an operation and any currently invalid patterns it has.
2924 // A worklist of patterns to consider for legality.
2925 SetVector<const Pattern *> patternWorklist;
2926
2927 // Build the mapping from operations to the parent ops that may generate them.
2928 applicator.walkAllPatterns([&](const Pattern &pattern) {
2929 std::optional<OperationName> root = pattern.getRootKind();
2930
2931 // If the pattern has no specific root, we can't analyze the relationship
2932 // between the root op and generated operations. Given that, add all such
2933 // patterns to the legalization set.
2934 if (!root) {
2935 anyOpLegalizerPatterns.push_back(&pattern);
2936 return;
2937 }
2938
2939 // Skip operations that are always known to be legal.
2940 if (target.getOpAction(*root) == LegalizationAction::Legal)
2941 return;
2942
2943 // Add this pattern to the invalid set for the root op and record this root
2944 // as a parent for any generated operations.
2945 invalidPatterns[*root].insert(&pattern);
2946 for (auto op : pattern.getGeneratedOps())
2947 parentOps[op].insert(*root);
2948
2949 // Add this pattern to the worklist.
2950 patternWorklist.insert(&pattern);
2951 });
2952
2953 // If there are any patterns that don't have a specific root kind, we can't
2954 // make direct assumptions about what operations will never be legalized.
2955 // Note: Technically we could, but it would require an analysis that may
2956 // recurse into itself. It would be better to perform this kind of filtering
2957 // at a higher level than here anyways.
2958 if (!anyOpLegalizerPatterns.empty()) {
2959 for (const Pattern *pattern : patternWorklist)
2960 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2961 return;
2962 }
2963
2964 while (!patternWorklist.empty()) {
2965 auto *pattern = patternWorklist.pop_back_val();
2966
2967 // Check to see if any of the generated operations are invalid.
2968 if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2969 std::optional<LegalizationAction> action = target.getOpAction(op);
2970 return !legalizerPatterns.count(op) &&
2971 (!action || action == LegalizationAction::Illegal);
2972 }))
2973 continue;
2974
2975 // Otherwise, if all of the generated operation are valid, this op is now
2976 // legal so add all of the child patterns to the worklist.
2977 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2978 invalidPatterns[*pattern->getRootKind()].erase(pattern);
2979
2980 // Add any invalid patterns of the parent operations to see if they have now
2981 // become legal.
2982 for (auto op : parentOps[*pattern->getRootKind()])
2983 patternWorklist.set_union(invalidPatterns[op]);
2984 }
2985}
2986
2987void OperationLegalizer::computeLegalizationGraphBenefit(
2988 LegalizationPatterns &anyOpLegalizerPatterns,
2990 // The smallest pattern depth, when legalizing an operation.
2991 DenseMap<OperationName, unsigned> minOpPatternDepth;
2992
2993 // For each operation that is transitively legal, compute a cost for it.
2994 for (auto &opIt : legalizerPatterns)
2995 if (!minOpPatternDepth.count(opIt.first))
2996 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
2997 legalizerPatterns);
2998
2999 // Apply the cost model to the patterns that can match any operation. Those
3000 // with a specific operation type are already resolved when computing the op
3001 // legalization depth.
3002 if (!anyOpLegalizerPatterns.empty())
3003 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3004 legalizerPatterns);
3005
3006 // Apply a cost model to the pattern applicator. We order patterns first by
3007 // depth then benefit. `legalizerPatterns` contains per-op patterns by
3008 // decreasing benefit.
3009 applicator.applyCostModel([&](const Pattern &pattern) {
3010 ArrayRef<const Pattern *> orderedPatternList;
3011 if (std::optional<OperationName> rootName = pattern.getRootKind())
3012 orderedPatternList = legalizerPatterns[*rootName];
3013 else
3014 orderedPatternList = anyOpLegalizerPatterns;
3015
3016 // If the pattern is not found, then it was removed and cannot be matched.
3017 auto *it = llvm::find(orderedPatternList, &pattern);
3018 if (it == orderedPatternList.end())
3020
3021 // Patterns found earlier in the list have higher benefit.
3022 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3023 });
3024}
3025
3026unsigned OperationLegalizer::computeOpLegalizationDepth(
3027 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
3029 // Check for existing depth.
3030 auto depthIt = minOpPatternDepth.find(op);
3031 if (depthIt != minOpPatternDepth.end())
3032 return depthIt->second;
3033
3034 // If a mapping for this operation does not exist, then this operation
3035 // is always legal. Return 0 as the depth for a directly legal operation.
3036 auto opPatternsIt = legalizerPatterns.find(op);
3037 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3038 return 0u;
3039
3040 // Record this initial depth in case we encounter this op again when
3041 // recursively computing the depth.
3042 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3043
3044 // Apply the cost model to the operation patterns, and update the minimum
3045 // depth.
3046 unsigned minDepth = applyCostModelToPatterns(
3047 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3048 minOpPatternDepth[op] = minDepth;
3049 return minDepth;
3050}
3051
3052unsigned OperationLegalizer::applyCostModelToPatterns(
3053 LegalizationPatterns &patterns,
3054 DenseMap<OperationName, unsigned> &minOpPatternDepth,
3056 unsigned minDepth = std::numeric_limits<unsigned>::max();
3057
3058 // Compute the depth for each pattern within the set.
3059 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3060 patternsByDepth.reserve(patterns.size());
3061 for (const Pattern *pattern : patterns) {
3062 unsigned depth = 1;
3063 for (auto generatedOp : pattern->getGeneratedOps()) {
3064 unsigned generatedOpDepth = computeOpLegalizationDepth(
3065 generatedOp, minOpPatternDepth, legalizerPatterns);
3066 depth = std::max(depth, generatedOpDepth + 1);
3067 }
3068 patternsByDepth.emplace_back(pattern, depth);
3069
3070 // Update the minimum depth of the pattern list.
3071 minDepth = std::min(minDepth, depth);
3072 }
3073
3074 // If the operation only has one legalization pattern, there is no need to
3075 // sort them.
3076 if (patternsByDepth.size() == 1)
3077 return minDepth;
3078
3079 // Sort the patterns by those likely to be the most beneficial.
3080 llvm::stable_sort(patternsByDepth,
3081 [](const std::pair<const Pattern *, unsigned> &lhs,
3082 const std::pair<const Pattern *, unsigned> &rhs) {
3083 // First sort by the smaller pattern legalization
3084 // depth.
3085 if (lhs.second != rhs.second)
3086 return lhs.second < rhs.second;
3087
3088 // Then sort by the larger pattern benefit.
3089 auto lhsBenefit = lhs.first->getBenefit();
3090 auto rhsBenefit = rhs.first->getBenefit();
3091 return lhsBenefit > rhsBenefit;
3092 });
3093
3094 // Update the legalization pattern to use the new sorted list.
3095 patterns.clear();
3096 for (auto &patternIt : patternsByDepth)
3097 patterns.push_back(patternIt.first);
3098 return minDepth;
3099}
3100
3101//===----------------------------------------------------------------------===//
3102// Reconcile Unrealized Casts
3103//===----------------------------------------------------------------------===//
3104
3105/// Try to reconcile all given UnrealizedConversionCastOps and store the
3106/// left-over ops in `remainingCastOps` (if provided). See documentation in
3107/// DialectConversion.h for more details.
3108/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3109/// algorithm may visit an operand (or user) which is a cast op, but will not
3110/// try to reconcile it if not in the filtered set.
3111template <typename RangeT>
3113 RangeT castOps,
3114 function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3116 // A worklist of cast ops to process.
3117 SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
3118
3119 // Helper function that return the unrealized_conversion_cast op that
3120 // defines all inputs of the given op (in the same order). Return "nullptr"
3121 // if there is no such op.
3122 auto getInputCast =
3123 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3124 if (castOp.getInputs().empty())
3125 return {};
3126 auto inputCastOp =
3127 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3128 if (!inputCastOp)
3129 return {};
3130 if (inputCastOp.getOutputs() != castOp.getInputs())
3131 return {};
3132 return inputCastOp;
3133 };
3134
3135 // Process ops in the worklist bottom-to-top.
3136 while (!worklist.empty()) {
3137 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3138
3139 // Traverse the chain of input cast ops to see if an op with the same
3140 // input types can be found.
3141 UnrealizedConversionCastOp nextCast = castOp;
3142 while (nextCast) {
3143 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3144 if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
3145 return v.getDefiningOp() == castOp;
3146 })) {
3147 // Ran into a cycle.
3148 break;
3149 }
3150
3151 // Found a cast where the input types match the output types of the
3152 // matched op. We can directly use those inputs.
3153 castOp.replaceAllUsesWith(nextCast.getInputs());
3154 break;
3155 }
3156 nextCast = getInputCast(nextCast);
3157 }
3158 }
3159
3160 // A set of all alive cast ops. I.e., ops whose results are (transitively)
3161 // used by an op that is not a cast op.
3162 DenseSet<Operation *> liveOps;
3163
3164 // Helper function that marks the given op and transitively reachable input
3165 // cast ops as alive.
3166 auto markOpLive = [&](Operation *rootOp) {
3167 SmallVector<Operation *> worklist;
3168 worklist.push_back(rootOp);
3169 while (!worklist.empty()) {
3170 Operation *op = worklist.pop_back_val();
3171 if (liveOps.insert(op).second) {
3172 // Successfully inserted: process reachable input cast ops.
3173 for (Value v : op->getOperands())
3174 if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3175 if (isCastOpOfInterestFn(castOp))
3176 worklist.push_back(castOp);
3177 }
3178 }
3179 };
3180
3181 // Find all alive cast ops.
3182 for (UnrealizedConversionCastOp op : castOps) {
3183 // The op may have been marked live already as being an operand of another
3184 // live cast op.
3185 if (liveOps.contains(op.getOperation()))
3186 continue;
3187 // If any of the users is not a cast op, mark the current op (and its
3188 // input ops) as live.
3189 if (llvm::any_of(op->getUsers(), [&](Operation *user) {
3190 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3191 return !castOp || !isCastOpOfInterestFn(castOp);
3192 }))
3193 markOpLive(op);
3194 }
3195
3196 // Erase all dead cast ops.
3197 for (UnrealizedConversionCastOp op : castOps) {
3198 if (liveOps.contains(op)) {
3199 // Op is alive and was not erased. Add it to the remaining cast ops.
3200 if (remainingCastOps)
3201 remainingCastOps->push_back(op);
3202 continue;
3203 }
3204
3205 // Op is dead. Erase it.
3206 op->dropAllUses();
3207 op->erase();
3208 }
3209}
3210
3212 ArrayRef<UnrealizedConversionCastOp> castOps,
3213 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3214 // Set of all cast ops for faster lookups.
3215 DenseSet<UnrealizedConversionCastOp> castOpSet;
3216 for (UnrealizedConversionCastOp op : castOps)
3217 castOpSet.insert(op);
3218 reconcileUnrealizedCasts(castOpSet, remainingCastOps);
3219}
3220
3222 const DenseSet<UnrealizedConversionCastOp> &castOps,
3223 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3225 llvm::make_range(castOps.begin(), castOps.end()),
3226 [&](UnrealizedConversionCastOp castOp) {
3227 return castOps.contains(castOp);
3228 },
3229 remainingCastOps);
3230}
3231
3232namespace mlir {
3235 &castOps,
3238 castOps.keys(),
3239 [&](UnrealizedConversionCastOp castOp) {
3240 return castOps.contains(castOp);
3241 },
3242 remainingCastOps);
3243}
3244} // namespace mlir
3245
3246//===----------------------------------------------------------------------===//
3247// OperationConverter
3248//===----------------------------------------------------------------------===//
3249
3250namespace mlir {
3251// This class converts operations to a given conversion target via a set of
3252// rewrite patterns. The conversion behaves differently depending on the
3253// conversion mode.
3257 const ConversionConfig &config,
3258 OpConversionMode mode)
3259 : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
3260 mode(mode) {}
3261
3262 /// Converts the given operations to the conversion target.
3263 LogicalResult convertOperations(ArrayRef<Operation *> ops);
3264
3265 /// Converts a single operation. If `isRecursiveLegalization` is "true", the
3266 /// conversion is a recursive legalization request, triggered from within a
3267 /// pattern. In that case, do not emit errors because there will be another
3268 /// attempt at legalizing the operation later (via the regular pre-order
3269 /// legalization mechanism).
3270 LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
3271
3272private:
3273 /// The rewriter to use when converting operations.
3274 ConversionPatternRewriter rewriter;
3275
3276 /// The legalizer to use when converting operations.
3277 OperationLegalizer opLegalizer;
3278
3279 /// The conversion mode to use when legalizing operations.
3280 OpConversionMode mode;
3281};
3282} // namespace mlir
3283
3284LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
3285 return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true);
3286}
3287
3289 bool isRecursiveLegalization) {
3290 const ConversionConfig &config = rewriter.getConfig();
3291
3292 // Legalize the given operation.
3293 if (failed(opLegalizer.legalize(op))) {
3294 // Handle the case of a failed conversion for each of the different modes.
3295 // Full conversions expect all operations to be converted.
3296 if (mode == OpConversionMode::Full) {
3297 if (!isRecursiveLegalization)
3298 op->emitError() << "failed to legalize operation '" << op->getName()
3299 << "'";
3300 return failure();
3301 }
3302 // Partial conversions allow conversions to fail iff the operation was not
3303 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
3304 // set, non-legalizable ops are added to that set.
3305 if (mode == OpConversionMode::Partial) {
3306 if (opLegalizer.isIllegal(op)) {
3307 if (!isRecursiveLegalization)
3308 op->emitError() << "failed to legalize operation '" << op->getName()
3309 << "' that was explicitly marked illegal";
3310 return failure();
3311 }
3312 if (config.unlegalizedOps && !isRecursiveLegalization)
3313 config.unlegalizedOps->insert(op);
3314 }
3315 } else if (mode == OpConversionMode::Analysis) {
3316 // Analysis conversions don't fail if any operations fail to legalize,
3317 // they are only interested in the operations that were successfully
3318 // legalized.
3319 if (config.legalizableOps && !isRecursiveLegalization)
3320 config.legalizableOps->insert(op);
3321 }
3322 return success();
3323}
3324
3325static LogicalResult
3327 UnrealizedConversionCastOp op,
3328 const UnresolvedMaterializationInfo &info) {
3329 assert(!op.use_empty() &&
3330 "expected that dead materializations have already been DCE'd");
3331 Operation::operand_range inputOperands = op.getOperands();
3332
3333 // Try to materialize the conversion.
3334 if (const TypeConverter *converter = info.getConverter()) {
3335 rewriter.setInsertionPoint(op);
3336 SmallVector<Value> newMaterialization;
3337 switch (info.getMaterializationKind()) {
3338 case MaterializationKind::Target:
3339 newMaterialization = converter->materializeTargetConversion(
3340 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3341 info.getOriginalType());
3342 break;
3343 case MaterializationKind::Source:
3344 assert(op->getNumResults() == 1 && "expected single result");
3345 Value sourceMat = converter->materializeSourceConversion(
3346 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3347 if (sourceMat)
3348 newMaterialization.push_back(sourceMat);
3349 break;
3350 }
3351 if (!newMaterialization.empty()) {
3352#ifndef NDEBUG
3353 ValueRange newMaterializationRange(newMaterialization);
3354 assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
3355 "materialization callback produced value of incorrect type");
3356#endif // NDEBUG
3357 rewriter.replaceOp(op, newMaterialization);
3358 return success();
3359 }
3360 }
3361
3362 InFlightDiagnostic diag = op->emitError()
3363 << "failed to legalize unresolved materialization "
3364 "from ("
3365 << inputOperands.getTypes() << ") to ("
3366 << op.getResultTypes()
3367 << ") that remained live after conversion";
3368 diag.attachNote(op->getUsers().begin()->getLoc())
3369 << "see existing live user here: " << *op->getUsers().begin();
3370 return failure();
3371}
3372
3374 const ConversionTarget &target = opLegalizer.getTarget();
3375
3376 // Compute the set of operations and blocks to convert.
3377 SmallVector<Operation *> toConvert;
3378 for (auto *op : ops) {
3380 [&](Operation *op) {
3381 toConvert.push_back(op);
3382 // Don't check this operation's children for conversion if the
3383 // operation is recursively legal.
3384 auto legalityInfo = target.isLegal(op);
3385 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3386 return WalkResult::skip();
3387 return WalkResult::advance();
3388 });
3389 }
3390
3391 // Convert each operation and discard rewrites on failure.
3392 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
3393
3394 for (auto *op : toConvert) {
3395 if (failed(convert(op))) {
3396 // Dialect conversion failed.
3397 if (rewriterImpl.config.allowPatternRollback) {
3398 // Rollback is allowed: restore the original IR.
3399 rewriterImpl.undoRewrites();
3400 } else {
3401 // Rollback is not allowed: apply all modifications that have been
3402 // performed so far.
3403 rewriterImpl.applyRewrites();
3404 }
3405 return failure();
3406 }
3407 }
3408
3409 // After a successful conversion, apply rewrites.
3410 rewriterImpl.applyRewrites();
3411
3412 // Reconcile all UnrealizedConversionCastOps that were inserted by the
3413 // dialect conversion frameworks. (Not the ones that were inserted by
3414 // patterns.)
3416 &materializations = rewriterImpl.unresolvedMaterializations;
3418 reconcileUnrealizedCasts(materializations, &remainingCastOps);
3419
3420 // Drop markers.
3421 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3422 castOp->removeAttr(kPureTypeConversionMarker);
3423
3424 // Try to legalize all unresolved materializations.
3425 if (rewriter.getConfig().buildMaterializations) {
3426 // Use a new rewriter, so the modifications are not tracked for rollback
3427 // purposes etc.
3428 IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
3429 rewriter.getConfig().listener);
3430 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3431 auto it = materializations.find(castOp);
3432 assert(it != materializations.end() && "inconsistent state");
3433 if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
3434 it->second)))
3435 return failure();
3436 }
3437 }
3438
3439 return success();
3440}
3441
3442//===----------------------------------------------------------------------===//
3443// Type Conversion
3444//===----------------------------------------------------------------------===//
3445
3446void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
3447 ArrayRef<Type> types) {
3448 assert(!types.empty() && "expected valid types");
3449 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
3450 addInputs(types);
3451}
3452
3453void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
3454 assert(!types.empty() &&
3455 "1->0 type remappings don't need to be added explicitly");
3456 argTypes.append(types.begin(), types.end());
3457}
3458
3459void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3460 unsigned newInputNo,
3461 unsigned newInputCount) {
3462 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3463 assert(newInputCount != 0 && "expected valid input count");
3464 remappedInputs[origInputNo] =
3465 InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
3466}
3467
3468void TypeConverter::SignatureConversion::remapInput(
3469 unsigned origInputNo, ArrayRef<Value> replacements) {
3470 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3471 remappedInputs[origInputNo] = InputMapping{
3472 origInputNo, /*size=*/0,
3473 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3474}
3475
3476/// Internal implementation of the type conversion.
3477/// This is used with either a Type or a Value as the first argument.
3478/// - we can cache the context-free conversions until the last registered
3479/// context-aware conversion.
3480/// - we can't cache the result of type conversion happening after context-aware
3481/// conversions, because the type converter may return different results for the
3482/// same input type.
3483LogicalResult
3484TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3485 SmallVectorImpl<Type> &results) const {
3486 assert(typeOrValue && "expected non-null type");
3487 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3488 : cast<Type>(typeOrValue);
3489 {
3490 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3491 std::defer_lock);
3493 cacheReadLock.lock();
3494 auto existingIt = cachedDirectConversions.find(t);
3495 if (existingIt != cachedDirectConversions.end()) {
3496 if (existingIt->second)
3497 results.push_back(existingIt->second);
3498 return success(existingIt->second != nullptr);
3499 }
3500 auto multiIt = cachedMultiConversions.find(t);
3501 if (multiIt != cachedMultiConversions.end()) {
3502 results.append(multiIt->second.begin(), multiIt->second.end());
3503 return success();
3504 }
3505 }
3506 // Walk the added converters in reverse order to apply the most recently
3507 // registered first.
3508 size_t currentCount = results.size();
3509
3510 // We can cache the context-free conversions until the last registered
3511 // context-aware conversion. But only if we're processing a Value right now.
3512 auto isCacheable = [&](int index) {
3513 int numberOfConversionsUntilContextAware =
3514 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3515 return index < numberOfConversionsUntilContextAware;
3516 };
3517
3518 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3519 std::defer_lock);
3520
3521 for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3522 const ConversionCallbackFn &converter = indexedConverter.value();
3523 std::optional<LogicalResult> result = converter(typeOrValue, results);
3524 if (!result) {
3525 assert(results.size() == currentCount &&
3526 "failed type conversion should not change results");
3527 continue;
3528 }
3529 if (!isCacheable(indexedConverter.index()))
3530 return success();
3532 cacheWriteLock.lock();
3533 if (!succeeded(*result)) {
3534 assert(results.size() == currentCount &&
3535 "failed type conversion should not change results");
3536 cachedDirectConversions.try_emplace(t, nullptr);
3537 return failure();
3538 }
3539 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3540 if (newTypes.size() == 1)
3541 cachedDirectConversions.try_emplace(t, newTypes.front());
3542 else
3543 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3544 return success();
3545 }
3546 return failure();
3547}
3548
3549LogicalResult TypeConverter::convertType(Type t,
3550 SmallVectorImpl<Type> &results) const {
3551 return convertTypeImpl(t, results);
3552}
3553
3554LogicalResult TypeConverter::convertType(Value v,
3555 SmallVectorImpl<Type> &results) const {
3556 return convertTypeImpl(v, results);
3557}
3558
3559Type TypeConverter::convertType(Type t) const {
3560 // Use the multi-type result version to convert the type.
3561 SmallVector<Type, 1> results;
3562 if (failed(convertType(t, results)))
3563 return nullptr;
3564
3565 // Check to ensure that only one type was produced.
3566 return results.size() == 1 ? results.front() : nullptr;
3567}
3568
3569Type TypeConverter::convertType(Value v) const {
3570 // Use the multi-type result version to convert the type.
3571 SmallVector<Type, 1> results;
3572 if (failed(convertType(v, results)))
3573 return nullptr;
3574
3575 // Check to ensure that only one type was produced.
3576 return results.size() == 1 ? results.front() : nullptr;
3577}
3578
3579LogicalResult
3580TypeConverter::convertTypes(TypeRange types,
3581 SmallVectorImpl<Type> &results) const {
3582 for (Type type : types)
3583 if (failed(convertType(type, results)))
3584 return failure();
3585 return success();
3586}
3587
3588LogicalResult
3589TypeConverter::convertTypes(ValueRange values,
3590 SmallVectorImpl<Type> &results) const {
3591 for (Value value : values)
3592 if (failed(convertType(value, results)))
3593 return failure();
3594 return success();
3595}
3596
3597bool TypeConverter::isLegal(Type type) const {
3598 return convertType(type) == type;
3599}
3600
3601bool TypeConverter::isLegal(Value value) const {
3602 return convertType(value) == value.getType();
3603}
3604
3605bool TypeConverter::isLegal(Operation *op) const {
3606 return isLegal(op->getOperands()) && isLegal(op->getResults());
3607}
3608
3609bool TypeConverter::isLegal(Region *region) const {
3610 return llvm::all_of(
3611 *region, [this](Block &block) { return isLegal(block.getArguments()); });
3612}
3613
3614bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3615 if (!isLegal(ty.getInputs()))
3616 return false;
3617 if (!isLegal(ty.getResults()))
3618 return false;
3619 return true;
3620}
3621
3622LogicalResult
3623TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
3624 SignatureConversion &result) const {
3625 // Try to convert the given input type.
3626 SmallVector<Type, 1> convertedTypes;
3627 if (failed(convertType(type, convertedTypes)))
3628 return failure();
3629
3630 // If this argument is being dropped, there is nothing left to do.
3631 if (convertedTypes.empty())
3632 return success();
3633
3634 // Otherwise, add the new inputs.
3635 result.addInputs(inputNo, convertedTypes);
3636 return success();
3637}
3638LogicalResult
3639TypeConverter::convertSignatureArgs(TypeRange types,
3640 SignatureConversion &result,
3641 unsigned origInputOffset) const {
3642 for (unsigned i = 0, e = types.size(); i != e; ++i)
3643 if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3644 return failure();
3645 return success();
3646}
3647LogicalResult
3648TypeConverter::convertSignatureArg(unsigned inputNo, Value value,
3649 SignatureConversion &result) const {
3650 // Try to convert the given input type.
3651 SmallVector<Type, 1> convertedTypes;
3652 if (failed(convertType(value, convertedTypes)))
3653 return failure();
3654
3655 // If this argument is being dropped, there is nothing left to do.
3656 if (convertedTypes.empty())
3657 return success();
3658
3659 // Otherwise, add the new inputs.
3660 result.addInputs(inputNo, convertedTypes);
3661 return success();
3662}
3663LogicalResult
3664TypeConverter::convertSignatureArgs(ValueRange values,
3665 SignatureConversion &result,
3666 unsigned origInputOffset) const {
3667 for (unsigned i = 0, e = values.size(); i != e; ++i)
3668 if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
3669 return failure();
3670 return success();
3671}
3672
3673Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3674 Location loc, Type resultType,
3675 ValueRange inputs) const {
3676 for (const SourceMaterializationCallbackFn &fn :
3677 llvm::reverse(sourceMaterializations))
3678 if (Value result = fn(builder, resultType, inputs, loc))
3679 return result;
3680 return nullptr;
3681}
3682
3683Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3684 Location loc, Type resultType,
3685 ValueRange inputs,
3686 Type originalType) const {
3687 SmallVector<Value> result = materializeTargetConversion(
3688 builder, loc, TypeRange(resultType), inputs, originalType);
3689 if (result.empty())
3690 return nullptr;
3691 assert(result.size() == 1 && "expected single result");
3692 return result.front();
3693}
3694
3695SmallVector<Value> TypeConverter::materializeTargetConversion(
3696 OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3697 Type originalType) const {
3698 for (const TargetMaterializationCallbackFn &fn :
3699 llvm::reverse(targetMaterializations)) {
3700 SmallVector<Value> result =
3701 fn(builder, resultTypes, inputs, loc, originalType);
3702 if (result.empty())
3703 continue;
3704 assert(TypeRange(ValueRange(result)) == resultTypes &&
3705 "callback produced incorrect number of values or values with "
3706 "incorrect types");
3707 return result;
3708 }
3709 return {};
3710}
3711
3712std::optional<TypeConverter::SignatureConversion>
3713TypeConverter::convertBlockSignature(Block *block) const {
3714 SignatureConversion conversion(block->getNumArguments());
3715 if (failed(convertSignatureArgs(block->getArguments(), conversion)))
3716 return std::nullopt;
3717 return conversion;
3718}
3719
3720//===----------------------------------------------------------------------===//
3721// Type attribute conversion
3722//===----------------------------------------------------------------------===//
3723TypeConverter::AttributeConversionResult
3724TypeConverter::AttributeConversionResult::result(Attribute attr) {
3725 return AttributeConversionResult(attr, resultTag);
3726}
3727
3728TypeConverter::AttributeConversionResult
3729TypeConverter::AttributeConversionResult::na() {
3730 return AttributeConversionResult(nullptr, naTag);
3731}
3732
3733TypeConverter::AttributeConversionResult
3734TypeConverter::AttributeConversionResult::abort() {
3735 return AttributeConversionResult(nullptr, abortTag);
3736}
3737
3738bool TypeConverter::AttributeConversionResult::hasResult() const {
3739 return impl.getInt() == resultTag;
3740}
3741
3742bool TypeConverter::AttributeConversionResult::isNa() const {
3743 return impl.getInt() == naTag;
3744}
3745
3746bool TypeConverter::AttributeConversionResult::isAbort() const {
3747 return impl.getInt() == abortTag;
3748}
3749
3750Attribute TypeConverter::AttributeConversionResult::getResult() const {
3751 assert(hasResult() && "Cannot get result from N/A or abort");
3752 return impl.getPointer();
3753}
3754
3755std::optional<Attribute>
3756TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
3757 for (const TypeAttributeConversionCallbackFn &fn :
3758 llvm::reverse(typeAttributeConversions)) {
3759 AttributeConversionResult res = fn(type, attr);
3760 if (res.hasResult())
3761 return res.getResult();
3762 if (res.isAbort())
3763 return std::nullopt;
3764 }
3765 return std::nullopt;
3766}
3767
3768//===----------------------------------------------------------------------===//
3769// FunctionOpInterfaceSignatureConversion
3770//===----------------------------------------------------------------------===//
3771
3772static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3773 const TypeConverter &typeConverter,
3774 ConversionPatternRewriter &rewriter) {
3775 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3776 if (!type)
3777 return failure();
3778
3779 // Convert the original function types.
3780 TypeConverter::SignatureConversion result(type.getNumInputs());
3781 SmallVector<Type, 1> newResults;
3782 if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3783 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3784 return failure();
3785 if (!funcOp.getFunctionBody().empty())
3786 rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
3787 &typeConverter);
3788
3789 // Update the function signature in-place.
3790 auto newType = FunctionType::get(rewriter.getContext(),
3791 result.getConvertedTypes(), newResults);
3792
3793 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3794
3795 return success();
3796}
3797
3798/// Create a default conversion pattern that rewrites the type signature of a
3799/// FunctionOpInterface op. This only supports ops which use FunctionType to
3800/// represent their type.
3801namespace {
3802struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3803 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3804 MLIRContext *ctx,
3805 const TypeConverter &converter,
3806 PatternBenefit benefit)
3807 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3808
3809 LogicalResult
3810 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3811 ConversionPatternRewriter &rewriter) const override {
3812 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3813 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3814 }
3815};
3816
3817struct AnyFunctionOpInterfaceSignatureConversion
3818 : public OpInterfaceConversionPattern<FunctionOpInterface> {
3819 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3820
3821 LogicalResult
3822 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3823 ConversionPatternRewriter &rewriter) const override {
3824 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3825 }
3826};
3827} // namespace
3828
3829FailureOr<Operation *>
3830mlir::convertOpResultTypes(Operation *op, ValueRange operands,
3831 const TypeConverter &converter,
3832 ConversionPatternRewriter &rewriter) {
3833 assert(op && "Invalid op");
3834 Location loc = op->getLoc();
3835 if (converter.isLegal(op))
3836 return rewriter.notifyMatchFailure(loc, "op already legal");
3837
3838 OperationState newOp(loc, op->getName());
3839 newOp.addOperands(operands);
3840
3841 SmallVector<Type> newResultTypes;
3842 if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
3843 return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3844
3845 newOp.addTypes(newResultTypes);
3846 newOp.addAttributes(op->getAttrs());
3847 return rewriter.create(newOp);
3848}
3849
3850void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3851 StringRef functionLikeOpName, RewritePatternSet &patterns,
3852 const TypeConverter &converter, PatternBenefit benefit) {
3853 patterns.add<FunctionOpInterfaceSignatureConversion>(
3854 functionLikeOpName, patterns.getContext(), converter, benefit);
3855}
3856
3857void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3858 RewritePatternSet &patterns, const TypeConverter &converter,
3859 PatternBenefit benefit) {
3860 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3861 converter, patterns.getContext(), benefit);
3862}
3863
3864//===----------------------------------------------------------------------===//
3865// ConversionTarget
3866//===----------------------------------------------------------------------===//
3867
3868void ConversionTarget::setOpAction(OperationName op,
3869 LegalizationAction action) {
3870 legalOperations[op].action = action;
3871}
3872
3873void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3874 LegalizationAction action) {
3875 for (StringRef dialect : dialectNames)
3876 legalDialects[dialect] = action;
3877}
3878
3879auto ConversionTarget::getOpAction(OperationName op) const
3880 -> std::optional<LegalizationAction> {
3881 std::optional<LegalizationInfo> info = getOpInfo(op);
3882 return info ? info->action : std::optional<LegalizationAction>();
3883}
3884
3885auto ConversionTarget::isLegal(Operation *op) const
3886 -> std::optional<LegalOpDetails> {
3887 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3888 if (!info)
3889 return std::nullopt;
3890
3891 // Returns true if this operation instance is known to be legal.
3892 auto isOpLegal = [&] {
3893 // Handle dynamic legality either with the provided legality function.
3894 if (info->action == LegalizationAction::Dynamic) {
3895 std::optional<bool> result = info->legalityFn(op);
3896 if (result)
3897 return *result;
3898 }
3899
3900 // Otherwise, the operation is only legal if it was marked 'Legal'.
3901 return info->action == LegalizationAction::Legal;
3902 };
3903 if (!isOpLegal())
3904 return std::nullopt;
3905
3906 // This operation is legal, compute any additional legality information.
3907 LegalOpDetails legalityDetails;
3908 if (info->isRecursivelyLegal) {
3909 auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3910 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3911 legalityDetails.isRecursivelyLegal =
3912 legalityFnIt->second(op).value_or(true);
3913 } else {
3914 legalityDetails.isRecursivelyLegal = true;
3915 }
3916 }
3917 return legalityDetails;
3918}
3919
3920bool ConversionTarget::isIllegal(Operation *op) const {
3921 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3922 if (!info)
3923 return false;
3924
3925 if (info->action == LegalizationAction::Dynamic) {
3926 std::optional<bool> result = info->legalityFn(op);
3927 if (!result)
3928 return false;
3929
3930 return !(*result);
3931 }
3932
3933 return info->action == LegalizationAction::Illegal;
3934}
3935
3936static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
3937 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
3938 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
3939 if (!oldCallback)
3940 return newCallback;
3941
3942 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
3943 Operation *op) -> std::optional<bool> {
3944 if (std::optional<bool> result = newCl(op))
3945 return *result;
3946
3947 return oldCl(op);
3948 };
3949 return chain;
3950}
3951
3952void ConversionTarget::setLegalityCallback(
3953 OperationName name, const DynamicLegalityCallbackFn &callback) {
3954 assert(callback && "expected valid legality callback");
3955 auto *infoIt = legalOperations.find(name);
3956 assert(infoIt != legalOperations.end() &&
3957 infoIt->second.action == LegalizationAction::Dynamic &&
3958 "expected operation to already be marked as dynamically legal");
3959 infoIt->second.legalityFn =
3960 composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
3961}
3962
3963void ConversionTarget::markOpRecursivelyLegal(
3964 OperationName name, const DynamicLegalityCallbackFn &callback) {
3965 auto *infoIt = legalOperations.find(name);
3966 assert(infoIt != legalOperations.end() &&
3967 infoIt->second.action != LegalizationAction::Illegal &&
3968 "expected operation to already be marked as legal");
3969 infoIt->second.isRecursivelyLegal = true;
3970 if (callback)
3971 opRecursiveLegalityFns[name] = composeLegalityCallbacks(
3972 std::move(opRecursiveLegalityFns[name]), callback);
3973 else
3974 opRecursiveLegalityFns.erase(name);
3975}
3976
3977void ConversionTarget::setLegalityCallback(
3978 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
3979 assert(callback && "expected valid legality callback");
3980 for (StringRef dialect : dialects)
3981 dialectLegalityFns[dialect] = composeLegalityCallbacks(
3982 std::move(dialectLegalityFns[dialect]), callback);
3983}
3984
3985void ConversionTarget::setLegalityCallback(
3986 const DynamicLegalityCallbackFn &callback) {
3987 assert(callback && "expected valid legality callback");
3988 unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
3989}
3990
3991auto ConversionTarget::getOpInfo(OperationName op) const
3992 -> std::optional<LegalizationInfo> {
3993 // Check for info for this specific operation.
3994 const auto *it = legalOperations.find(op);
3995 if (it != legalOperations.end())
3996 return it->second;
3997 // Check for info for the parent dialect.
3998 auto dialectIt = legalDialects.find(op.getDialectNamespace());
3999 if (dialectIt != legalDialects.end()) {
4000 DynamicLegalityCallbackFn callback;
4001 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4002 if (dialectFn != dialectLegalityFns.end())
4003 callback = dialectFn->second;
4004 return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
4005 callback};
4006 }
4007 // Otherwise, check if we mark unknown operations as dynamic.
4008 if (unknownLegalityFn)
4009 return LegalizationInfo{LegalizationAction::Dynamic,
4010 /*isRecursivelyLegal=*/false, unknownLegalityFn};
4011 return std::nullopt;
4012}
4013
4014#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4015//===----------------------------------------------------------------------===//
4016// PDL Configuration
4017//===----------------------------------------------------------------------===//
4018
4019void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4020 auto &rewriterImpl =
4021 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4022 rewriterImpl.currentTypeConverter = getTypeConverter();
4023}
4024
4025void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4026 auto &rewriterImpl =
4027 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4028 rewriterImpl.currentTypeConverter = nullptr;
4029}
4030
4031/// Remap the given value using the rewriter and the type converter in the
4032/// provided config.
4033static FailureOr<SmallVector<Value>>
4034pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
4035 SmallVector<Value> mappedValues;
4036 if (failed(rewriter.getRemappedValues(values, mappedValues)))
4037 return failure();
4038 return std::move(mappedValues);
4039}
4040
4041void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
4042 patterns.getPDLPatterns().registerRewriteFunction(
4043 "convertValue",
4044 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4045 auto results = pdllConvertValues(
4046 static_cast<ConversionPatternRewriter &>(rewriter), value);
4047 if (failed(results))
4048 return failure();
4049 return results->front();
4050 });
4051 patterns.getPDLPatterns().registerRewriteFunction(
4052 "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
4053 return pdllConvertValues(
4054 static_cast<ConversionPatternRewriter &>(rewriter), values);
4055 });
4056 patterns.getPDLPatterns().registerRewriteFunction(
4057 "convertType",
4058 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4059 auto &rewriterImpl =
4060 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4061 if (const TypeConverter *converter =
4062 rewriterImpl.currentTypeConverter) {
4063 if (Type newType = converter->convertType(type))
4064 return newType;
4065 return failure();
4066 }
4067 return type;
4068 });
4069 patterns.getPDLPatterns().registerRewriteFunction(
4070 "convertTypes",
4071 [](PatternRewriter &rewriter,
4072 TypeRange types) -> FailureOr<SmallVector<Type>> {
4073 auto &rewriterImpl =
4074 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4075 const TypeConverter *converter = rewriterImpl.currentTypeConverter;
4076 if (!converter)
4077 return SmallVector<Type>(types);
4078
4079 SmallVector<Type> remappedTypes;
4080 if (failed(converter->convertTypes(types, remappedTypes)))
4081 return failure();
4082 return std::move(remappedTypes);
4083 });
4084}
4085#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
4086
4087//===----------------------------------------------------------------------===//
4088// Op Conversion Entry Points
4089//===----------------------------------------------------------------------===//
4090
4091/// This is the type of Action that is dispatched when a conversion is applied.
4093 : public tracing::ActionImpl<ApplyConversionAction> {
4094public:
4097 static constexpr StringLiteral tag = "apply-conversion";
4098 static constexpr StringLiteral desc =
4099 "Encapsulate the application of a dialect conversion";
4100
4101 void print(raw_ostream &os) const override { os << tag; }
4102};
4103
4105 const ConversionTarget &target,
4107 ConversionConfig config,
4108 OpConversionMode mode) {
4109 if (ops.empty())
4110 return success();
4111 MLIRContext *ctx = ops.front()->getContext();
4112 LogicalResult status = success();
4113 SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
4115 [&] {
4116 OperationConverter opConverter(ops.front()->getContext(), target,
4117 patterns, config, mode);
4118 status = opConverter.convertOperations(ops);
4119 },
4120 irUnits);
4121 return status;
4122}
4123
4124//===----------------------------------------------------------------------===//
4125// Partial Conversion
4126//===----------------------------------------------------------------------===//
4127
4128LogicalResult mlir::applyPartialConversion(
4129 ArrayRef<Operation *> ops, const ConversionTarget &target,
4130 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4131 return applyConversion(ops, target, patterns, config,
4132 OpConversionMode::Partial);
4133}
4134LogicalResult
4135mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
4136 const FrozenRewritePatternSet &patterns,
4137 ConversionConfig config) {
4138 return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
4139}
4140
4141//===----------------------------------------------------------------------===//
4142// Full Conversion
4143//===----------------------------------------------------------------------===//
4144
4145LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4146 const ConversionTarget &target,
4147 const FrozenRewritePatternSet &patterns,
4148 ConversionConfig config) {
4149 return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
4150}
4151LogicalResult mlir::applyFullConversion(Operation *op,
4152 const ConversionTarget &target,
4153 const FrozenRewritePatternSet &patterns,
4154 ConversionConfig config) {
4155 return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
4156}
4157
4158//===----------------------------------------------------------------------===//
4159// Analysis Conversion
4160//===----------------------------------------------------------------------===//
4161
4162/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
4163/// op is a top-level module op (which is expected to be isolated from above),
4164/// return that op.
4166 // Check if there is a top-level operation within `ops`. If so, return that
4167 // op.
4168 for (Operation *op : ops) {
4169 if (!op->getParentOp()) {
4170#ifndef NDEBUG
4171 assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
4172 "expected top-level op to be isolated from above");
4173 for (Operation *other : ops)
4174 assert(op->isAncestor(other) &&
4175 "expected ops to have a common ancestor");
4176#endif // NDEBUG
4177 return op;
4178 }
4179 }
4180
4181 // No top-level op. Find a common ancestor.
4182 Operation *commonAncestor =
4183 ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
4184 for (Operation *op : ops.drop_front()) {
4185 while (!commonAncestor->isProperAncestor(op)) {
4186 commonAncestor =
4188 assert(commonAncestor &&
4189 "expected to find a common isolated from above ancestor");
4190 }
4191 }
4192
4193 return commonAncestor;
4194}
4195
4196LogicalResult mlir::applyAnalysisConversion(
4197 ArrayRef<Operation *> ops, ConversionTarget &target,
4198 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4199#ifndef NDEBUG
4200 if (config.legalizableOps)
4201 assert(config.legalizableOps->empty() && "expected empty set");
4202#endif // NDEBUG
4203
4204 // Clone closted common ancestor that is isolated from above.
4205 Operation *commonAncestor = findCommonAncestor(ops);
4206 IRMapping mapping;
4207 Operation *clonedAncestor = commonAncestor->clone(mapping);
4208 // Compute inverse IR mapping.
4209 DenseMap<Operation *, Operation *> inverseOperationMap;
4210 for (auto &it : mapping.getOperationMap())
4211 inverseOperationMap[it.second] = it.first;
4212
4213 // Convert the cloned operations. The original IR will remain unchanged.
4214 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4215 ops, [&](Operation *op) { return mapping.lookup(op); });
4216 LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
4217 OpConversionMode::Analysis);
4218
4219 // Remap `legalizableOps`, so that they point to the original ops and not the
4220 // cloned ops.
4221 if (config.legalizableOps) {
4222 DenseSet<Operation *> originalLegalizableOps;
4223 for (Operation *op : *config.legalizableOps)
4224 originalLegalizableOps.insert(inverseOperationMap[op]);
4225 *config.legalizableOps = std::move(originalLegalizableOps);
4226 }
4227
4228 // Erase the cloned IR.
4229 clonedAncestor->erase();
4230 return status;
4231}
4232
4233LogicalResult
4234mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
4235 const FrozenRewritePatternSet &patterns,
4236 ConversionConfig config) {
4237 return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
4238}
return success()
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl)
Replace all uses of from with repl.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
#define DEBUG_TYPE
b getContext())
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition SCCP.cpp:67
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Definition Value.h:309
Location getLoc() const
Return the location for this argument.
Definition Value.h:324
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:140
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:149
bool empty()
Definition Block.h:148
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:308
BlockArgListType getArguments()
Definition Block.h:87
iterator end()
Definition Block.h:144
iterator begin()
Definition Block.h:143
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
UnitAttr getUnitAttr()
Definition Builders.cpp:98
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
MLIRContext * context
Definition Builders.h:202
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition Builders.h:327
Block::iterator getPoint() const
Definition Builders.h:340
bool isSet() const
Returns true if this insert point is set.
Definition Builders.h:337
Block * getBlock() const
Definition Builders.h:339
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition Builders.cpp:473
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:421
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for ops that are known to be isolated from above.
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
void destroyOpProperties(OpaqueProperties properties) const
This hooks destroy the op properties.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:226
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:852
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:834
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition Operation.h:900
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
CRTP Implementation of an action.
Definition Action.h:76
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Definition Action.h:66
AttrTypeReplacer.
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
This iterator enumerates elements according to their dominance relationship.
Definition Iterators.h:48
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition Builders.h:308
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition Builders.h:298
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
LogicalResult convertOperations(ArrayRef< Operation * > ops)
Converts the given operations to the conversion target.
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
void replaceAllUsesWith(Value from, ValueRange to, const TypeConverter *converter)
Replace the uses of the given value with the given values.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.