MLIR 23.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/Config/mlir-config.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Iterators.h"
17#include "mlir/IR/Operation.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/SaveAndRestore.h"
27#include "llvm/Support/ScopedPrinter.h"
28#include <optional>
29#include <utility>
30
31using namespace mlir;
32using namespace mlir::detail;
33
34#define DEBUG_TYPE "dialect-conversion"
35
36/// A utility function to log a successful result for the given reason.
37template <typename... Args>
38static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
39 LLVM_DEBUG({
40 os.unindent();
41 os.startLine() << "} -> SUCCESS";
42 if (!fmt.empty())
43 os.getOStream() << " : "
44 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
45 os.getOStream() << "\n";
46 });
47}
48
49/// A utility function to log a failure result for the given reason.
50template <typename... Args>
51static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
52 LLVM_DEBUG({
53 os.unindent();
54 os.startLine() << "} -> FAILURE : "
55 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
56 << "\n";
57 });
58}
59
60/// Helper function that computes an insertion point where the given value is
61/// defined and can be used without a dominance violation.
63 Block *insertBlock = value.getParentBlock();
64 Block::iterator insertPt = insertBlock->begin();
65 if (OpResult inputRes = dyn_cast<OpResult>(value))
66 insertPt = ++inputRes.getOwner()->getIterator();
67 return OpBuilder::InsertPoint(insertBlock, insertPt);
68}
69
70/// Helper function that computes an insertion point where the given values are
71/// defined and can be used without a dominance violation.
73 assert(!vals.empty() && "expected at least one value");
74 DominanceInfo domInfo;
76 for (Value v : vals.drop_front()) {
77 // Choose the "later" insertion point.
79 if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
80 nextPt.getPoint())) {
81 // pt is before nextPt => choose nextPt.
82 pt = nextPt;
83 } else {
84#ifndef NDEBUG
85 // nextPt should be before pt => choose pt.
86 // If pt, nextPt are no dominance relationship, then there is no valid
87 // insertion point at which all given values are defined.
88 bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
89 pt.getBlock(), pt.getPoint());
90 assert(dom && "unable to find valid insertion point");
91#endif // NDEBUG
92 }
93 }
94 return pt;
95}
96
97namespace {
98enum OpConversionMode {
99 /// In this mode, the conversion will ignore failed conversions to allow
100 /// illegal operations to co-exist in the IR.
101 Partial,
102
103 /// In this mode, all operations must be legal for the given target for the
104 /// conversion to succeed.
105 Full,
106
107 /// In this mode, operations are analyzed for legality. No actual rewrites are
108 /// applied to the operations on success.
109 Analysis,
110};
111} // namespace
112
113//===----------------------------------------------------------------------===//
114// ConversionValueMapping
115//===----------------------------------------------------------------------===//
116
117/// A vector of SSA values, optimized for the most common case of a single
118/// value.
120
121namespace {
122
123/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
124struct ValueVectorMapInfo {
125 static ValueVector getEmptyKey() { return ValueVector{Value()}; }
126 static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
127 static ::llvm::hash_code getHashValue(const ValueVector &val) {
128 return ::llvm::hash_combine_range(val);
129 }
130 static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
131 return LHS == RHS;
132 }
133};
134
135/// This class wraps a IRMapping to provide recursive lookup
136/// functionality, i.e. we will traverse if the mapped value also has a mapping.
137struct ConversionValueMapping {
138 /// Return "true" if an SSA value is mapped to the given value. May return
139 /// false positives.
140 bool isMappedTo(Value value) const { return mappedTo.contains(value); }
141
142 /// Lookup a value in the mapping.
143 ValueVector lookup(const ValueVector &from) const;
144
145 template <typename T>
146 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
147
148 /// Map a value vector to the one provided.
149 template <typename OldVal, typename NewVal>
150 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
151 map(OldVal &&oldVal, NewVal &&newVal) {
152 LLVM_DEBUG({
153 ValueVector next(newVal);
154 while (true) {
155 assert(next != oldVal && "inserting cyclic mapping");
156 auto it = mapping.find(next);
157 if (it == mapping.end())
158 break;
159 next = it->second;
160 }
161 });
162 mappedTo.insert_range(newVal);
163
164 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
165 }
166
167 /// Map a value vector or single value to the one provided.
168 template <typename OldVal, typename NewVal>
169 std::enable_if_t<!IsValueVector<OldVal>::value ||
170 !IsValueVector<NewVal>::value>
171 map(OldVal &&oldVal, NewVal &&newVal) {
172 if constexpr (IsValueVector<OldVal>{}) {
173 map(std::forward<OldVal>(oldVal), ValueVector{newVal});
174 } else if constexpr (IsValueVector<NewVal>{}) {
175 map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
176 } else {
177 map(ValueVector{oldVal}, ValueVector{newVal});
178 }
179 }
180
181 void map(Value oldVal, SmallVector<Value> &&newVal) {
182 map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
183 }
184
185 /// Drop the last mapping for the given values.
186 void erase(const ValueVector &value) { mapping.erase(value); }
187
188private:
189 /// Current value mappings.
191
192 /// All SSA values that are mapped to. May contain false positives.
193 DenseSet<Value> mappedTo;
194};
195} // namespace
196
197/// Marker attribute for pure type conversions. I.e., mappings whose only
198/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
199/// the replacement values of a "replaceOp" call, etc., are not pure type
200/// conversions.)
201static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
202
203/// Return the operation that defines all values in the vector. Return nullptr
204/// if the values are not defined by the same operation.
206 assert(!values.empty() && "expected non-empty value vector");
207 Operation *op = values.front().getDefiningOp();
208 for (Value v : llvm::drop_begin(values)) {
209 if (v.getDefiningOp() != op)
210 return nullptr;
211 }
212 return op;
213}
214
215/// A vector of values is a pure type conversion if all values are defined by
216/// the same operation and the operation has the `kPureTypeConversionMarker`
217/// attribute.
218static bool isPureTypeConversion(const ValueVector &values) {
219 assert(!values.empty() && "expected non-empty value vector");
220 Operation *op = getCommonDefiningOp(values);
221 return op && op->hasAttr(kPureTypeConversionMarker);
222}
223
224ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
225 auto it = mapping.find(from);
226 if (it == mapping.end()) {
227 // No mapping found: The lookup stops here.
228 return {};
229 }
230 return it->second;
231}
232
233//===----------------------------------------------------------------------===//
234// Rewriter and Translation State
235//===----------------------------------------------------------------------===//
236namespace {
237/// This class contains a snapshot of the current conversion rewriter state.
238/// This is useful when saving and undoing a set of rewrites.
239struct RewriterState {
240 RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
241 unsigned numReplacedOps)
242 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
243 numReplacedOps(numReplacedOps) {}
244
245 /// The current number of rewrites performed.
246 unsigned numRewrites;
247
248 /// The current number of ignored operations.
249 unsigned numIgnoredOperations;
250
251 /// The current number of replaced ops that are scheduled for erasure.
252 unsigned numReplacedOps;
253};
254
255//===----------------------------------------------------------------------===//
256// IR rewrites
257//===----------------------------------------------------------------------===//
258
259static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
260
261/// Notify the listener that the given block and its contents are being erased.
262static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
263 for (Operation &op : b)
264 notifyIRErased(listener, op);
265 listener->notifyBlockErased(&b);
266}
267
268/// Notify the listener that the given operation and its contents are being
269/// erased.
270static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
271 for (Region &r : op.getRegions()) {
272 for (Block &b : r) {
273 notifyIRErased(listener, b);
274 }
275 }
276 listener->notifyOperationErased(&op);
277}
278
279/// An IR rewrite that can be committed (upon success) or rolled back (upon
280/// failure).
281///
282/// The dialect conversion keeps track of IR modifications (requested by the
283/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
284/// are directly applied to the IR as the rewriter API is used, some are applied
285/// partially, and some are delayed until the `IRRewrite` objects are committed.
286class IRRewrite {
287public:
288 /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
289 /// Enum values are ordered, so that they can be used in `classof`: first all
290 /// block rewrites, then all operation rewrites.
291 enum class Kind {
292 // Block rewrites
293 CreateBlock,
294 EraseBlock,
295 InlineBlock,
296 MoveBlock,
297 BlockTypeConversion,
298 // Operation rewrites
299 MoveOperation,
300 ModifyOperation,
301 ReplaceOperation,
302 CreateOperation,
303 UnresolvedMaterialization,
304 // Value rewrites
305 ReplaceValue
306 };
307
308 virtual ~IRRewrite() = default;
309
310 /// Roll back the rewrite. Operations may be erased during rollback.
311 virtual void rollback() = 0;
312
313 /// Commit the rewrite. At this point, it is certain that the dialect
314 /// conversion will succeed. All IR modifications, except for operation/block
315 /// erasure, must be performed through the given rewriter.
316 ///
317 /// Instead of erasing operations/blocks, they should merely be unlinked
318 /// commit phase and finally be erased during the cleanup phase. This is
319 /// because internal dialect conversion state (such as `mapping`) may still
320 /// be using them.
321 ///
322 /// Any IR modification that was already performed before the commit phase
323 /// (e.g., insertion of an op) must be communicated to the listener that may
324 /// be attached to the given rewriter.
325 virtual void commit(RewriterBase &rewriter) {}
326
327 /// Cleanup operations/blocks. Cleanup is called after commit.
328 virtual void cleanup(RewriterBase &rewriter) {}
329
330 Kind getKind() const { return kind; }
331
332 static bool classof(const IRRewrite *rewrite) { return true; }
333
334protected:
335 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
336 : kind(kind), rewriterImpl(rewriterImpl) {}
337
338 const ConversionConfig &getConfig() const;
339
340 const Kind kind;
341 ConversionPatternRewriterImpl &rewriterImpl;
342};
343
344/// A block rewrite.
345class BlockRewrite : public IRRewrite {
346public:
347 /// Return the block that this rewrite operates on.
348 Block *getBlock() const { return block; }
349
350 static bool classof(const IRRewrite *rewrite) {
351 return rewrite->getKind() >= Kind::CreateBlock &&
352 rewrite->getKind() <= Kind::BlockTypeConversion;
353 }
354
355protected:
356 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
357 Block *block)
358 : IRRewrite(kind, rewriterImpl), block(block) {}
359
360 // The block that this rewrite operates on.
361 Block *block;
362};
363
364/// A value rewrite.
365class ValueRewrite : public IRRewrite {
366public:
367 /// Return the value that this rewrite operates on.
368 Value getValue() const { return value; }
369
370 static bool classof(const IRRewrite *rewrite) {
371 return rewrite->getKind() == Kind::ReplaceValue;
372 }
373
374protected:
375 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
376 Value value)
377 : IRRewrite(kind, rewriterImpl), value(value) {}
378
379 // The value that this rewrite operates on.
380 Value value;
381};
382
383/// Creation of a block. Block creations are immediately reflected in the IR.
384/// There is no extra work to commit the rewrite. During rollback, the newly
385/// created block is erased.
386class CreateBlockRewrite : public BlockRewrite {
387public:
388 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
389 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
390
391 static bool classof(const IRRewrite *rewrite) {
392 return rewrite->getKind() == Kind::CreateBlock;
393 }
394
395 void commit(RewriterBase &rewriter) override {
396 // The block was already created and inserted. Just inform the listener.
397 if (auto *listener = rewriter.getListener())
398 listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
399 }
400
401 void rollback() override {
402 // Unlink all of the operations within this block, they will be deleted
403 // separately.
404 auto &blockOps = block->getOperations();
405 while (!blockOps.empty())
406 blockOps.remove(blockOps.begin());
407 block->dropAllUses();
408 if (block->getParent())
409 block->erase();
410 else
411 delete block;
412 }
413};
414
415/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
416/// blocks are immediately unlinked, but only erased during cleanup. This makes
417/// it easier to rollback a block erasure: the block is simply inserted into its
418/// original location.
419class EraseBlockRewrite : public BlockRewrite {
420public:
421 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
422 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
423 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
424
425 static bool classof(const IRRewrite *rewrite) {
426 return rewrite->getKind() == Kind::EraseBlock;
427 }
428
429 ~EraseBlockRewrite() override {
430 assert(!block &&
431 "rewrite was neither rolled back nor committed/cleaned up");
432 }
433
434 void rollback() override {
435 // The block (owned by this rewrite) was not actually erased yet. It was
436 // just unlinked. Put it back into its original position.
437 assert(block && "expected block");
438 auto &blockList = region->getBlocks();
439 Region::iterator before = insertBeforeBlock
440 ? Region::iterator(insertBeforeBlock)
441 : blockList.end();
442 blockList.insert(before, block);
443 block = nullptr;
444 }
445
446 void commit(RewriterBase &rewriter) override {
447 assert(block && "expected block");
448
449 // Notify the listener that the block and its contents are being erased.
450 if (auto *listener =
451 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
452 notifyIRErased(listener, *block);
453 }
454
455 void cleanup(RewriterBase &rewriter) override {
456 // Erase the contents of the block.
457 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
458 rewriter.eraseOp(&op);
459 assert(block->empty() && "expected empty block");
460
461 // Erase the block.
462 block->dropAllDefinedValueUses();
463 delete block;
464 block = nullptr;
465 }
466
467private:
468 // The region in which this block was previously contained.
469 Region *region;
470
471 // The original successor of this block before it was unlinked. "nullptr" if
472 // this block was the only block in the region.
473 Block *insertBeforeBlock;
474};
475
476/// Inlining of a block. This rewrite is immediately reflected in the IR.
477/// Note: This rewrite represents only the inlining of the operations. The
478/// erasure of the inlined block is a separate rewrite.
479class InlineBlockRewrite : public BlockRewrite {
480public:
481 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
482 Block *sourceBlock, Block::iterator before)
483 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
484 sourceBlock(sourceBlock),
485 firstInlinedInst(sourceBlock->empty() ? nullptr
486 : &sourceBlock->front()),
487 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
488 // If a listener is attached to the dialect conversion, ops must be moved
489 // one-by-one. When they are moved in bulk, notifications cannot be sent
490 // because the ops that used to be in the source block at the time of the
491 // inlining (before the "commit" phase) are unknown at the time when
492 // notifications are sent (which is during the "commit" phase).
493 assert(!getConfig().listener &&
494 "InlineBlockRewrite not supported if listener is attached");
495 }
496
497 static bool classof(const IRRewrite *rewrite) {
498 return rewrite->getKind() == Kind::InlineBlock;
499 }
500
501 void rollback() override {
502 // Put the operations from the destination block (owned by the rewrite)
503 // back into the source block.
504 if (firstInlinedInst) {
505 assert(lastInlinedInst && "expected operation");
506 sourceBlock->getOperations().splice(sourceBlock->begin(),
507 block->getOperations(),
508 Block::iterator(firstInlinedInst),
509 ++Block::iterator(lastInlinedInst));
510 }
511 }
512
513private:
514 // The block that originally contained the operations.
515 Block *sourceBlock;
516
517 // The first inlined operation.
518 Operation *firstInlinedInst;
519
520 // The last inlined operation.
521 Operation *lastInlinedInst;
522};
523
524/// Moving of a block. This rewrite is immediately reflected in the IR.
525class MoveBlockRewrite : public BlockRewrite {
526public:
527 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
528 Region *previousRegion, Region::iterator previousIt)
529 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
530 region(previousRegion),
531 insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
532 : &*previousIt) {}
533
534 static bool classof(const IRRewrite *rewrite) {
535 return rewrite->getKind() == Kind::MoveBlock;
536 }
537
538 void commit(RewriterBase &rewriter) override {
539 // The block was already moved. Just inform the listener.
540 if (auto *listener = rewriter.getListener()) {
541 // Note: `previousIt` cannot be passed because this is a delayed
542 // notification and iterators into past IR state cannot be represented.
543 listener->notifyBlockInserted(block, /*previous=*/region,
544 /*previousIt=*/{});
545 }
546 }
547
548 void rollback() override {
549 // Move the block back to its original position.
550 Region::iterator before =
551 insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
552 if (Region *currentParent = block->getParent()) {
553 // Block is still in a region, use cheap splice to move it back.
554 region->getBlocks().splice(before, currentParent->getBlocks(), block);
555 return;
556 }
557 // Block was orphaned by a prior rollback, can't splice.
558 region->getBlocks().insert(before, block);
559 }
560
561private:
562 // The region in which this block was previously contained.
563 Region *region;
564
565 // The original successor of this block before it was moved. "nullptr" if
566 // this block was the only block in the region.
567 Block *insertBeforeBlock;
568};
569
570/// Block type conversion. This rewrite is partially reflected in the IR.
571class BlockTypeConversionRewrite : public BlockRewrite {
572public:
573 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
574 Block *origBlock, Block *newBlock)
575 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
576 newBlock(newBlock) {}
577
578 static bool classof(const IRRewrite *rewrite) {
579 return rewrite->getKind() == Kind::BlockTypeConversion;
580 }
581
582 Block *getOrigBlock() const { return block; }
583
584 Block *getNewBlock() const { return newBlock; }
585
586 void commit(RewriterBase &rewriter) override;
587
588 void rollback() override;
589
590private:
591 /// The new block that was created as part of this signature conversion.
592 Block *newBlock;
593};
594
595/// Replacing a value. This rewrite is not immediately reflected in the
596/// IR. An internal IR mapping is updated, but the actual replacement is delayed
597/// until the rewrite is committed.
598class ReplaceValueRewrite : public ValueRewrite {
599public:
600 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
601 const TypeConverter *converter)
602 : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
603 converter(converter) {}
604
605 static bool classof(const IRRewrite *rewrite) {
606 return rewrite->getKind() == Kind::ReplaceValue;
607 }
608
609 void commit(RewriterBase &rewriter) override;
610
611 void rollback() override;
612
613private:
614 /// The current type converter when the value was replaced.
615 const TypeConverter *converter;
616};
617
618/// An operation rewrite.
619class OperationRewrite : public IRRewrite {
620public:
621 /// Return the operation that this rewrite operates on.
622 Operation *getOperation() const { return op; }
623
624 static bool classof(const IRRewrite *rewrite) {
625 return rewrite->getKind() >= Kind::MoveOperation &&
626 rewrite->getKind() <= Kind::UnresolvedMaterialization;
627 }
628
629protected:
630 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
631 Operation *op)
632 : IRRewrite(kind, rewriterImpl), op(op) {}
633
634 // The operation that this rewrite operates on.
635 Operation *op;
636};
637
638/// Moving of an operation. This rewrite is immediately reflected in the IR.
639class MoveOperationRewrite : public OperationRewrite {
640public:
641 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
642 Operation *op, OpBuilder::InsertPoint previous)
643 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
644 block(previous.getBlock()),
645 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
646 ? nullptr
647 : &*previous.getPoint()) {}
648
649 static bool classof(const IRRewrite *rewrite) {
650 return rewrite->getKind() == Kind::MoveOperation;
651 }
652
653 void commit(RewriterBase &rewriter) override {
654 // The operation was already moved. Just inform the listener.
655 if (auto *listener = rewriter.getListener()) {
656 // Note: `previousIt` cannot be passed because this is a delayed
657 // notification and iterators into past IR state cannot be represented.
658 listener->notifyOperationInserted(
659 op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
660 /*insertPt=*/{}));
661 }
662 }
663
664 void rollback() override {
665 // Move the operation back to its original position.
666 Block::iterator before =
667 insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
668 block->getOperations().splice(before, op->getBlock()->getOperations(), op);
669 }
670
671private:
672 // The block in which this operation was previously contained.
673 Block *block;
674
675 // The original successor of this operation before it was moved. "nullptr"
676 // if this operation was the only operation in the region.
677 Operation *insertBeforeOp;
678};
679
680/// In-place modification of an op. This rewrite is immediately reflected in
681/// the IR. The previous state of the operation is stored in this object.
682class ModifyOperationRewrite : public OperationRewrite {
683public:
684 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
685 Operation *op)
686 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
687 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
688 operands(op->operand_begin(), op->operand_end()),
689 successors(op->successor_begin(), op->successor_end()) {
690 if (OpaqueProperties prop = op->getPropertiesStorage()) {
691 // Make a copy of the properties.
692 propertiesStorage = operator new(op->getPropertiesStorageSize());
693 OpaqueProperties propCopy(propertiesStorage);
694 name.initOpProperties(propCopy, /*init=*/prop);
695 }
696 }
697
698 static bool classof(const IRRewrite *rewrite) {
699 return rewrite->getKind() == Kind::ModifyOperation;
700 }
701
702 ~ModifyOperationRewrite() override {
703 assert(!propertiesStorage &&
704 "rewrite was neither committed nor rolled back");
705 }
706
707 void commit(RewriterBase &rewriter) override {
708 // Notify the listener that the operation was modified in-place.
709 if (auto *listener =
710 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
711 listener->notifyOperationModified(op);
712
713 if (propertiesStorage) {
714 OpaqueProperties propCopy(propertiesStorage);
715 // Note: The operation may have been erased in the mean time, so
716 // OperationName must be stored in this object.
717 name.destroyOpProperties(propCopy);
718 operator delete(propertiesStorage);
719 propertiesStorage = nullptr;
720 }
721 }
722
723 void rollback() override {
724 op->setLoc(loc);
725 op->setAttrs(attrs);
726 op->setOperands(operands);
727 for (const auto &it : llvm::enumerate(successors))
728 op->setSuccessor(it.value(), it.index());
729 if (propertiesStorage) {
730 OpaqueProperties propCopy(propertiesStorage);
731 op->copyProperties(propCopy);
732 name.destroyOpProperties(propCopy);
733 operator delete(propertiesStorage);
734 propertiesStorage = nullptr;
735 }
736 }
737
738private:
739 OperationName name;
740 LocationAttr loc;
741 DictionaryAttr attrs;
742 SmallVector<Value, 8> operands;
743 SmallVector<Block *, 2> successors;
744 void *propertiesStorage = nullptr;
745};
746
747/// Replacing an operation. Erasing an operation is treated as a special case
748/// with "null" replacements. This rewrite is not immediately reflected in the
749/// IR. An internal IR mapping is updated, but values are not replaced and the
750/// original op is not erased until the rewrite is committed.
751class ReplaceOperationRewrite : public OperationRewrite {
752public:
753 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
754 Operation *op, const TypeConverter *converter)
755 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
756 converter(converter) {}
757
758 static bool classof(const IRRewrite *rewrite) {
759 return rewrite->getKind() == Kind::ReplaceOperation;
760 }
761
762 void commit(RewriterBase &rewriter) override;
763
764 void rollback() override;
765
766 void cleanup(RewriterBase &rewriter) override;
767
768private:
769 /// An optional type converter that can be used to materialize conversions
770 /// between the new and old values if necessary.
771 const TypeConverter *converter;
772};
773
774class CreateOperationRewrite : public OperationRewrite {
775public:
776 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
777 Operation *op)
778 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
779
780 static bool classof(const IRRewrite *rewrite) {
781 return rewrite->getKind() == Kind::CreateOperation;
782 }
783
784 void commit(RewriterBase &rewriter) override {
785 // The operation was already created and inserted. Just inform the listener.
786 if (auto *listener = rewriter.getListener())
787 listener->notifyOperationInserted(op, /*previous=*/{});
788 }
789
790 void rollback() override;
791};
792
793/// The type of materialization.
794enum MaterializationKind {
795 /// This materialization materializes a conversion from an illegal type to a
796 /// legal one.
797 Target,
798
799 /// This materialization materializes a conversion from a legal type back to
800 /// an illegal one.
801 Source
802};
803
804/// Helper class that stores metadata about an unresolved materialization.
805class UnresolvedMaterializationInfo {
806public:
807 UnresolvedMaterializationInfo() = default;
808 UnresolvedMaterializationInfo(const TypeConverter *converter,
809 MaterializationKind kind, Type originalType)
810 : converterAndKind(converter, kind), originalType(originalType) {}
811
812 /// Return the type converter of this materialization (which may be null).
813 const TypeConverter *getConverter() const {
814 return converterAndKind.getPointer();
815 }
816
817 /// Return the kind of this materialization.
818 MaterializationKind getMaterializationKind() const {
819 return converterAndKind.getInt();
820 }
821
822 /// Return the original type of the SSA value.
823 Type getOriginalType() const { return originalType; }
824
825private:
826 /// The corresponding type converter to use when resolving this
827 /// materialization, and the kind of this materialization.
828 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
829 converterAndKind;
830
831 /// The original type of the SSA value. Only used for target
832 /// materializations.
833 Type originalType;
834};
835
836/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
837/// op. Unresolved materializations fold away or are replaced with
838/// source/target materializations at the end of the dialect conversion.
839class UnresolvedMaterializationRewrite : public OperationRewrite {
840public:
841 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
842 UnrealizedConversionCastOp op,
843 ValueVector mappedValues)
844 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
845 mappedValues(std::move(mappedValues)) {}
846
847 static bool classof(const IRRewrite *rewrite) {
848 return rewrite->getKind() == Kind::UnresolvedMaterialization;
849 }
850
851 void rollback() override;
852
853 UnrealizedConversionCastOp getOperation() const {
854 return cast<UnrealizedConversionCastOp>(op);
855 }
856
857private:
858 /// The values in the conversion value mapping that are being replaced by the
859 /// results of this unresolved materialization.
860 ValueVector mappedValues;
861};
862} // namespace
863
864#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
865/// Return "true" if there is an operation rewrite that matches the specified
866/// rewrite type and operation among the given rewrites.
867template <typename RewriteTy, typename R>
868static bool hasRewrite(R &&rewrites, Operation *op) {
869 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
870 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
871 return rewriteTy && rewriteTy->getOperation() == op;
872 });
873}
874
875/// Return "true" if there is a block rewrite that matches the specified
876/// rewrite type and block among the given rewrites.
877template <typename RewriteTy, typename R>
878static bool hasRewrite(R &&rewrites, Block *block) {
879 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
880 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
881 return rewriteTy && rewriteTy->getBlock() == block;
882 });
883}
884#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
885
886//===----------------------------------------------------------------------===//
887// ConversionPatternRewriterImpl
888//===----------------------------------------------------------------------===//
889namespace mlir {
890namespace detail {
892 explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
893 const ConversionConfig &config,
897
898 //===--------------------------------------------------------------------===//
899 // State Management
900 //===--------------------------------------------------------------------===//
901
902 /// Return the current state of the rewriter.
903 RewriterState getCurrentState();
904
905 /// Apply all requested operation rewrites. This method is invoked when the
906 /// conversion process succeeds.
907 void applyRewrites();
908
909 /// Reset the state of the rewriter to a previously saved point. Optionally,
910 /// the name of the pattern that triggered the rollback can specified for
911 /// debugging purposes.
912 void resetState(RewriterState state, StringRef patternName = "");
913
914 /// Append a rewrite. Rewrites are committed upon success and rolled back upon
915 /// failure.
916 template <typename RewriteTy, typename... Args>
917 void appendRewrite(Args &&...args) {
918 assert(config.allowPatternRollback && "appending rewrites is not allowed");
919 rewrites.push_back(
920 std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
921 }
922
923 /// Undo the rewrites (motions, splits) one by one in reverse order until
924 /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
925 /// that triggered the rollback can specified for debugging purposes.
926 void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
927
928 /// Remap the given values to those with potentially different types. Returns
929 /// success if the values could be remapped, failure otherwise. `valueDiagTag`
930 /// is the tag used when describing a value within a diagnostic, e.g.
931 /// "operand".
932 LogicalResult remapValues(StringRef valueDiagTag,
933 std::optional<Location> inputLoc, ValueRange values,
934 SmallVector<ValueVector> &remapped);
935
936 /// Return "true" if the given operation is ignored, and does not need to be
937 /// converted.
938 bool isOpIgnored(Operation *op) const;
939
940 /// Return "true" if the given operation was replaced or erased.
941 bool wasOpReplaced(Operation *op) const;
942
943 /// Lookup the most recently mapped values with the desired types in the
944 /// mapping, taking into account only replacements. Perform a best-effort
945 /// search for existing materializations with the desired types.
946 ///
947 /// If `skipPureTypeConversions` is "true", materializations that are pure
948 /// type conversions are not considered.
949 ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
950 bool skipPureTypeConversions = false) const;
951
952 /// Lookup the given value within the map, or return an empty vector if the
953 /// value is not mapped. If it is mapped, this follows the same behavior
954 /// as `lookupOrDefault`.
955 ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
956
957 //===--------------------------------------------------------------------===//
958 // IR Rewrites / Type Conversion
959 //===--------------------------------------------------------------------===//
960
961 /// Convert the types of block arguments within the given region.
962 FailureOr<Block *>
963 convertRegionTypes(Region *region, const TypeConverter &converter,
964 TypeConverter::SignatureConversion *entryConversion);
965
966 /// Apply the given signature conversion on the given block. The new block
967 /// containing the updated signature is returned. If no conversions were
968 /// necessary, e.g. if the block has no arguments, `block` is returned.
969 /// `converter` is used to generate any necessary cast operations that
970 /// translate between the origin argument types and those specified in the
971 /// signature conversion.
972 Block *applySignatureConversion(
973 Block *block, const TypeConverter *converter,
974 TypeConverter::SignatureConversion &signatureConversion);
975
976 /// Replace the results of the given operation with the given values and
977 /// erase the operation.
978 ///
979 /// There can be multiple replacement values for each result (1:N
980 /// replacement). If the replacement values are empty, the respective result
981 /// is dropped and a source materialization is built if the result still has
982 /// uses.
983 void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
984
985 /// Replace the uses of the given value with the given values. The specified
986 /// converter is used to build materializations (if necessary). If `functor`
987 /// is specified, only the uses that the functor returns "true" for are
988 /// replaced.
989 void replaceValueUses(Value from, ValueRange to,
990 const TypeConverter *converter,
991 function_ref<bool(OpOperand &)> functor = nullptr);
992
993 /// Erase the given block and its contents.
994 void eraseBlock(Block *block);
995
996 /// Inline the source block into the destination block before the given
997 /// iterator.
998 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
999
1000 //===--------------------------------------------------------------------===//
1001 // Materializations
1002 //===--------------------------------------------------------------------===//
1003
1004 /// Build an unresolved materialization operation given a range of output
1005 /// types and a list of input operands. Returns the inputs if they their
1006 /// types match the output types.
1007 ///
1008 /// If a cast op was built, it can optionally be returned with the `castOp`
1009 /// output argument.
1010 ///
1011 /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
1012 /// the results of the unresolved materialization in the conversion value
1013 /// mapping.
1014 ///
1015 /// If `isPureTypeConversion` is "true", the materialization is created only
1016 /// to resolve a type mismatch. That means it is not a regular value
1017 /// replacement issued by the user. (Replacement values that are created
1018 /// "out of thin air" appear like unresolved materializations because they are
1019 /// unrealized_conversion_cast ops. However, they must be treated like
1020 /// regular value replacements.)
1021 ValueRange buildUnresolvedMaterialization(
1022 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1023 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1024 Type originalType, const TypeConverter *converter,
1025 bool isPureTypeConversion = true);
1026
1027 /// Find a replacement value for the given SSA value in the conversion value
1028 /// mapping. The replacement value must have the same type as the given SSA
1029 /// value. If there is no replacement value with the correct type, find the
1030 /// latest replacement value (regardless of the type) and build a source
1031 /// materialization.
1032 Value findOrBuildReplacementValue(Value value,
1033 const TypeConverter *converter);
1034
1035 //===--------------------------------------------------------------------===//
1036 // Rewriter Notification Hooks
1037 //===--------------------------------------------------------------------===//
1038
1039 //// Notifies that an op was inserted.
1040 void notifyOperationInserted(Operation *op,
1041 OpBuilder::InsertPoint previous) override;
1042
1043 /// Notifies that a block was inserted.
1044 void notifyBlockInserted(Block *block, Region *previous,
1045 Region::iterator previousIt) override;
1046
1047 /// Notifies that a pattern match failed for the given reason.
1048 void
1049 notifyMatchFailure(Location loc,
1050 function_ref<void(Diagnostic &)> reasonCallback) override;
1051
1052 //===--------------------------------------------------------------------===//
1053 // IR Erasure
1054 //===--------------------------------------------------------------------===//
1055
1056 /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1057 /// operation or block is erased multiple times. This rewriter assumes that
1058 /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1060 public:
1063 std::function<void(Operation *)> opErasedCallback = nullptr)
1064 : RewriterBase(context, /*listener=*/this),
1065 opErasedCallback(std::move(opErasedCallback)) {}
1066
1067 /// Erase the given op (unless it was already erased).
1068 void eraseOp(Operation *op) override {
1069 if (wasErased(op))
1070 return;
1071 op->dropAllUses();
1073 }
1074
1075 /// Erase the given block (unless it was already erased).
1076 void eraseBlock(Block *block) override {
1077 if (wasErased(block))
1078 return;
1079 assert(block->empty() && "expected empty block");
1080 block->dropAllDefinedValueUses();
1082 }
1083
1084 bool wasErased(void *ptr) const { return erased.contains(ptr); }
1085
1087 erased.insert(op);
1088 if (opErasedCallback)
1089 opErasedCallback(op);
1090 }
1091
1092 void notifyBlockErased(Block *block) override { erased.insert(block); }
1093
1094 private:
1095 /// Pointers to all erased operations and blocks.
1096 DenseSet<void *> erased;
1097
1098 /// A callback that is invoked when an operation is erased.
1099 std::function<void(Operation *)> opErasedCallback;
1100 };
1101
1102 //===--------------------------------------------------------------------===//
1103 // State
1104 //===--------------------------------------------------------------------===//
1105
1106 /// The rewriter that is used to perform the conversion.
1107 ConversionPatternRewriter &rewriter;
1108
1109 // Mapping between replaced values that differ in type. This happens when
1110 // replacing a value with one of a different type.
1111 ConversionValueMapping mapping;
1112
1113 /// Ordered list of block operations (creations, splits, motions).
1114 /// This vector is maintained only if `allowPatternRollback` is set to
1115 /// "true". Otherwise, all IR rewrites are materialized immediately and no
1116 /// bookkeeping is needed.
1118
1119 /// A set of operations that should no longer be considered for legalization.
1120 /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1121 /// tracked separately.
1123
1124 /// A set of operations that were replaced/erased. Such ops are not erased
1125 /// immediately but only when the dialect conversion succeeds. In the mean
1126 /// time, they should no longer be considered for legalization and any attempt
1127 /// to modify/access them is invalid rewriter API usage.
1129
1130 /// A set of operations that were created by the current pattern.
1132
1133 /// A set of operations that were modified by the current pattern.
1135
1136 /// A list of unresolved materializations that were created by the current
1137 /// pattern.
1139
1140 /// A mapping for looking up metadata of unresolved materializations.
1143
1144 /// The current type converter, or nullptr if no type converter is currently
1145 /// active.
1147
1148 /// A mapping of regions to type converters that should be used when
1149 /// converting the arguments of blocks within that region.
1151
1152 /// Dialect conversion configuration.
1153 const ConversionConfig &config;
1154
1155 /// The operation converter to use for recursive legalization.
1157
1158 /// A set of erased operations. This set is utilized only if
1159 /// `allowPatternRollback` is set to "false". Conceptually, this set is
1160 /// similar to `replacedOps` (which is maintained when the flag is set to
1161 /// "true"). However, erasing from a DenseSet is more efficient than erasing
1162 /// from a SetVector.
1164
1165 /// A set of erased blocks. This set is utilized only if
1166 /// `allowPatternRollback` is set to "false".
1168
1169 /// A rewriter that notifies the listener (if any) about all IR
1170 /// modifications. This rewriter is utilized only if `allowPatternRollback`
1171 /// is set to "false". If the flag is set to "true", the listener is notified
1172 /// with a separate mechanism (e.g., in `IRRewrite::commit`).
1174
1175#ifndef NDEBUG
1176 /// A set of replaced values. This set is for debugging purposes only and it
1177 /// is maintained only if `allowPatternRollback` is set to "true".
1179
1180 /// A set of operations that have pending updates. This tracking isn't
1181 /// strictly necessary, and is thus only active during debug builds for extra
1182 /// verification.
1184
1185 /// A raw output stream used to prefix the debug log.
1186 llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
1187 llvm::dbgs()};
1188
1189 /// A logger used to emit diagnostics during the conversion process.
1190 llvm::ScopedPrinter logger{os};
1191 std::string logPrefix;
1192#endif
1193};
1194} // namespace detail
1195} // namespace mlir
1196
1197const ConversionConfig &IRRewrite::getConfig() const {
1198 return rewriterImpl.config;
1199}
1200
1201void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1202 // Inform the listener about all IR modifications that have already taken
1203 // place: References to the original block have been replaced with the new
1204 // block.
1205 if (auto *listener =
1206 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1207 for (Operation *op : getNewBlock()->getUsers())
1208 listener->notifyOperationModified(op);
1209}
1210
1211void BlockTypeConversionRewrite::rollback() {
1212 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1213}
1214
1215/// Replace all uses of `from` with `repl`.
1216static void
1218 function_ref<bool(OpOperand &)> functor = nullptr) {
1219 if (isa<BlockArgument>(repl)) {
1220 // `repl` is a block argument. Directly replace all uses.
1221 if (functor) {
1222 rewriter.replaceUsesWithIf(from, repl, functor);
1223 } else {
1224 rewriter.replaceAllUsesWith(from, repl);
1225 }
1226 return;
1227 }
1228
1229 // If the replacement value is an operation, only replace those uses that:
1230 // - are in a different block than the replacement operation, or
1231 // - are in the same block but after the replacement operation.
1232 //
1233 // Example:
1234 // ^bb0(%arg0: i32):
1235 // %0 = "consumer"(%arg0) : (i32) -> (i32)
1236 // "another_consumer"(%arg0) : (i32) -> ()
1237 //
1238 // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1239 // use in "another_consumer" but not the use in "consumer". When using the
1240 // normal RewriterBase API, this would typically be done with
1241 // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1242 // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1243 // it cannot be supported efficiently with `allowPatternRollback` set to
1244 // "true". Therefore, the conversion driver is trying to be smart and replaces
1245 // only those uses that do not lead to a dominance violation. E.g., the
1246 // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1247 // behavior.
1248 //
1249 // TODO: As we move more and more towards `allowPatternRollback` set to
1250 // "false", we should remove this special handling, in order to align the
1251 // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1252 Operation *replOp = repl.getDefiningOp();
1253 Block *replBlock = replOp->getBlock();
1254 rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
1255 Operation *user = operand.getOwner();
1256 bool result =
1257 user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1258 if (result && functor)
1259 result &= functor(operand);
1260 return result;
1261 });
1262}
1263
1264void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1265 Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
1266 if (!repl)
1267 return;
1268 performReplaceValue(rewriter, value, repl);
1269}
1270
1271void ReplaceValueRewrite::rollback() {
1272 rewriterImpl.mapping.erase({value});
1273#ifndef NDEBUG
1274 rewriterImpl.replacedValues.erase(value);
1275#endif // NDEBUG
1276}
1277
1278void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1279 auto *listener =
1280 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1281
1282 // Compute replacement values.
1283 SmallVector<Value> replacements =
1284 llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1285 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1286 });
1287
1288 // Notify the listener that the operation is about to be replaced.
1289 if (listener)
1290 listener->notifyOperationReplaced(op, replacements);
1291
1292 // Replace all uses with the new values.
1293 for (auto [result, newValue] :
1294 llvm::zip_equal(op->getResults(), replacements))
1295 if (newValue)
1296 rewriter.replaceAllUsesWith(result, newValue);
1297
1298 // The original op will be erased, so remove it from the set of unlegalized
1299 // ops.
1300 if (getConfig().unlegalizedOps)
1301 getConfig().unlegalizedOps->erase(op);
1302
1303 // Notify the listener that the operation and its contents are being erased.
1304 if (listener)
1305 notifyIRErased(listener, *op);
1306
1307 // Do not erase the operation yet. It may still be referenced in `mapping`.
1308 // Just unlink it for now and erase it during cleanup.
1309 op->getBlock()->getOperations().remove(op);
1310}
1311
1312void ReplaceOperationRewrite::rollback() {
1313 for (auto result : op->getResults())
1314 rewriterImpl.mapping.erase({result});
1315}
1316
1317void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1318 rewriter.eraseOp(op);
1319}
1320
1321void CreateOperationRewrite::rollback() {
1322 for (Region &region : op->getRegions()) {
1323 while (!region.getBlocks().empty())
1324 region.getBlocks().remove(region.getBlocks().begin());
1325 }
1326 op->dropAllUses();
1327 op->erase();
1328}
1329
1330void UnresolvedMaterializationRewrite::rollback() {
1331 if (!mappedValues.empty())
1332 rewriterImpl.mapping.erase(mappedValues);
1333 rewriterImpl.unresolvedMaterializations.erase(getOperation());
1334 op->erase();
1335}
1336
1338 // Commit all rewrites. Use a new rewriter, so the modifications are not
1339 // tracked for rollback purposes etc.
1340 IRRewriter irRewriter(rewriter.getContext(), config.listener);
1341 // Note: New rewrites may be added during the "commit" phase and the
1342 // `rewrites` vector may reallocate.
1343 for (size_t i = 0; i < rewrites.size(); ++i)
1344 rewrites[i]->commit(irRewriter);
1345
1346 // Clean up all rewrites.
1347 SingleEraseRewriter eraseRewriter(
1348 rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
1349 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1350 unresolvedMaterializations.erase(castOp);
1351 });
1352 for (auto &rewrite : rewrites)
1353 rewrite->cleanup(eraseRewriter);
1354}
1355
1356//===----------------------------------------------------------------------===//
1357// State Management
1358//===----------------------------------------------------------------------===//
1359
1361 Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1362 // Helper function that looks up a single value.
1363 auto lookup = [&](const ValueVector &values) -> ValueVector {
1364 assert(!values.empty() && "expected non-empty value vector");
1365
1366 // If the pattern rollback is enabled, use the mapping to look up the
1367 // values.
1368 if (config.allowPatternRollback)
1369 return mapping.lookup(values);
1370
1371 // Otherwise, look up values by examining the IR. All replacements have
1372 // already been materialized in IR.
1373 Operation *op = getCommonDefiningOp(values);
1374 if (!op)
1375 return {};
1376 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1377 if (!castOp)
1378 return {};
1379 if (!this->unresolvedMaterializations.contains(castOp))
1380 return {};
1381 if (castOp.getOutputs() != values)
1382 return {};
1383 return castOp.getInputs();
1384 };
1385
1386 // Helper function that looks up each value in `values` individually and then
1387 // composes the results. If that fails, it tries to look up the entire vector
1388 // at once.
1389 auto composedLookup = [&](const ValueVector &values) -> ValueVector {
1390 // If possible, replace each value with (one or multiple) mapped values.
1391 ValueVector next;
1392 for (Value v : values) {
1393 ValueVector r = lookup({v});
1394 if (!r.empty()) {
1395 llvm::append_range(next, r);
1396 } else {
1397 next.push_back(v);
1398 }
1399 }
1400 if (next != values) {
1401 // At least one value was replaced.
1402 return next;
1403 }
1404
1405 // Otherwise: Check if there is a mapping for the entire vector. Such
1406 // mappings are materializations. (N:M mapping are not supported for value
1407 // replacements.)
1408 //
1409 // Note: From a correctness point of view, materializations do not have to
1410 // be stored (and looked up) in the mapping. But for performance reasons,
1411 // we choose to reuse existing IR (when possible) instead of creating it
1412 // multiple times.
1413 ValueVector r = lookup(values);
1414 if (r.empty()) {
1415 // No mapping found: The lookup stops here.
1416 return {};
1417 }
1418 return r;
1419 };
1420
1421 // Try to find the deepest values that have the desired types. If there is no
1422 // such mapping, simply return the deepest values.
1423 ValueVector desiredValue;
1424 ValueVector current{from};
1425 ValueVector lastNonMaterialization{from};
1426 do {
1427 // Store the current value if the types match.
1428 bool match = TypeRange(ValueRange(current)) == desiredTypes;
1429 if (skipPureTypeConversions) {
1430 // Skip pure type conversions, if requested.
1431 bool pureConversion = isPureTypeConversion(current);
1432 match &= !pureConversion;
1433 // Keep track of the last mapped value that was not a pure type
1434 // conversion.
1435 if (!pureConversion)
1436 lastNonMaterialization = current;
1437 }
1438 if (match)
1439 desiredValue = current;
1440
1441 // Lookup next value in the mapping.
1442 ValueVector next = composedLookup(current);
1443 if (next.empty())
1444 break;
1445 current = std::move(next);
1446 } while (true);
1447
1448 // If the desired values were found use them, otherwise default to the leaf
1449 // values. (Skip pure type conversions, if requested.)
1450 if (!desiredTypes.empty())
1451 return desiredValue;
1452 if (skipPureTypeConversions)
1453 return lastNonMaterialization;
1454 return current;
1455}
1456
1459 TypeRange desiredTypes) const {
1460 ValueVector result = lookupOrDefault(from, desiredTypes);
1461 if (result == ValueVector{from} ||
1462 (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
1463 return {};
1464 return result;
1465}
1466
1468 return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1469}
1470
1472 StringRef patternName) {
1473 // Undo any rewrites.
1474 undoRewrites(state.numRewrites, patternName);
1475
1476 // Pop all of the recorded ignored operations that are no longer valid.
1477 while (ignoredOps.size() != state.numIgnoredOperations)
1478 ignoredOps.pop_back();
1479
1480 while (replacedOps.size() != state.numReplacedOps)
1481 replacedOps.pop_back();
1482}
1483
1484void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1485 StringRef patternName) {
1486 for (auto &rewrite :
1487 llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1488 rewrite->rollback();
1489 rewrites.resize(numRewritesToKeep);
1490}
1491
1493 StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
1494 SmallVector<ValueVector> &remapped) {
1495 remapped.reserve(llvm::size(values));
1496
1497 for (const auto &it : llvm::enumerate(values)) {
1498 Value operand = it.value();
1499 Type origType = operand.getType();
1500 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1501
1502 if (!currentTypeConverter) {
1503 // The current pattern does not have a type converter. Pass the most
1504 // recently mapped values, excluding materializations. Materializations
1505 // are intentionally excluded because their presence may depend on other
1506 // patterns. Including materializations would make the lookup fragile
1507 // and unpredictable.
1508 remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1509 /*skipPureTypeConversions=*/true));
1510 continue;
1511 }
1512
1513 // If there is no legal conversion, fail to match this pattern.
1514 SmallVector<Type, 1> legalTypes;
1515 if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
1516 notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1517 diag << "unable to convert type for " << valueDiagTag << " #"
1518 << it.index() << ", type was " << origType;
1519 });
1520 return failure();
1521 }
1522 // If a type is converted to 0 types, there is nothing to do.
1523 if (legalTypes.empty()) {
1524 remapped.push_back({});
1525 continue;
1526 }
1527
1528 ValueVector repl = lookupOrDefault(operand, legalTypes);
1529 if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1530 // Mapped values have the correct type or there is an existing
1531 // materialization. Or the operand is not mapped at all and has the
1532 // correct type.
1533 remapped.push_back(std::move(repl));
1534 continue;
1535 }
1536
1537 // Create a materialization for the most recently mapped values.
1538 repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1539 /*skipPureTypeConversions=*/true);
1541 MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1542 /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1543 /*originalType=*/origType, currentTypeConverter);
1544 remapped.push_back(castValues);
1545 }
1546 return success();
1547}
1548
1550 // Check to see if this operation is ignored or was replaced.
1551 return wasOpReplaced(op) || ignoredOps.count(op);
1552}
1553
1555 // Check to see if this operation was replaced.
1556 return replacedOps.count(op) || erasedOps.count(op);
1557}
1558
1559//===----------------------------------------------------------------------===//
1560// Type Conversion
1561//===----------------------------------------------------------------------===//
1562
1564 Region *region, const TypeConverter &converter,
1565 TypeConverter::SignatureConversion *entryConversion) {
1566 regionToConverter[region] = &converter;
1567 if (region->empty())
1568 return nullptr;
1569
1570 // Convert the arguments of each non-entry block within the region.
1571 for (Block &block :
1572 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1573 // Compute the signature for the block with the provided converter.
1574 std::optional<TypeConverter::SignatureConversion> conversion =
1575 converter.convertBlockSignature(&block);
1576 if (!conversion)
1577 return failure();
1578 // Convert the block with the computed signature.
1579 applySignatureConversion(&block, &converter, *conversion);
1580 }
1581
1582 // Convert the entry block. If an entry signature conversion was provided,
1583 // use that one. Otherwise, compute the signature with the type converter.
1584 if (entryConversion)
1585 return applySignatureConversion(&region->front(), &converter,
1586 *entryConversion);
1587 std::optional<TypeConverter::SignatureConversion> conversion =
1588 converter.convertBlockSignature(&region->front());
1589 if (!conversion)
1590 return failure();
1591 return applySignatureConversion(&region->front(), &converter, *conversion);
1592}
1593
1595 Block *block, const TypeConverter *converter,
1596 TypeConverter::SignatureConversion &signatureConversion) {
1597#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1598 // A block cannot be converted multiple times.
1599 if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1600 llvm::reportFatalInternalError("block was already converted");
1601#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1602
1604
1605 // If no arguments are being changed or added, there is nothing to do.
1606 unsigned origArgCount = block->getNumArguments();
1607 auto convertedTypes = signatureConversion.getConvertedTypes();
1608 if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1609 return block;
1610
1611 // Compute the locations of all block arguments in the new block.
1612 SmallVector<Location> newLocs(convertedTypes.size(),
1613 rewriter.getUnknownLoc());
1614 for (unsigned i = 0; i < origArgCount; ++i) {
1615 auto inputMap = signatureConversion.getInputMapping(i);
1616 if (!inputMap || inputMap->replacedWithValues())
1617 continue;
1618 Location origLoc = block->getArgument(i).getLoc();
1619 for (unsigned j = 0; j < inputMap->size; ++j)
1620 newLocs[inputMap->inputNo + j] = origLoc;
1621 }
1622
1623 // Insert a new block with the converted block argument types and move all ops
1624 // from the old block to the new block.
1625 Block *newBlock =
1626 rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1627 convertedTypes, newLocs);
1628
1629 // If a listener is attached to the dialect conversion, ops cannot be moved
1630 // to the destination block in bulk ("fast path"). This is because at the time
1631 // the notifications are sent, it is unknown which ops were moved. Instead,
1632 // ops should be moved one-by-one ("slow path"), so that a separate
1633 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1634 // a bit more efficient, so we try to do that when possible.
1635 bool fastPath = !config.listener;
1636 if (fastPath) {
1637 if (config.allowPatternRollback)
1638 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1639 newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1640 } else {
1641 while (!block->empty())
1642 rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1643 }
1644
1645 // Replace all uses of the old block with the new block.
1646 block->replaceAllUsesWith(newBlock);
1647
1648 for (unsigned i = 0; i != origArgCount; ++i) {
1649 BlockArgument origArg = block->getArgument(i);
1650 Type origArgType = origArg.getType();
1651
1652 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1653 signatureConversion.getInputMapping(i);
1654 if (!inputMap) {
1655 // This block argument was dropped and no replacement value was provided.
1656 // Materialize a replacement value "out of thin air".
1657 // Note: Materialization must be built here because we cannot find a
1658 // valid insertion point in the new block. (Will point to the old block.)
1659 Value mat =
1661 MaterializationKind::Source,
1662 OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1663 origArg.getLoc(),
1664 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1665 /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1666 /*isPureTypeConversion=*/false)
1667 .front();
1668 replaceValueUses(origArg, mat, converter);
1669 continue;
1670 }
1671
1672 if (inputMap->replacedWithValues()) {
1673 // This block argument was dropped and replacement values were provided.
1674 assert(inputMap->size == 0 &&
1675 "invalid to provide a replacement value when the argument isn't "
1676 "dropped");
1677 replaceValueUses(origArg, inputMap->replacementValues, converter);
1678 continue;
1679 }
1680
1681 // This is a 1->1+ mapping.
1682 auto replArgs =
1683 newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1684 replaceValueUses(origArg, replArgs, converter);
1685 }
1686
1687 if (config.allowPatternRollback)
1688 appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1689
1690 // Erase the old block. (It is just unlinked for now and will be erased during
1691 // cleanup.)
1692 rewriter.eraseBlock(block);
1693
1694 return newBlock;
1695}
1696
1697//===----------------------------------------------------------------------===//
1698// Materializations
1699//===----------------------------------------------------------------------===//
1700
1701/// Build an unresolved materialization operation given an output type and set
1702/// of input operands.
1704 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1705 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1706 Type originalType, const TypeConverter *converter,
1707 bool isPureTypeConversion) {
1708 assert((!originalType || kind == MaterializationKind::Target) &&
1709 "original type is valid only for target materializations");
1710 assert(TypeRange(inputs) != outputTypes &&
1711 "materialization is not necessary");
1712
1713 // Create an unresolved materialization. We use a new OpBuilder to avoid
1714 // tracking the materialization like we do for other operations.
1715 OpBuilder builder(outputTypes.front().getContext());
1716 builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1717 UnrealizedConversionCastOp convertOp =
1718 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1719 if (config.attachDebugMaterializationKind) {
1720 StringRef kindStr =
1721 kind == MaterializationKind::Source ? "source" : "target";
1722 convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
1723 }
1725 convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
1726
1727 // Register the materialization.
1728 unresolvedMaterializations[convertOp] =
1729 UnresolvedMaterializationInfo(converter, kind, originalType);
1730 if (config.allowPatternRollback) {
1731 if (!valuesToMap.empty())
1732 mapping.map(valuesToMap, convertOp.getResults());
1734 std::move(valuesToMap));
1735 } else {
1736 patternMaterializations.insert(convertOp);
1737 }
1738 return convertOp.getResults();
1739}
1740
1742 Value value, const TypeConverter *converter) {
1743 assert(config.allowPatternRollback &&
1744 "this code path is valid only in rollback mode");
1745
1746 // Try to find a replacement value with the same type in the conversion value
1747 // mapping. This includes cached materializations. We try to reuse those
1748 // instead of generating duplicate IR.
1749 ValueVector repl = lookupOrNull(value, value.getType());
1750 if (!repl.empty())
1751 return repl.front();
1752
1753 // Check if the value is dead. No replacement value is needed in that case.
1754 // This is an approximate check that may have false negatives but does not
1755 // require computing and traversing an inverse mapping. (We may end up
1756 // building source materializations that are never used and that fold away.)
1757 if (llvm::all_of(value.getUsers(),
1758 [&](Operation *op) { return replacedOps.contains(op); }) &&
1759 !mapping.isMappedTo(value))
1760 return Value();
1761
1762 // No replacement value was found. Get the latest replacement value
1763 // (regardless of the type) and build a source materialization to the
1764 // original type.
1765 repl = lookupOrNull(value);
1766
1767 // Compute the insertion point of the materialization.
1769 if (repl.empty()) {
1770 // The source materialization has no inputs. Insert it right before the
1771 // value that it is replacing.
1772 ip = computeInsertPoint(value);
1773 } else {
1774 // Compute the "earliest" insertion point at which all values in `repl` are
1775 // defined. It is important to emit the materialization at that location
1776 // because the same materialization may be reused in a different context.
1777 // (That's because materializations are cached in the conversion value
1778 // mapping.) The insertion point of the materialization must be valid for
1779 // all future users that may be created later in the conversion process.
1780 ip = computeInsertPoint(repl);
1781 }
1783 MaterializationKind::Source, ip, value.getLoc(),
1784 /*valuesToMap=*/repl, /*inputs=*/repl,
1785 /*outputTypes=*/value.getType(),
1786 /*originalType=*/Type(), converter,
1787 /*isPureTypeConversion=*/!repl.empty())
1788 .front();
1789 return castValue;
1790}
1791
1792//===----------------------------------------------------------------------===//
1793// Rewriter Notification Hooks
1794//===----------------------------------------------------------------------===//
1795
1797 Operation *op, OpBuilder::InsertPoint previous) {
1798 // If no previous insertion point is provided, the op used to be detached.
1799 bool wasDetached = !previous.isSet();
1800 LLVM_DEBUG({
1801 logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
1802 << ")";
1803 if (wasDetached)
1804 logger.getOStream() << " (was detached)";
1805 logger.getOStream() << "\n";
1806 });
1807
1808 // In rollback mode, it is easier to misuse the API, so perform extra error
1809 // checking.
1810 assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
1811 "attempting to insert into a block within a replaced/erased op");
1812
1813 // In "no rollback" mode, the listener is always notified immediately.
1814 if (!config.allowPatternRollback && config.listener)
1815 config.listener->notifyOperationInserted(op, previous);
1816
1817 if (wasDetached) {
1818 // If the op was detached, it is most likely a newly created op. Add it the
1819 // set of newly created ops, so that it will be legalized. If this op is
1820 // not a newly created op, it will be legalized a second time, which is
1821 // inefficient but harmless.
1822 patternNewOps.insert(op);
1823
1824 if (config.allowPatternRollback) {
1825 // TODO: If the same op is inserted multiple times from a detached
1826 // state, the rollback mechanism may erase the same op multiple times.
1827 // This is a bug in the rollback-based dialect conversion driver.
1829 } else {
1830 // In "no rollback" mode, there is an extra data structure for tracking
1831 // erased operations that must be kept up to date.
1832 erasedOps.erase(op);
1833 }
1834 return;
1835 }
1836
1837 // The op was moved from one place to another.
1838 if (config.allowPatternRollback)
1840}
1841
1842/// Given that `fromRange` is about to be replaced with `toRange`, compute
1843/// replacement values with the types of `fromRange`.
1844static SmallVector<Value>
1846 const SmallVector<SmallVector<Value>> &toRange,
1847 const TypeConverter *converter) {
1848 assert(!impl.config.allowPatternRollback &&
1849 "this code path is valid only in 'no rollback' mode");
1850 SmallVector<Value> repls;
1851 for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1852 if (from.use_empty()) {
1853 // The replaced value is dead. No replacement value is needed.
1854 repls.push_back(Value());
1855 continue;
1856 }
1857
1858 if (to.empty()) {
1859 // The replaced value is dropped. Materialize a replacement value "out of
1860 // thin air".
1861 Value srcMat = impl.buildUnresolvedMaterialization(
1862 MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
1863 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1864 /*outputTypes=*/from.getType(), /*originalType=*/Type(),
1865 converter)[0];
1866 repls.push_back(srcMat);
1867 continue;
1868 }
1869
1870 if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
1871 // The replacement value already has the correct type. Use it directly.
1872 repls.push_back(to[0]);
1873 continue;
1874 }
1875
1876 // The replacement value has the wrong type. Build a source materialization
1877 // to the original type.
1878 // TODO: This is a bit inefficient. We should try to reuse existing
1879 // materializations if possible. This would require an extension of the
1880 // `lookupOrDefault` API.
1881 Value srcMat = impl.buildUnresolvedMaterialization(
1882 MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
1883 /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
1884 /*originalType=*/Type(), converter)[0];
1885 repls.push_back(srcMat);
1886 }
1887
1888 return repls;
1889}
1890
1892 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1893 assert(newValues.size() == op->getNumResults() &&
1894 "incorrect number of replacement values");
1895 LLVM_DEBUG({
1896 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
1897 << ")\n";
1899 // If the user-provided replacement types are different from the
1900 // legalized types, as per the current type converter, print a note.
1901 // In most cases, the replacement types are expected to match the types
1902 // produced by the type converter, so this could indicate a bug in the
1903 // user code.
1904 for (auto [result, repls] :
1905 llvm::zip_equal(op->getResults(), newValues)) {
1906 Type resultType = result.getType();
1907 auto logProlog = [&, repls = repls]() {
1908 logger.startLine() << " Note: Replacing op result of type "
1909 << resultType << " with value(s) of type (";
1910 llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) {
1911 logger.getOStream() << v.getType();
1912 });
1913 logger.getOStream() << ")";
1914 };
1915 SmallVector<Type> convertedTypes;
1916 if (failed(currentTypeConverter->convertTypes(resultType,
1917 convertedTypes))) {
1918 logProlog();
1919 logger.getOStream() << ", but the type converter failed to legalize "
1920 "the original type.\n";
1921 continue;
1922 }
1923 if (TypeRange(convertedTypes) != TypeRange(ValueRange(repls))) {
1924 logProlog();
1925 logger.getOStream() << ", but the legalized type(s) is/are (";
1926 llvm::interleaveComma(convertedTypes, logger.getOStream(),
1927 [&](Type t) { logger.getOStream() << t; });
1928 logger.getOStream() << ")\n";
1929 }
1930 }
1931 }
1932 });
1933
1934 if (!config.allowPatternRollback) {
1935 // Pattern rollback is not allowed: materialize all IR changes immediately.
1937 *this, op->getResults(), newValues, currentTypeConverter);
1938 // Update internal data structures, so that there are no dangling pointers
1939 // to erased IR.
1940 op->walk([&](Operation *op) {
1941 erasedOps.insert(op);
1942 ignoredOps.remove(op);
1943 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1944 unresolvedMaterializations.erase(castOp);
1945 patternMaterializations.erase(castOp);
1946 }
1947 // The original op will be erased, so remove it from the set of
1948 // unlegalized ops.
1949 if (config.unlegalizedOps)
1950 config.unlegalizedOps->erase(op);
1951 });
1952 op->walk([&](Block *block) { erasedBlocks.insert(block); });
1953 // Replace the op with the replacement values and notify the listener.
1954 notifyingRewriter.replaceOp(op, repls);
1955 return;
1956 }
1957
1958 assert(!ignoredOps.contains(op) && "operation was already replaced");
1959#ifndef NDEBUG
1960 for (Value v : op->getResults())
1961 assert(!replacedValues.contains(v) &&
1962 "attempting to replace a value that was already replaced");
1963#endif // NDEBUG
1964
1965 // Check if replaced op is an unresolved materialization, i.e., an
1966 // unrealized_conversion_cast op that was created by the conversion driver.
1967 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1968 // Make sure that the user does not mess with unresolved materializations
1969 // that were inserted by the conversion driver. We keep track of these
1970 // ops in internal data structures.
1971 assert(!unresolvedMaterializations.contains(castOp) &&
1972 "attempting to replace/erase an unresolved materialization");
1973 }
1974
1975 // Create mappings for each of the new result values.
1976 for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
1977 mapping.map(static_cast<Value>(result), std::move(repl));
1978
1980 // Mark this operation and all nested ops as replaced.
1981 op->walk([&](Operation *op) { replacedOps.insert(op); });
1982}
1983
1985 Value from, ValueRange to, const TypeConverter *converter,
1986 function_ref<bool(OpOperand &)> functor) {
1987 LLVM_DEBUG({
1988 logger.startLine() << "** Replace Value : '" << from << "'";
1989 if (auto blockArg = dyn_cast<BlockArgument>(from)) {
1990 if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
1991 logger.getOStream() << " (in region of '" << parentOp->getName()
1992 << "' (" << parentOp << ")";
1993 } else {
1994 logger.getOStream() << " (unlinked block)";
1995 }
1996 }
1997 if (functor) {
1998 logger.getOStream() << ", conditional replacement";
1999 }
2000 });
2001
2002 if (!config.allowPatternRollback) {
2003 SmallVector<Value> toConv = llvm::to_vector(to);
2004 SmallVector<Value> repls =
2005 getReplacementValues(*this, from, {toConv}, converter);
2006 IRRewriter r(from.getContext());
2007 Value repl = repls.front();
2008 if (!repl)
2009 return;
2010
2011 performReplaceValue(r, from, repl, functor);
2012 return;
2013 }
2014
2015#ifndef NDEBUG
2016 // Make sure that a value is not replaced multiple times. In rollback mode,
2017 // `replaceAllUsesWith` replaces not only all current uses of the given value,
2018 // but also all future uses that may be introduced by future pattern
2019 // applications. Therefore, it does not make sense to call
2020 // `replaceAllUsesWith` multiple times with the same value. Doing so would
2021 // overwrite the mapping and mess with the internal state of the dialect
2022 // conversion driver.
2023 assert(!replacedValues.contains(from) &&
2024 "attempting to replace a value that was already replaced");
2025 assert(!wasOpReplaced(from.getDefiningOp()) &&
2026 "attempting to replace a op result that was already replaced");
2027 replacedValues.insert(from);
2028#endif // NDEBUG
2029
2030 if (functor)
2031 llvm::reportFatalInternalError(
2032 "conditional value replacement is not supported in rollback mode");
2033 mapping.map(from, to);
2034 appendRewrite<ReplaceValueRewrite>(from, converter);
2035}
2036
2038 if (!config.allowPatternRollback) {
2039 // Pattern rollback is not allowed: materialize all IR changes immediately.
2040 // Update internal data structures, so that there are no dangling pointers
2041 // to erased IR.
2042 block->walk([&](Operation *op) {
2043 erasedOps.insert(op);
2044 ignoredOps.remove(op);
2045 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2046 unresolvedMaterializations.erase(castOp);
2047 patternMaterializations.erase(castOp);
2048 }
2049 // The original op will be erased, so remove it from the set of
2050 // unlegalized ops.
2051 if (config.unlegalizedOps)
2052 config.unlegalizedOps->erase(op);
2053 });
2054 block->walk([&](Block *block) { erasedBlocks.insert(block); });
2055 // Erase the block and notify the listener.
2056 notifyingRewriter.eraseBlock(block);
2057 return;
2058 }
2059
2060 assert(!wasOpReplaced(block->getParentOp()) &&
2061 "attempting to erase a block within a replaced/erased op");
2063
2064 // Unlink the block from its parent region. The block is kept in the rewrite
2065 // object and will be actually destroyed when rewrites are applied. This
2066 // allows us to keep the operations in the block live and undo the removal by
2067 // re-inserting the block.
2068 block->getParent()->getBlocks().remove(block);
2069
2070 // Mark all nested ops as erased.
2071 block->walk([&](Operation *op) { replacedOps.insert(op); });
2072}
2073
2075 Block *block, Region *previous, Region::iterator previousIt) {
2076 // If no previous insertion point is provided, the block used to be detached.
2077 bool wasDetached = !previous;
2078 Operation *newParentOp = block->getParentOp();
2079 LLVM_DEBUG(
2080 {
2081 Operation *parent = newParentOp;
2082 if (parent) {
2083 logger.startLine() << "** Insert Block into : '" << parent->getName()
2084 << "' (" << parent << ")";
2085 } else {
2086 logger.startLine()
2087 << "** Insert Block into detached Region (nullptr parent op)";
2088 }
2089 if (wasDetached)
2090 logger.getOStream() << " (was detached)";
2091 logger.getOStream() << "\n";
2092 });
2093
2094 // In rollback mode, it is easier to misuse the API, so perform extra error
2095 // checking.
2096 assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
2097 "attempting to insert into a region within a replaced/erased op");
2098 (void)newParentOp;
2099
2100 // In "no rollback" mode, the listener is always notified immediately.
2101 if (!config.allowPatternRollback && config.listener)
2102 config.listener->notifyBlockInserted(block, previous, previousIt);
2103
2104 if (wasDetached) {
2105 // If the block was detached, it is most likely a newly created block.
2106 if (config.allowPatternRollback) {
2107 // TODO: If the same block is inserted multiple times from a detached
2108 // state, the rollback mechanism may erase the same block multiple times.
2109 // This is a bug in the rollback-based dialect conversion driver.
2111 } else {
2112 // In "no rollback" mode, there is an extra data structure for tracking
2113 // erased blocks that must be kept up to date.
2114 erasedBlocks.erase(block);
2115 }
2116 return;
2117 }
2118
2119 // The block was moved from one place to another.
2120 if (config.allowPatternRollback)
2121 appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
2122}
2123
2125 Block *dest,
2126 Block::iterator before) {
2127 appendRewrite<InlineBlockRewrite>(dest, source, before);
2128}
2129
2131 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
2132 LLVM_DEBUG({
2134 reasonCallback(diag);
2135 logger.startLine() << "** Failure : " << diag.str() << "\n";
2136 if (config.notifyCallback)
2137 config.notifyCallback(diag);
2138 });
2139}
2140
2141//===----------------------------------------------------------------------===//
2142// ConversionPatternRewriter
2143//===----------------------------------------------------------------------===//
2144
2145ConversionPatternRewriter::ConversionPatternRewriter(
2146 MLIRContext *ctx, const ConversionConfig &config,
2147 OperationConverter &opConverter)
2149 *this, config, opConverter)) {
2150 setListener(impl.get());
2151}
2152
2153ConversionPatternRewriter::~ConversionPatternRewriter() = default;
2154
2155const ConversionConfig &ConversionPatternRewriter::getConfig() const {
2156 return impl->config;
2157}
2158
2159void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
2160 assert(op && newOp && "expected non-null op");
2161 replaceOp(op, newOp->getResults());
2162}
2163
2164void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
2165 assert(op->getNumResults() == newValues.size() &&
2166 "incorrect # of replacement values");
2167
2168 // If the current insertion point is before the erased operation, we adjust
2169 // the insertion point to be after the operation.
2170 if (getInsertionPoint() == op->getIterator())
2172
2173 SmallVector<SmallVector<Value>> newVals =
2174 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2175 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2176 });
2177 impl->replaceOp(op, std::move(newVals));
2178}
2179
2180void ConversionPatternRewriter::replaceOpWithMultiple(
2181 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2182 assert(op->getNumResults() == newValues.size() &&
2183 "incorrect # of replacement values");
2184
2185 // If the current insertion point is before the erased operation, we adjust
2186 // the insertion point to be after the operation.
2187 if (getInsertionPoint() == op->getIterator())
2189
2190 impl->replaceOp(op, std::move(newValues));
2191}
2192
2193void ConversionPatternRewriter::eraseOp(Operation *op) {
2194 LLVM_DEBUG({
2195 impl->logger.startLine()
2196 << "** Erase : '" << op->getName() << "'(" << op << ")\n";
2197 });
2198
2199 // If the current insertion point is before the erased operation, we adjust
2200 // the insertion point to be after the operation.
2201 if (getInsertionPoint() == op->getIterator())
2203
2204 SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
2205 impl->replaceOp(op, std::move(nullRepls));
2206}
2207
2208void ConversionPatternRewriter::eraseBlock(Block *block) {
2209 impl->eraseBlock(block);
2210}
2211
2212Block *ConversionPatternRewriter::applySignatureConversion(
2213 Block *block, TypeConverter::SignatureConversion &conversion,
2214 const TypeConverter *converter) {
2215 assert(!impl->wasOpReplaced(block->getParentOp()) &&
2216 "attempting to apply a signature conversion to a block within a "
2217 "replaced/erased op");
2218 return impl->applySignatureConversion(block, converter, conversion);
2219}
2220
2221FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2222 Region *region, const TypeConverter &converter,
2223 TypeConverter::SignatureConversion *entryConversion) {
2224 assert(!impl->wasOpReplaced(region->getParentOp()) &&
2225 "attempting to apply a signature conversion to a block within a "
2226 "replaced/erased op");
2227 return impl->convertRegionTypes(region, converter, entryConversion);
2228}
2229
2230void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
2231 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2232}
2233
2234void ConversionPatternRewriter::replaceUsesWithIf(
2235 Value from, ValueRange to, function_ref<bool(OpOperand &)> functor,
2236 bool *allUsesReplaced) {
2237 assert(!allUsesReplaced &&
2238 "allUsesReplaced is not supported in a dialect conversion");
2239 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2240}
2241
2242Value ConversionPatternRewriter::getRemappedValue(Value key) {
2243 SmallVector<ValueVector> remappedValues;
2244 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
2245 remappedValues)))
2246 return nullptr;
2247 assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
2248 return remappedValues.front().front();
2249}
2250
2251LogicalResult
2252ConversionPatternRewriter::getRemappedValues(ValueRange keys,
2253 SmallVectorImpl<Value> &results) {
2254 if (keys.empty())
2255 return success();
2256 SmallVector<ValueVector> remapped;
2257 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
2258 remapped)))
2259 return failure();
2260 for (const auto &values : remapped) {
2261 assert(values.size() == 1 && "1:N conversion not supported");
2262 results.push_back(values.front());
2263 }
2264 return success();
2265}
2266
2267void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
2268 Block::iterator before,
2269 ValueRange argValues) {
2270#ifndef NDEBUG
2271 assert(argValues.size() == source->getNumArguments() &&
2272 "incorrect # of argument replacement values");
2273 assert(!impl->wasOpReplaced(source->getParentOp()) &&
2274 "attempting to inline a block from a replaced/erased op");
2275 assert(!impl->wasOpReplaced(dest->getParentOp()) &&
2276 "attempting to inline a block into a replaced/erased op");
2277 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
2278 // The source block will be deleted, so it should not have any users (i.e.,
2279 // there should be no predecessors).
2280 assert(llvm::all_of(source->getUsers(), opIgnored) &&
2281 "expected 'source' to have no predecessors");
2282#endif // NDEBUG
2283
2284 // If a listener is attached to the dialect conversion, ops cannot be moved
2285 // to the destination block in bulk ("fast path"). This is because at the time
2286 // the notifications are sent, it is unknown which ops were moved. Instead,
2287 // ops should be moved one-by-one ("slow path"), so that a separate
2288 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
2289 // a bit more efficient, so we try to do that when possible.
2290 bool fastPath = !getConfig().listener;
2291
2292 if (fastPath && impl->config.allowPatternRollback)
2293 impl->inlineBlockBefore(source, dest, before);
2294
2295 // Replace all uses of block arguments.
2296 for (auto it : llvm::zip(source->getArguments(), argValues))
2297 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2298
2299 if (fastPath) {
2300 // Move all ops at once.
2301 dest->getOperations().splice(before, source->getOperations());
2302 } else {
2303 // Move op by op.
2304 while (!source->empty())
2305 moveOpBefore(&source->front(), dest, before);
2306 }
2307
2308 // If the current insertion point is within the source block, adjust the
2309 // insertion point to the destination block.
2310 if (getInsertionBlock() == source)
2311 setInsertionPoint(dest, getInsertionPoint());
2312
2313 // Erase the source block.
2314 eraseBlock(source);
2315}
2316
2317void ConversionPatternRewriter::startOpModification(Operation *op) {
2318 if (!impl->config.allowPatternRollback) {
2319 // Pattern rollback is not allowed: no extra bookkeeping is needed.
2321 return;
2322 }
2323 assert(!impl->wasOpReplaced(op) &&
2324 "attempting to modify a replaced/erased op");
2325#ifndef NDEBUG
2326 impl->pendingRootUpdates.insert(op);
2327#endif
2328 impl->appendRewrite<ModifyOperationRewrite>(op);
2329}
2330
2331void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2332 impl->patternModifiedOps.insert(op);
2333 if (!impl->config.allowPatternRollback) {
2335 if (getConfig().listener)
2336 getConfig().listener->notifyOperationModified(op);
2337 return;
2338 }
2339
2340 // There is nothing to do here, we only need to track the operation at the
2341 // start of the update.
2342#ifndef NDEBUG
2343 assert(!impl->wasOpReplaced(op) &&
2344 "attempting to modify a replaced/erased op");
2345 assert(impl->pendingRootUpdates.erase(op) &&
2346 "operation did not have a pending in-place update");
2347#endif
2348}
2349
2350void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2351 if (!impl->config.allowPatternRollback) {
2353 return;
2354 }
2355#ifndef NDEBUG
2356 assert(impl->pendingRootUpdates.erase(op) &&
2357 "operation did not have a pending in-place update");
2358#endif
2359 // Erase the last update for this operation.
2360 auto it = llvm::find_if(
2361 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
2362 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2363 return modifyRewrite && modifyRewrite->getOperation() == op;
2364 });
2365 assert(it != impl->rewrites.rend() && "no root update started on op");
2366 (*it)->rollback();
2367 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2368 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2369}
2370
2371detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2372 return *impl;
2373}
2374
2375//===----------------------------------------------------------------------===//
2376// ConversionPattern
2377//===----------------------------------------------------------------------===//
2378
2379FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2380 ArrayRef<ValueRange> operands) const {
2381 SmallVector<Value> oneToOneOperands;
2382 oneToOneOperands.reserve(operands.size());
2383 for (ValueRange operand : operands) {
2384 if (operand.size() != 1)
2385 return failure();
2386
2387 oneToOneOperands.push_back(operand.front());
2388 }
2389 return std::move(oneToOneOperands);
2390}
2391
2392LogicalResult
2393ConversionPattern::matchAndRewrite(Operation *op,
2394 PatternRewriter &rewriter) const {
2395 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
2396 auto &rewriterImpl = dialectRewriter.getImpl();
2397
2398 // Track the current conversion pattern type converter in the rewriter.
2399 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
2400 getTypeConverter());
2401
2402 // Remap the operands of the operation.
2403 SmallVector<ValueVector> remapped;
2404 if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
2405 op->getOperands(), remapped))) {
2406 return failure();
2407 }
2408 SmallVector<ValueRange> remappedAsRange =
2409 llvm::to_vector_of<ValueRange>(remapped);
2410 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2411}
2412
2413//===----------------------------------------------------------------------===//
2414// OperationLegalizer
2415//===----------------------------------------------------------------------===//
2416
2417namespace {
2418/// A set of rewrite patterns that can be used to legalize a given operation.
2419using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2420
2421/// This class defines a recursive operation legalizer.
2422class OperationLegalizer {
2423public:
2424 using LegalizationAction = ConversionTarget::LegalizationAction;
2425
2426 OperationLegalizer(ConversionPatternRewriter &rewriter,
2427 const ConversionTarget &targetInfo,
2428 const FrozenRewritePatternSet &patterns);
2429
2430 /// Returns true if the given operation is known to be illegal on the target.
2431 bool isIllegal(Operation *op) const;
2432
2433 /// Attempt to legalize the given operation. Returns success if the operation
2434 /// was legalized, failure otherwise.
2435 LogicalResult legalize(Operation *op);
2436
2437 /// Returns the conversion target in use by the legalizer.
2438 const ConversionTarget &getTarget() { return target; }
2439
2440private:
2441 /// Attempt to legalize the given operation by folding it.
2442 LogicalResult legalizeWithFold(Operation *op);
2443
2444 /// Attempt to legalize the given operation by applying a pattern. Returns
2445 /// success if the operation was legalized, failure otherwise.
2446 LogicalResult legalizeWithPattern(Operation *op);
2447
2448 /// Return true if the given pattern may be applied to the given operation,
2449 /// false otherwise.
2450 bool canApplyPattern(Operation *op, const Pattern &pattern);
2451
2452 /// Legalize the resultant IR after successfully applying the given pattern.
2453 LogicalResult
2454 legalizePatternResult(Operation *op, const Pattern &pattern,
2455 const RewriterState &curState,
2456 const SetVector<Operation *> &newOps,
2457 const SetVector<Operation *> &modifiedOps);
2458
2459 LogicalResult
2460 legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
2461 LogicalResult
2462 legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
2463
2464 //===--------------------------------------------------------------------===//
2465 // Cost Model
2466 //===--------------------------------------------------------------------===//
2467
2468 /// Build an optimistic legalization graph given the provided patterns. This
2469 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2470 /// patterns for operations that are not directly legal, but may be
2471 /// transitively legal for the current target given the provided patterns.
2472 void buildLegalizationGraph(
2473 LegalizationPatterns &anyOpLegalizerPatterns,
2475
2476 /// Compute the benefit of each node within the computed legalization graph.
2477 /// This orders the patterns within 'legalizerPatterns' based upon two
2478 /// criteria:
2479 /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2480 /// represent the more direct mapping to the target.
2481 /// 2) When comparing patterns with the same legalization depth, prefer the
2482 /// pattern with the highest PatternBenefit. This allows for users to
2483 /// prefer specific legalizations over others.
2484 void computeLegalizationGraphBenefit(
2485 LegalizationPatterns &anyOpLegalizerPatterns,
2487
2488 /// Compute the legalization depth when legalizing an operation of the given
2489 /// type.
2490 unsigned computeOpLegalizationDepth(
2491 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2493
2494 /// Apply the conversion cost model to the given set of patterns, and return
2495 /// the smallest legalization depth of any of the patterns. See
2496 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2497 unsigned applyCostModelToPatterns(
2498 LegalizationPatterns &patterns,
2499 DenseMap<OperationName, unsigned> &minOpPatternDepth,
2501
2502 /// The current set of patterns that have been applied.
2503 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2504
2505 /// The rewriter to use when converting operations.
2506 ConversionPatternRewriter &rewriter;
2507
2508 /// The legalization information provided by the target.
2509 const ConversionTarget &target;
2510
2511 /// The pattern applicator to use for conversions.
2512 PatternApplicator applicator;
2513};
2514} // namespace
2515
2516OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2517 const ConversionTarget &targetInfo,
2518 const FrozenRewritePatternSet &patterns)
2519 : rewriter(rewriter), target(targetInfo), applicator(patterns) {
2520 // The set of patterns that can be applied to illegal operations to transform
2521 // them into legal ones.
2523 LegalizationPatterns anyOpLegalizerPatterns;
2524
2525 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2526 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2527}
2528
2529bool OperationLegalizer::isIllegal(Operation *op) const {
2530 return target.isIllegal(op);
2531}
2532
2533LogicalResult OperationLegalizer::legalize(Operation *op) {
2534#ifndef NDEBUG
2535 const char *logLineComment =
2536 "//===-------------------------------------------===//\n";
2537
2538 auto &logger = rewriter.getImpl().logger;
2539#endif
2540
2541 // Check to see if the operation is ignored and doesn't need to be converted.
2542 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2543
2544 LLVM_DEBUG({
2545 logger.getOStream() << "\n";
2546 logger.startLine() << logLineComment;
2547 logger.startLine() << "Legalizing operation : ";
2548 // Do not print the operation name if the operation is ignored. Ignored ops
2549 // may have been erased and should not be accessed. The pointer can be
2550 // printed safely.
2551 if (!isIgnored)
2552 logger.getOStream() << "'" << op->getName() << "' ";
2553 logger.getOStream() << "(" << op << ") {\n";
2554 logger.indent();
2555
2556 // If the operation has no regions, just print it here.
2557 if (!isIgnored && op->getNumRegions() == 0) {
2558 logger.startLine() << OpWithFlags(op,
2559 OpPrintingFlags().printGenericOpForm())
2560 << "\n";
2561 }
2562 });
2563
2564 if (isIgnored) {
2565 LLVM_DEBUG({
2566 logSuccess(logger, "operation marked 'ignored' during conversion");
2567 logger.startLine() << logLineComment;
2568 });
2569 return success();
2570 }
2571
2572 // Check if this operation is legal on the target.
2573 if (auto legalityInfo = target.isLegal(op)) {
2574 LLVM_DEBUG({
2575 logSuccess(
2576 logger, "operation marked legal by the target{0}",
2577 legalityInfo->isRecursivelyLegal
2578 ? "; NOTE: operation is recursively legal; skipping internals"
2579 : "");
2580 logger.startLine() << logLineComment;
2581 });
2582
2583 // If this operation is recursively legal, mark its children as ignored so
2584 // that we don't consider them for legalization.
2585 if (legalityInfo->isRecursivelyLegal) {
2586 op->walk([&](Operation *nested) {
2587 if (op != nested)
2588 rewriter.getImpl().ignoredOps.insert(nested);
2589 });
2590 }
2591
2592 return success();
2593 }
2594
2595 // If the operation is not legal, try to fold it in-place if the folding mode
2596 // is 'BeforePatterns'. 'Never' will skip this.
2597 const ConversionConfig &config = rewriter.getConfig();
2598 if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2599 if (succeeded(legalizeWithFold(op))) {
2600 LLVM_DEBUG({
2601 logSuccess(logger, "operation was folded");
2602 logger.startLine() << logLineComment;
2603 });
2604 return success();
2605 }
2606 }
2607
2608 // Otherwise, we need to apply a legalization pattern to this operation.
2609 if (succeeded(legalizeWithPattern(op))) {
2610 LLVM_DEBUG({
2611 logSuccess(logger, "");
2612 logger.startLine() << logLineComment;
2613 });
2614 return success();
2615 }
2616
2617 // If the operation can't be legalized via patterns, try to fold it in-place
2618 // if the folding mode is 'AfterPatterns'.
2619 if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2620 if (succeeded(legalizeWithFold(op))) {
2621 LLVM_DEBUG({
2622 logSuccess(logger, "operation was folded");
2623 logger.startLine() << logLineComment;
2624 });
2625 return success();
2626 }
2627 }
2628
2629 LLVM_DEBUG({
2630 logFailure(logger, "no matched legalization pattern");
2631 logger.startLine() << logLineComment;
2632 });
2633 return failure();
2634}
2635
2636/// Helper function that moves and returns the given object. Also resets the
2637/// original object, so that it is in a valid, empty state again.
2638template <typename T>
2639static T moveAndReset(T &obj) {
2640 T result = std::move(obj);
2641 obj = T();
2642 return result;
2643}
2644
2645LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2646 auto &rewriterImpl = rewriter.getImpl();
2647 LLVM_DEBUG({
2648 rewriterImpl.logger.startLine() << "* Fold {\n";
2649 rewriterImpl.logger.indent();
2650 });
2651
2652 // Clear pattern state, so that the next pattern application starts with a
2653 // clean slate. (The op/block sets are populated by listener notifications.)
2654 llvm::scope_exit cleanup([&]() {
2655 rewriterImpl.patternNewOps.clear();
2656 rewriterImpl.patternModifiedOps.clear();
2657 });
2658
2659 // Upon failure, undo all changes made by the folder.
2660 RewriterState curState = rewriterImpl.getCurrentState();
2661
2662 // Try to fold the operation.
2663 StringRef opName = op->getName().getStringRef();
2664 SmallVector<Value, 2> replacementValues;
2665 SmallVector<Operation *, 2> newOps;
2666 rewriter.setInsertionPoint(op);
2667 rewriter.startOpModification(op);
2668 if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2669 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2670 rewriter.cancelOpModification(op);
2671 return failure();
2672 }
2673 rewriter.finalizeOpModification(op);
2674
2675 // An empty list of replacement values indicates that the fold was in-place.
2676 // As the operation changed, a new legalization needs to be attempted.
2677 if (replacementValues.empty())
2678 return legalize(op);
2679
2680 // Insert a replacement for 'op' with the folded replacement values.
2681 rewriter.replaceOp(op, replacementValues);
2682
2683 // Recursively legalize any new constant operations.
2684 for (Operation *newOp : newOps) {
2685 if (failed(legalize(newOp))) {
2686 LLVM_DEBUG(logFailure(rewriterImpl.logger,
2687 "failed to legalize generated constant '{0}'",
2688 newOp->getName()));
2689 if (!rewriter.getConfig().allowPatternRollback) {
2690 // Rolling back a folder is like rolling back a pattern.
2691 llvm::reportFatalInternalError(
2692 "op '" + opName +
2693 "' folder rollback of IR modifications requested");
2694 }
2695 rewriterImpl.resetState(
2696 curState, std::string(op->getName().getStringRef()) + " folder");
2697 return failure();
2698 }
2699 }
2700
2701 LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2702 return success();
2703}
2704
2705/// Report a fatal error indicating that newly produced or modified IR could
2706/// not be legalized.
2707static void
2709 const SetVector<Operation *> &newOps,
2710 const SetVector<Operation *> &modifiedOps) {
2711 auto newOpNames = llvm::map_range(
2712 newOps, [](Operation *op) { return op->getName().getStringRef(); });
2713 auto modifiedOpNames = llvm::map_range(
2714 modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
2715 llvm::reportFatalInternalError("pattern '" + pattern.getDebugName() +
2716 "' produced IR that could not be legalized. " +
2717 "new ops: {" + llvm::join(newOpNames, ", ") +
2718 "}, " + "modified ops: {" +
2719 llvm::join(modifiedOpNames, ", ") + "}");
2720}
2721
2722LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2723 auto &rewriterImpl = rewriter.getImpl();
2724 const ConversionConfig &config = rewriter.getConfig();
2725
2726#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2727 Operation *checkOp;
2728 std::optional<OperationFingerPrint> topLevelFingerPrint;
2729 if (!rewriterImpl.config.allowPatternRollback) {
2730 // The op may be getting erased, so we have to check the parent op.
2731 // (In rare cases, a pattern may even erase the parent op, which will cause
2732 // a crash here. Expensive checks are "best effort".) Skip the check if the
2733 // op does not have a parent op.
2734 if ((checkOp = op->getParentOp())) {
2735 if (!op->getContext()->isMultithreadingEnabled()) {
2736 topLevelFingerPrint = OperationFingerPrint(checkOp);
2737 } else {
2738 // Another thread may be modifying a sibling operation. Therefore, the
2739 // fingerprinting mechanism of the parent op works only in
2740 // single-threaded mode.
2741 LLVM_DEBUG({
2742 rewriterImpl.logger.startLine()
2743 << "WARNING: Multi-threadeding is enabled. Some dialect "
2744 "conversion expensive checks are skipped in multithreading "
2745 "mode!\n";
2746 });
2747 }
2748 }
2749 }
2750#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2751
2752 // Functor that returns if the given pattern may be applied.
2753 auto canApply = [&](const Pattern &pattern) {
2754 bool canApply = canApplyPattern(op, pattern);
2755 if (canApply && config.listener)
2756 config.listener->notifyPatternBegin(pattern, op);
2757 return canApply;
2758 };
2759
2760 // Functor that cleans up the rewriter state after a pattern failed to match.
2761 RewriterState curState = rewriterImpl.getCurrentState();
2762 auto onFailure = [&](const Pattern &pattern) {
2763 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2764 if (!rewriterImpl.config.allowPatternRollback) {
2765 // Erase all unresolved materializations.
2766 for (auto op : rewriterImpl.patternMaterializations) {
2767 rewriterImpl.unresolvedMaterializations.erase(op);
2768 op.erase();
2769 }
2770 rewriterImpl.patternMaterializations.clear();
2771#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2772 // Expensive pattern check that can detect API violations.
2773 if (checkOp && topLevelFingerPrint) {
2774 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2775 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2776 llvm::reportFatalInternalError(
2777 "pattern '" + pattern.getDebugName() +
2778 "' returned failure but IR did change");
2779 }
2780#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2781 }
2782 rewriterImpl.patternNewOps.clear();
2783 rewriterImpl.patternModifiedOps.clear();
2784 LLVM_DEBUG({
2785 logFailure(rewriterImpl.logger, "pattern failed to match");
2786 if (rewriterImpl.config.notifyCallback) {
2787 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
2788 diag << "Failed to apply pattern \"" << pattern.getDebugName()
2789 << "\" on op:\n"
2790 << *op;
2791 rewriterImpl.config.notifyCallback(diag);
2792 }
2793 });
2794 if (config.listener)
2795 config.listener->notifyPatternEnd(pattern, failure());
2796 rewriterImpl.resetState(curState, pattern.getDebugName());
2797 appliedPatterns.erase(&pattern);
2798 };
2799
2800 // Functor that performs additional legalization when a pattern is
2801 // successfully applied.
2802 auto onSuccess = [&](const Pattern &pattern) {
2803 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2804 if (!rewriterImpl.config.allowPatternRollback) {
2805 // Eagerly erase unused materializations.
2806 for (auto op : rewriterImpl.patternMaterializations) {
2807 if (op->use_empty()) {
2808 rewriterImpl.unresolvedMaterializations.erase(op);
2809 op.erase();
2810 }
2811 }
2812 rewriterImpl.patternMaterializations.clear();
2813 }
2814 SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2815 SetVector<Operation *> modifiedOps =
2816 moveAndReset(rewriterImpl.patternModifiedOps);
2817 auto result =
2818 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2819 appliedPatterns.erase(&pattern);
2820 if (failed(result)) {
2821 if (!rewriterImpl.config.allowPatternRollback)
2822 reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
2823 rewriterImpl.resetState(curState, pattern.getDebugName());
2824 }
2825 if (config.listener)
2826 config.listener->notifyPatternEnd(pattern, result);
2827 return result;
2828 };
2829
2830 // Try to match and rewrite a pattern on this operation.
2831 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2832 onSuccess);
2833}
2834
2835bool OperationLegalizer::canApplyPattern(Operation *op,
2836 const Pattern &pattern) {
2837 LLVM_DEBUG({
2838 auto &os = rewriter.getImpl().logger;
2839 os.getOStream() << "\n";
2840 os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2841 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2842 os.getOStream() << ")' {\n";
2843 os.indent();
2844 });
2845
2846 // Ensure that we don't cycle by not allowing the same pattern to be
2847 // applied twice in the same recursion stack if it is not known to be safe.
2848 if (!pattern.hasBoundedRewriteRecursion() &&
2849 !appliedPatterns.insert(&pattern).second) {
2850 LLVM_DEBUG(
2851 logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2852 return false;
2853 }
2854 return true;
2855}
2856
2857LogicalResult OperationLegalizer::legalizePatternResult(
2858 Operation *op, const Pattern &pattern, const RewriterState &curState,
2859 const SetVector<Operation *> &newOps,
2860 const SetVector<Operation *> &modifiedOps) {
2861 [[maybe_unused]] auto &impl = rewriter.getImpl();
2862 assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2863
2864#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2865 if (impl.config.allowPatternRollback) {
2866 // Check that the root was either replaced or updated in place.
2867 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2868 auto replacedRoot = [&] {
2869 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2870 };
2871 auto updatedRootInPlace = [&] {
2872 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2873 };
2874 if (!replacedRoot() && !updatedRootInPlace())
2875 llvm::reportFatalInternalError(
2876 "expected pattern to replace the root operation "
2877 "or modify it in place");
2878 }
2879#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2880
2881 // Legalize each of the actions registered during application.
2882 if (failed(legalizePatternRootUpdates(modifiedOps)) ||
2883 failed(legalizePatternCreatedOperations(newOps))) {
2884 return failure();
2885 }
2886
2887 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2888 return success();
2889}
2890
2891LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2892 const SetVector<Operation *> &newOps) {
2893 for (Operation *op : newOps) {
2894 if (failed(legalize(op))) {
2895 LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
2896 "failed to legalize generated operation '{0}'({1})",
2897 op->getName(), op));
2898 return failure();
2899 }
2900 }
2901 return success();
2902}
2903
2904LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2905 const SetVector<Operation *> &modifiedOps) {
2906 for (Operation *op : modifiedOps) {
2907 if (failed(legalize(op))) {
2908 LLVM_DEBUG(
2909 logFailure(rewriter.getImpl().logger,
2910 "failed to legalize operation updated in-place '{0}'",
2911 op->getName()));
2912 return failure();
2913 }
2914 }
2915 return success();
2916}
2917
2918//===----------------------------------------------------------------------===//
2919// Cost Model
2920//===----------------------------------------------------------------------===//
2921
2922void OperationLegalizer::buildLegalizationGraph(
2923 LegalizationPatterns &anyOpLegalizerPatterns,
2925 // A mapping between an operation and a set of operations that can be used to
2926 // generate it.
2928 // A mapping between an operation and any currently invalid patterns it has.
2930 // A worklist of patterns to consider for legality.
2931 SetVector<const Pattern *> patternWorklist;
2932
2933 // Build the mapping from operations to the parent ops that may generate them.
2934 applicator.walkAllPatterns([&](const Pattern &pattern) {
2935 std::optional<OperationName> root = pattern.getRootKind();
2936
2937 // If the pattern has no specific root, we can't analyze the relationship
2938 // between the root op and generated operations. Given that, add all such
2939 // patterns to the legalization set.
2940 if (!root) {
2941 anyOpLegalizerPatterns.push_back(&pattern);
2942 return;
2943 }
2944
2945 // Skip operations that are always known to be legal.
2946 if (target.getOpAction(*root) == LegalizationAction::Legal)
2947 return;
2948
2949 // Add this pattern to the invalid set for the root op and record this root
2950 // as a parent for any generated operations.
2951 invalidPatterns[*root].insert(&pattern);
2952 for (auto op : pattern.getGeneratedOps())
2953 parentOps[op].insert(*root);
2954
2955 // Add this pattern to the worklist.
2956 patternWorklist.insert(&pattern);
2957 });
2958
2959 // If there are any patterns that don't have a specific root kind, we can't
2960 // make direct assumptions about what operations will never be legalized.
2961 // Note: Technically we could, but it would require an analysis that may
2962 // recurse into itself. It would be better to perform this kind of filtering
2963 // at a higher level than here anyways.
2964 if (!anyOpLegalizerPatterns.empty()) {
2965 for (const Pattern *pattern : patternWorklist)
2966 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2967 return;
2968 }
2969
2970 while (!patternWorklist.empty()) {
2971 auto *pattern = patternWorklist.pop_back_val();
2972
2973 // Check to see if any of the generated operations are invalid.
2974 if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2975 std::optional<LegalizationAction> action = target.getOpAction(op);
2976 return !legalizerPatterns.count(op) &&
2977 (!action || action == LegalizationAction::Illegal);
2978 }))
2979 continue;
2980
2981 // Otherwise, if all of the generated operation are valid, this op is now
2982 // legal so add all of the child patterns to the worklist.
2983 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2984 invalidPatterns[*pattern->getRootKind()].erase(pattern);
2985
2986 // Add any invalid patterns of the parent operations to see if they have now
2987 // become legal.
2988 for (auto op : parentOps[*pattern->getRootKind()])
2989 patternWorklist.set_union(invalidPatterns[op]);
2990 }
2991}
2992
2993void OperationLegalizer::computeLegalizationGraphBenefit(
2994 LegalizationPatterns &anyOpLegalizerPatterns,
2996 // The smallest pattern depth, when legalizing an operation.
2997 DenseMap<OperationName, unsigned> minOpPatternDepth;
2998
2999 // For each operation that is transitively legal, compute a cost for it.
3000 for (auto &opIt : legalizerPatterns)
3001 if (!minOpPatternDepth.count(opIt.first))
3002 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3003 legalizerPatterns);
3004
3005 // Apply the cost model to the patterns that can match any operation. Those
3006 // with a specific operation type are already resolved when computing the op
3007 // legalization depth.
3008 if (!anyOpLegalizerPatterns.empty())
3009 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3010 legalizerPatterns);
3011
3012 // Apply a cost model to the pattern applicator. We order patterns first by
3013 // depth then benefit. `legalizerPatterns` contains per-op patterns by
3014 // decreasing benefit.
3015 applicator.applyCostModel([&](const Pattern &pattern) {
3016 ArrayRef<const Pattern *> orderedPatternList;
3017 if (std::optional<OperationName> rootName = pattern.getRootKind())
3018 orderedPatternList = legalizerPatterns[*rootName];
3019 else
3020 orderedPatternList = anyOpLegalizerPatterns;
3021
3022 // If the pattern is not found, then it was removed and cannot be matched.
3023 auto *it = llvm::find(orderedPatternList, &pattern);
3024 if (it == orderedPatternList.end())
3026
3027 // Patterns found earlier in the list have higher benefit.
3028 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3029 });
3030}
3031
3032unsigned OperationLegalizer::computeOpLegalizationDepth(
3033 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
3035 // Check for existing depth.
3036 auto depthIt = minOpPatternDepth.find(op);
3037 if (depthIt != minOpPatternDepth.end())
3038 return depthIt->second;
3039
3040 // If a mapping for this operation does not exist, then this operation
3041 // is always legal. Return 0 as the depth for a directly legal operation.
3042 auto opPatternsIt = legalizerPatterns.find(op);
3043 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3044 return 0u;
3045
3046 // Record this initial depth in case we encounter this op again when
3047 // recursively computing the depth.
3048 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3049
3050 // Apply the cost model to the operation patterns, and update the minimum
3051 // depth.
3052 unsigned minDepth = applyCostModelToPatterns(
3053 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3054 minOpPatternDepth[op] = minDepth;
3055 return minDepth;
3056}
3057
3058unsigned OperationLegalizer::applyCostModelToPatterns(
3059 LegalizationPatterns &patterns,
3060 DenseMap<OperationName, unsigned> &minOpPatternDepth,
3062 unsigned minDepth = std::numeric_limits<unsigned>::max();
3063
3064 // Compute the depth for each pattern within the set.
3065 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3066 patternsByDepth.reserve(patterns.size());
3067 for (const Pattern *pattern : patterns) {
3068 unsigned depth = 1;
3069 for (auto generatedOp : pattern->getGeneratedOps()) {
3070 unsigned generatedOpDepth = computeOpLegalizationDepth(
3071 generatedOp, minOpPatternDepth, legalizerPatterns);
3072 depth = std::max(depth, generatedOpDepth + 1);
3073 }
3074 patternsByDepth.emplace_back(pattern, depth);
3075
3076 // Update the minimum depth of the pattern list.
3077 minDepth = std::min(minDepth, depth);
3078 }
3079
3080 // If the operation only has one legalization pattern, there is no need to
3081 // sort them.
3082 if (patternsByDepth.size() == 1)
3083 return minDepth;
3084
3085 // Sort the patterns by those likely to be the most beneficial.
3086 llvm::stable_sort(patternsByDepth,
3087 [](const std::pair<const Pattern *, unsigned> &lhs,
3088 const std::pair<const Pattern *, unsigned> &rhs) {
3089 // First sort by the smaller pattern legalization
3090 // depth.
3091 if (lhs.second != rhs.second)
3092 return lhs.second < rhs.second;
3093
3094 // Then sort by the larger pattern benefit.
3095 auto lhsBenefit = lhs.first->getBenefit();
3096 auto rhsBenefit = rhs.first->getBenefit();
3097 return lhsBenefit > rhsBenefit;
3098 });
3099
3100 // Update the legalization pattern to use the new sorted list.
3101 patterns.clear();
3102 for (auto &patternIt : patternsByDepth)
3103 patterns.push_back(patternIt.first);
3104 return minDepth;
3105}
3106
3107//===----------------------------------------------------------------------===//
3108// Reconcile Unrealized Casts
3109//===----------------------------------------------------------------------===//
3110
3111/// Try to reconcile all given UnrealizedConversionCastOps and store the
3112/// left-over ops in `remainingCastOps` (if provided). See documentation in
3113/// DialectConversion.h for more details.
3114/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3115/// algorithm may visit an operand (or user) which is a cast op, but will not
3116/// try to reconcile it if not in the filtered set.
3117template <typename RangeT>
3119 RangeT castOps,
3120 function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3122 // A worklist of cast ops to process.
3123 SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
3124
3125 // Helper function that return the unrealized_conversion_cast op that
3126 // defines all inputs of the given op (in the same order). Return "nullptr"
3127 // if there is no such op.
3128 auto getInputCast =
3129 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3130 if (castOp.getInputs().empty())
3131 return {};
3132 auto inputCastOp =
3133 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3134 if (!inputCastOp)
3135 return {};
3136 if (inputCastOp.getOutputs() != castOp.getInputs())
3137 return {};
3138 return inputCastOp;
3139 };
3140
3141 // Process ops in the worklist bottom-to-top.
3142 while (!worklist.empty()) {
3143 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3144
3145 // Traverse the chain of input cast ops to see if an op with the same
3146 // input types can be found.
3147 UnrealizedConversionCastOp nextCast = castOp;
3148 while (nextCast) {
3149 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3150 if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
3151 return v.getDefiningOp() == castOp;
3152 })) {
3153 // Ran into a cycle.
3154 break;
3155 }
3156
3157 // Found a cast where the input types match the output types of the
3158 // matched op. We can directly use those inputs.
3159 castOp.replaceAllUsesWith(nextCast.getInputs());
3160 break;
3161 }
3162 nextCast = getInputCast(nextCast);
3163 }
3164 }
3165
3166 // A set of all alive cast ops. I.e., ops whose results are (transitively)
3167 // used by an op that is not a cast op.
3168 DenseSet<Operation *> liveOps;
3169
3170 // Helper function that marks the given op and transitively reachable input
3171 // cast ops as alive.
3172 auto markOpLive = [&](Operation *rootOp) {
3173 SmallVector<Operation *> worklist;
3174 worklist.push_back(rootOp);
3175 while (!worklist.empty()) {
3176 Operation *op = worklist.pop_back_val();
3177 if (liveOps.insert(op).second) {
3178 // Successfully inserted: process reachable input cast ops.
3179 for (Value v : op->getOperands())
3180 if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3181 if (isCastOpOfInterestFn(castOp))
3182 worklist.push_back(castOp);
3183 }
3184 }
3185 };
3186
3187 // Find all alive cast ops.
3188 for (UnrealizedConversionCastOp op : castOps) {
3189 // The op may have been marked live already as being an operand of another
3190 // live cast op.
3191 if (liveOps.contains(op.getOperation()))
3192 continue;
3193 // If any of the users is not a cast op, mark the current op (and its
3194 // input ops) as live.
3195 if (llvm::any_of(op->getUsers(), [&](Operation *user) {
3196 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3197 return !castOp || !isCastOpOfInterestFn(castOp);
3198 }))
3199 markOpLive(op);
3200 }
3201
3202 // Erase all dead cast ops.
3203 for (UnrealizedConversionCastOp op : castOps) {
3204 if (liveOps.contains(op)) {
3205 // Op is alive and was not erased. Add it to the remaining cast ops.
3206 if (remainingCastOps)
3207 remainingCastOps->push_back(op);
3208 continue;
3209 }
3210
3211 // Op is dead. Erase it.
3212 op->dropAllUses();
3213 op->erase();
3214 }
3215}
3216
3218 ArrayRef<UnrealizedConversionCastOp> castOps,
3219 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3220 // Set of all cast ops for faster lookups.
3221 DenseSet<UnrealizedConversionCastOp> castOpSet;
3222 for (UnrealizedConversionCastOp op : castOps)
3223 castOpSet.insert(op);
3224 reconcileUnrealizedCasts(castOpSet, remainingCastOps);
3225}
3226
3228 const DenseSet<UnrealizedConversionCastOp> &castOps,
3229 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3231 llvm::make_range(castOps.begin(), castOps.end()),
3232 [&](UnrealizedConversionCastOp castOp) {
3233 return castOps.contains(castOp);
3234 },
3235 remainingCastOps);
3236}
3237
3238namespace mlir {
3241 &castOps,
3244 castOps.keys(),
3245 [&](UnrealizedConversionCastOp castOp) {
3246 return castOps.contains(castOp);
3247 },
3248 remainingCastOps);
3249}
3250} // namespace mlir
3251
3252//===----------------------------------------------------------------------===//
3253// OperationConverter
3254//===----------------------------------------------------------------------===//
3255
3256namespace mlir {
3257// This class converts operations to a given conversion target via a set of
3258// rewrite patterns. The conversion behaves differently depending on the
3259// conversion mode.
3262 const FrozenRewritePatternSet &patterns,
3263 const ConversionConfig &config,
3264 OpConversionMode mode)
3265 : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
3266 mode(mode) {}
3267
3268 /// Applies the conversion to the given operations (and their nested
3269 /// operations).
3270 LogicalResult applyConversion(ArrayRef<Operation *> ops);
3271
3272 /// Legalizes the given operations (and their nested operations) to the
3273 /// conversion target.
3274 template <typename Fn>
3275 LogicalResult legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
3276 bool isRecursiveLegalization = false);
3278 bool isRecursiveLegalization = false) {
3279 return legalizeOperations(
3280 ops, /*onFailure=*/[&]() {}, isRecursiveLegalization);
3281 }
3282
3283 /// Converts a single operation. If `isRecursiveLegalization` is "true", the
3284 /// conversion is a recursive legalization request, triggered from within a
3285 /// pattern. In that case, do not emit errors because there will be another
3286 /// attempt at legalizing the operation later (via the regular pre-order
3287 /// legalization mechanism).
3288 LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
3289
3290 const ConversionTarget &getTarget() { return opLegalizer.getTarget(); }
3291
3292private:
3293 /// The rewriter to use when converting operations.
3294 ConversionPatternRewriter rewriter;
3295
3296 /// The legalizer to use when converting operations.
3297 OperationLegalizer opLegalizer;
3298
3299 /// The conversion mode to use when legalizing operations.
3300 OpConversionMode mode;
3301};
3302} // namespace mlir
3303
3305 bool isRecursiveLegalization) {
3306 const ConversionConfig &config = rewriter.getConfig();
3307 auto emitFailedToLegalizeDiag = [&](bool wasExplicitlyIllegal) {
3309 << "failed to legalize operation '"
3310 << op->getName() << "'";
3311 if (wasExplicitlyIllegal)
3312 diag << " that was explicitly marked illegal";
3313 diag << ": " << OpWithFlags(op, OpPrintingFlags().skipRegions());
3314 };
3315
3316 // Legalize the given operation.
3317 if (failed(opLegalizer.legalize(op))) {
3318 // Handle the case of a failed conversion for each of the different modes.
3319 // Full conversions expect all operations to be converted.
3320 if (mode == OpConversionMode::Full) {
3321 if (!isRecursiveLegalization)
3322 emitFailedToLegalizeDiag(/*wasExplicitlyIllegal=*/false);
3323 return failure();
3324 }
3325 // Partial conversions allow conversions to fail iff the operation was not
3326 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
3327 // set, non-legalizable ops are added to that set.
3328 if (mode == OpConversionMode::Partial) {
3329 if (opLegalizer.isIllegal(op)) {
3330 if (!isRecursiveLegalization)
3331 emitFailedToLegalizeDiag(/*wasExplicitlyIllegal=*/true);
3332 return failure();
3333 }
3334 if (config.unlegalizedOps && !isRecursiveLegalization)
3335 config.unlegalizedOps->insert(op);
3336 }
3337 } else if (mode == OpConversionMode::Analysis) {
3338 // Analysis conversions don't fail if any operations fail to legalize,
3339 // they are only interested in the operations that were successfully
3340 // legalized.
3341 if (config.legalizableOps && !isRecursiveLegalization)
3342 config.legalizableOps->insert(op);
3343 }
3344 return success();
3345}
3346
3347static LogicalResult
3349 UnrealizedConversionCastOp op,
3350 const UnresolvedMaterializationInfo &info) {
3351 assert(!op.use_empty() &&
3352 "expected that dead materializations have already been DCE'd");
3353 Operation::operand_range inputOperands = op.getOperands();
3354
3355 // Try to materialize the conversion.
3356 if (const TypeConverter *converter = info.getConverter()) {
3357 rewriter.setInsertionPoint(op);
3358 SmallVector<Value> newMaterialization;
3359 switch (info.getMaterializationKind()) {
3360 case MaterializationKind::Target:
3361 newMaterialization = converter->materializeTargetConversion(
3362 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3363 info.getOriginalType());
3364 break;
3365 case MaterializationKind::Source:
3366 assert(op->getNumResults() == 1 && "expected single result");
3367 Value sourceMat = converter->materializeSourceConversion(
3368 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3369 if (sourceMat)
3370 newMaterialization.push_back(sourceMat);
3371 break;
3372 }
3373 if (!newMaterialization.empty()) {
3374#ifndef NDEBUG
3375 ValueRange newMaterializationRange(newMaterialization);
3376 assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
3377 "materialization callback produced value of incorrect type");
3378#endif // NDEBUG
3379 rewriter.replaceOp(op, newMaterialization);
3380 return success();
3381 }
3382 }
3383
3384 InFlightDiagnostic diag = op->emitError()
3385 << "failed to legalize unresolved materialization "
3386 "from ("
3387 << inputOperands.getTypes() << ") to ("
3388 << op.getResultTypes()
3389 << ") that remained live after conversion";
3390 diag.attachNote(op->getUsers().begin()->getLoc())
3391 << "see existing live user here: " << *op->getUsers().begin();
3392 return failure();
3393}
3394
3395template <typename Fn>
3396LogicalResult
3398 bool isRecursiveLegalization) {
3399 const ConversionTarget &target = opLegalizer.getTarget();
3400
3401 // Compute the set of operations and blocks to convert.
3402 SmallVector<Operation *> toConvert;
3403 for (Operation *op : ops) {
3405 [&](Operation *op) {
3406 toConvert.push_back(op);
3407 // Don't check this operation's children for conversion if the
3408 // operation is recursively legal.
3409 auto legalityInfo = target.isLegal(op);
3410 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3411 return WalkResult::skip();
3412 return WalkResult::advance();
3413 });
3414 }
3415 for (Operation *op : toConvert) {
3416 if (failed(convert(op, isRecursiveLegalization))) {
3417 // Failed to convert an operation.
3418 onFailure();
3419 return failure();
3420 }
3421 }
3422 return success();
3423}
3424
3425LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
3426 return impl->opConverter.legalizeOperations(op,
3427 /*isRecursiveLegalization=*/true);
3428}
3429
3430LogicalResult ConversionPatternRewriter::legalize(Region *r) {
3431 // Fast path: If the region is empty, there is nothing to legalize.
3432 if (r->empty())
3433 return success();
3434
3435 // Gather a list of all operations to legalize. This is done before
3436 // converting the entry block signature because unrealized_conversion_cast
3437 // ops should not be included.
3439 for (Block &b : *r)
3440 for (Operation &op : b)
3441 ops.push_back(&op);
3442
3443 // If the current pattern runs with a type converter, convert the entry block
3444 // signature.
3445 if (const TypeConverter *converter = impl->currentTypeConverter) {
3446 std::optional<TypeConverter::SignatureConversion> conversion =
3447 converter->convertBlockSignature(&r->front());
3448 if (!conversion)
3449 return failure();
3450 applySignatureConversion(&r->front(), *conversion, converter);
3451 }
3452
3453 // Legalize all operations in the region. This includes all nested
3454 // operations.
3455 return impl->opConverter.legalizeOperations(ops,
3456 /*isRecursiveLegalization=*/true);
3457}
3458
3460 // Convert each operation and discard rewrites on failure.
3461 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
3462 LogicalResult status = legalizeOperations(ops, /*onFailure=*/[&]() {
3463 // Dialect conversion failed.
3464 if (rewriterImpl.config.allowPatternRollback) {
3465 // Rollback is allowed: restore the original IR.
3466 rewriterImpl.undoRewrites();
3467 } else {
3468 // Rollback is not allowed: apply all modifications that have been
3469 // performed so far.
3470 rewriterImpl.applyRewrites();
3471 }
3472 });
3473 if (failed(status))
3474 return failure();
3475
3476 // After a successful conversion, apply rewrites.
3477 rewriterImpl.applyRewrites();
3478
3479 // Reconcile all UnrealizedConversionCastOps that were inserted by the
3480 // dialect conversion frameworks. (Not the ones that were inserted by
3481 // patterns.)
3483 &materializations = rewriterImpl.unresolvedMaterializations;
3485 reconcileUnrealizedCasts(materializations, &remainingCastOps);
3486
3487 // Drop markers.
3488 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3489 castOp->removeAttr(kPureTypeConversionMarker);
3490
3491 // Try to legalize all unresolved materializations.
3492 if (rewriter.getConfig().buildMaterializations) {
3493 // Use a new rewriter, so the modifications are not tracked for rollback
3494 // purposes etc.
3495 IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
3496 rewriter.getConfig().listener);
3497 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3498 auto it = materializations.find(castOp);
3499 assert(it != materializations.end() && "inconsistent state");
3500 if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
3501 it->second)))
3502 return failure();
3503 }
3504 }
3505
3506 return success();
3507}
3508
3509//===----------------------------------------------------------------------===//
3510// Type Conversion
3511//===----------------------------------------------------------------------===//
3512
3513void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
3514 ArrayRef<Type> types) {
3515 assert(!types.empty() && "expected valid types");
3516 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
3517 addInputs(types);
3518}
3519
3520void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
3521 assert(!types.empty() &&
3522 "1->0 type remappings don't need to be added explicitly");
3523 argTypes.append(types.begin(), types.end());
3524}
3525
3526void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3527 unsigned newInputNo,
3528 unsigned newInputCount) {
3529 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3530 assert(newInputCount != 0 && "expected valid input count");
3531 remappedInputs[origInputNo] =
3532 InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
3533}
3534
3535void TypeConverter::SignatureConversion::remapInput(
3536 unsigned origInputNo, ArrayRef<Value> replacements) {
3537 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3538 remappedInputs[origInputNo] = InputMapping{
3539 origInputNo, /*size=*/0,
3540 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3541}
3542
3543/// Internal implementation of the type conversion.
3544/// This is used with either a Type or a Value as the first argument.
3545/// - we can cache the context-free conversions until the last registered
3546/// context-aware conversion.
3547/// - we can't cache the result of type conversion happening after context-aware
3548/// conversions, because the type converter may return different results for the
3549/// same input type.
3550LogicalResult
3551TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3552 SmallVectorImpl<Type> &results) const {
3553 assert(typeOrValue && "expected non-null type");
3554 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3555 : cast<Type>(typeOrValue);
3556 {
3557 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3558 std::defer_lock);
3560 cacheReadLock.lock();
3561 auto existingIt = cachedDirectConversions.find(t);
3562 if (existingIt != cachedDirectConversions.end()) {
3563 if (existingIt->second)
3564 results.push_back(existingIt->second);
3565 return success(existingIt->second != nullptr);
3566 }
3567 auto multiIt = cachedMultiConversions.find(t);
3568 if (multiIt != cachedMultiConversions.end()) {
3569 results.append(multiIt->second.begin(), multiIt->second.end());
3570 return success();
3571 }
3572 }
3573 // Walk the added converters in reverse order to apply the most recently
3574 // registered first.
3575 size_t currentCount = results.size();
3576
3577 // We can cache the context-free conversions until the last registered
3578 // context-aware conversion. But only if we're processing a Value right now.
3579 auto isCacheable = [&](int index) {
3580 int numberOfConversionsUntilContextAware =
3581 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3582 return index < numberOfConversionsUntilContextAware;
3583 };
3584
3585 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3586 std::defer_lock);
3587
3588 for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3589 const ConversionCallbackFn &converter = indexedConverter.value();
3590 std::optional<LogicalResult> result = converter(typeOrValue, results);
3591 if (!result) {
3592 assert(results.size() == currentCount &&
3593 "failed type conversion should not change results");
3594 continue;
3595 }
3596 if (!isCacheable(indexedConverter.index()))
3597 return success();
3599 cacheWriteLock.lock();
3600 if (!succeeded(*result)) {
3601 assert(results.size() == currentCount &&
3602 "failed type conversion should not change results");
3603 cachedDirectConversions.try_emplace(t, nullptr);
3604 return failure();
3605 }
3606 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3607 if (newTypes.size() == 1)
3608 cachedDirectConversions.try_emplace(t, newTypes.front());
3609 else
3610 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3611 return success();
3612 }
3613 return failure();
3614}
3615
3616LogicalResult TypeConverter::convertType(Type t,
3617 SmallVectorImpl<Type> &results) const {
3618 return convertTypeImpl(t, results);
3619}
3620
3621LogicalResult TypeConverter::convertType(Value v,
3622 SmallVectorImpl<Type> &results) const {
3623 return convertTypeImpl(v, results);
3624}
3625
3626Type TypeConverter::convertType(Type t) const {
3627 // Use the multi-type result version to convert the type.
3628 SmallVector<Type, 1> results;
3629 if (failed(convertType(t, results)))
3630 return nullptr;
3631
3632 // Check to ensure that only one type was produced.
3633 return results.size() == 1 ? results.front() : nullptr;
3634}
3635
3636Type TypeConverter::convertType(Value v) const {
3637 // Use the multi-type result version to convert the type.
3638 SmallVector<Type, 1> results;
3639 if (failed(convertType(v, results)))
3640 return nullptr;
3641
3642 // Check to ensure that only one type was produced.
3643 return results.size() == 1 ? results.front() : nullptr;
3644}
3645
3646LogicalResult
3647TypeConverter::convertTypes(TypeRange types,
3648 SmallVectorImpl<Type> &results) const {
3649 for (Type type : types)
3650 if (failed(convertType(type, results)))
3651 return failure();
3652 return success();
3653}
3654
3655LogicalResult
3656TypeConverter::convertTypes(ValueRange values,
3657 SmallVectorImpl<Type> &results) const {
3658 for (Value value : values)
3659 if (failed(convertType(value, results)))
3660 return failure();
3661 return success();
3662}
3663
3664bool TypeConverter::isLegal(Type type) const {
3665 return convertType(type) == type;
3666}
3667
3668bool TypeConverter::isLegal(Value value) const {
3669 return convertType(value) == value.getType();
3670}
3671
3672bool TypeConverter::isLegal(Operation *op) const {
3673 return isLegal(op->getOperands()) && isLegal(op->getResults());
3674}
3675
3676bool TypeConverter::isLegal(Region *region) const {
3677 return llvm::all_of(
3678 *region, [this](Block &block) { return isLegal(block.getArguments()); });
3679}
3680
3681bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3682 if (!isLegal(ty.getInputs()))
3683 return false;
3684 if (!isLegal(ty.getResults()))
3685 return false;
3686 return true;
3687}
3688
3689LogicalResult
3690TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
3691 SignatureConversion &result) const {
3692 // Try to convert the given input type.
3693 SmallVector<Type, 1> convertedTypes;
3694 if (failed(convertType(type, convertedTypes)))
3695 return failure();
3696
3697 // If this argument is being dropped, there is nothing left to do.
3698 if (convertedTypes.empty())
3699 return success();
3700
3701 // Otherwise, add the new inputs.
3702 result.addInputs(inputNo, convertedTypes);
3703 return success();
3704}
3705LogicalResult
3706TypeConverter::convertSignatureArgs(TypeRange types,
3707 SignatureConversion &result,
3708 unsigned origInputOffset) const {
3709 for (unsigned i = 0, e = types.size(); i != e; ++i)
3710 if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3711 return failure();
3712 return success();
3713}
3714LogicalResult
3715TypeConverter::convertSignatureArg(unsigned inputNo, Value value,
3716 SignatureConversion &result) const {
3717 // Try to convert the given input type.
3718 SmallVector<Type, 1> convertedTypes;
3719 if (failed(convertType(value, convertedTypes)))
3720 return failure();
3721
3722 // If this argument is being dropped, there is nothing left to do.
3723 if (convertedTypes.empty())
3724 return success();
3725
3726 // Otherwise, add the new inputs.
3727 result.addInputs(inputNo, convertedTypes);
3728 return success();
3729}
3730LogicalResult
3731TypeConverter::convertSignatureArgs(ValueRange values,
3732 SignatureConversion &result,
3733 unsigned origInputOffset) const {
3734 for (unsigned i = 0, e = values.size(); i != e; ++i)
3735 if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
3736 return failure();
3737 return success();
3738}
3739
3740Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3741 Location loc, Type resultType,
3742 ValueRange inputs) const {
3743 for (const SourceMaterializationCallbackFn &fn :
3744 llvm::reverse(sourceMaterializations))
3745 if (Value result = fn(builder, resultType, inputs, loc))
3746 return result;
3747 return nullptr;
3748}
3749
3750Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3751 Location loc, Type resultType,
3752 ValueRange inputs,
3753 Type originalType) const {
3754 SmallVector<Value> result = materializeTargetConversion(
3755 builder, loc, TypeRange(resultType), inputs, originalType);
3756 if (result.empty())
3757 return nullptr;
3758 assert(result.size() == 1 && "expected single result");
3759 return result.front();
3760}
3761
3762SmallVector<Value> TypeConverter::materializeTargetConversion(
3763 OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3764 Type originalType) const {
3765 for (const TargetMaterializationCallbackFn &fn :
3766 llvm::reverse(targetMaterializations)) {
3767 SmallVector<Value> result =
3768 fn(builder, resultTypes, inputs, loc, originalType);
3769 if (result.empty())
3770 continue;
3771 assert(TypeRange(ValueRange(result)) == resultTypes &&
3772 "callback produced incorrect number of values or values with "
3773 "incorrect types");
3774 return result;
3775 }
3776 return {};
3777}
3778
3779std::optional<TypeConverter::SignatureConversion>
3780TypeConverter::convertBlockSignature(Block *block) const {
3781 SignatureConversion conversion(block->getNumArguments());
3782 if (failed(convertSignatureArgs(block->getArguments(), conversion)))
3783 return std::nullopt;
3784 return conversion;
3785}
3786
3787//===----------------------------------------------------------------------===//
3788// Type attribute conversion
3789//===----------------------------------------------------------------------===//
3790TypeConverter::AttributeConversionResult
3791TypeConverter::AttributeConversionResult::result(Attribute attr) {
3792 return AttributeConversionResult(attr, resultTag);
3793}
3794
3795TypeConverter::AttributeConversionResult
3796TypeConverter::AttributeConversionResult::na() {
3797 return AttributeConversionResult(nullptr, naTag);
3798}
3799
3800TypeConverter::AttributeConversionResult
3801TypeConverter::AttributeConversionResult::abort() {
3802 return AttributeConversionResult(nullptr, abortTag);
3803}
3804
3805bool TypeConverter::AttributeConversionResult::hasResult() const {
3806 return impl.getInt() == resultTag;
3807}
3808
3809bool TypeConverter::AttributeConversionResult::isNa() const {
3810 return impl.getInt() == naTag;
3811}
3812
3813bool TypeConverter::AttributeConversionResult::isAbort() const {
3814 return impl.getInt() == abortTag;
3815}
3816
3817Attribute TypeConverter::AttributeConversionResult::getResult() const {
3818 assert(hasResult() && "Cannot get result from N/A or abort");
3819 return impl.getPointer();
3820}
3821
3822std::optional<Attribute>
3823TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
3824 for (const TypeAttributeConversionCallbackFn &fn :
3825 llvm::reverse(typeAttributeConversions)) {
3826 AttributeConversionResult res = fn(type, attr);
3827 if (res.hasResult())
3828 return res.getResult();
3829 if (res.isAbort())
3830 return std::nullopt;
3831 }
3832 return std::nullopt;
3833}
3834
3835//===----------------------------------------------------------------------===//
3836// FunctionOpInterfaceSignatureConversion
3837//===----------------------------------------------------------------------===//
3838
3839static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3840 const TypeConverter &typeConverter,
3841 ConversionPatternRewriter &rewriter) {
3842 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3843 if (!type)
3844 return failure();
3845
3846 // Convert the original function types.
3847 TypeConverter::SignatureConversion result(type.getNumInputs());
3848 SmallVector<Type, 1> newResults;
3849 if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3850 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3851 return failure();
3852 if (!funcOp.getFunctionBody().empty())
3853 rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
3854 &typeConverter);
3855
3856 // Update the function signature in-place.
3857 auto newType = FunctionType::get(rewriter.getContext(),
3858 result.getConvertedTypes(), newResults);
3859
3860 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3861
3862 return success();
3863}
3864
3865/// Create a default conversion pattern that rewrites the type signature of a
3866/// FunctionOpInterface op. This only supports ops which use FunctionType to
3867/// represent their type.
3868namespace {
3869struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3870 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3871 MLIRContext *ctx,
3872 const TypeConverter &converter,
3873 PatternBenefit benefit)
3874 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3875
3876 LogicalResult
3877 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3878 ConversionPatternRewriter &rewriter) const override {
3879 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3880 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3881 }
3882};
3883
3884struct AnyFunctionOpInterfaceSignatureConversion
3885 : public OpInterfaceConversionPattern<FunctionOpInterface> {
3886 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3887
3888 LogicalResult
3889 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3890 ConversionPatternRewriter &rewriter) const override {
3891 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3892 }
3893};
3894} // namespace
3895
3896FailureOr<Operation *>
3897mlir::convertOpResultTypes(Operation *op, ValueRange operands,
3898 const TypeConverter &converter,
3899 ConversionPatternRewriter &rewriter) {
3900 assert(op && "Invalid op");
3901 Location loc = op->getLoc();
3902 if (converter.isLegal(op))
3903 return rewriter.notifyMatchFailure(loc, "op already legal");
3904
3905 OperationState newOp(loc, op->getName());
3906 newOp.addOperands(operands);
3907
3908 SmallVector<Type> newResultTypes;
3909 if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
3910 return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3911
3912 newOp.addTypes(newResultTypes);
3913 newOp.addAttributes(op->getAttrs());
3914 return rewriter.create(newOp);
3915}
3916
3917void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3918 StringRef functionLikeOpName, RewritePatternSet &patterns,
3919 const TypeConverter &converter, PatternBenefit benefit) {
3920 patterns.add<FunctionOpInterfaceSignatureConversion>(
3921 functionLikeOpName, patterns.getContext(), converter, benefit);
3922}
3923
3924void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3925 RewritePatternSet &patterns, const TypeConverter &converter,
3926 PatternBenefit benefit) {
3927 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3928 converter, patterns.getContext(), benefit);
3929}
3930
3931//===----------------------------------------------------------------------===//
3932// ConversionTarget
3933//===----------------------------------------------------------------------===//
3934
3935void ConversionTarget::setOpAction(OperationName op,
3936 LegalizationAction action) {
3937 legalOperations[op].action = action;
3938}
3939
3940void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3941 LegalizationAction action) {
3942 for (StringRef dialect : dialectNames)
3943 legalDialects[dialect] = action;
3944}
3945
3946auto ConversionTarget::getOpAction(OperationName op) const
3947 -> std::optional<LegalizationAction> {
3948 std::optional<LegalizationInfo> info = getOpInfo(op);
3949 return info ? info->action : std::optional<LegalizationAction>();
3950}
3951
3952auto ConversionTarget::isLegal(Operation *op) const
3953 -> std::optional<LegalOpDetails> {
3954 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3955 if (!info)
3956 return std::nullopt;
3957
3958 // Returns true if this operation instance is known to be legal.
3959 auto isOpLegal = [&] {
3960 // Handle dynamic legality either with the provided legality function.
3961 if (info->action == LegalizationAction::Dynamic) {
3962 std::optional<bool> result = info->legalityFn(op);
3963 if (result)
3964 return *result;
3965 }
3966
3967 // Otherwise, the operation is only legal if it was marked 'Legal'.
3968 return info->action == LegalizationAction::Legal;
3969 };
3970 if (!isOpLegal())
3971 return std::nullopt;
3972
3973 // This operation is legal, compute any additional legality information.
3974 LegalOpDetails legalityDetails;
3975 if (info->isRecursivelyLegal) {
3976 auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3977 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3978 legalityDetails.isRecursivelyLegal =
3979 legalityFnIt->second(op).value_or(true);
3980 } else {
3981 legalityDetails.isRecursivelyLegal = true;
3982 }
3983 }
3984 return legalityDetails;
3985}
3986
3987bool ConversionTarget::isIllegal(Operation *op) const {
3988 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3989 if (!info)
3990 return false;
3991
3992 if (info->action == LegalizationAction::Dynamic) {
3993 std::optional<bool> result = info->legalityFn(op);
3994 if (!result)
3995 return false;
3996
3997 return !(*result);
3998 }
3999
4000 return info->action == LegalizationAction::Illegal;
4001}
4002
4003static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
4004 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
4005 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
4006 if (!oldCallback)
4007 return newCallback;
4008
4009 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
4010 Operation *op) -> std::optional<bool> {
4011 if (std::optional<bool> result = newCl(op))
4012 return *result;
4013
4014 return oldCl(op);
4015 };
4016 return chain;
4017}
4018
4019void ConversionTarget::setLegalityCallback(
4020 OperationName name, const DynamicLegalityCallbackFn &callback) {
4021 assert(callback && "expected valid legality callback");
4022 auto *infoIt = legalOperations.find(name);
4023 assert(infoIt != legalOperations.end() &&
4024 infoIt->second.action == LegalizationAction::Dynamic &&
4025 "expected operation to already be marked as dynamically legal");
4026 infoIt->second.legalityFn =
4027 composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
4028}
4029
4030void ConversionTarget::markOpRecursivelyLegal(
4031 OperationName name, const DynamicLegalityCallbackFn &callback) {
4032 auto *infoIt = legalOperations.find(name);
4033 assert(infoIt != legalOperations.end() &&
4034 infoIt->second.action != LegalizationAction::Illegal &&
4035 "expected operation to already be marked as legal");
4036 infoIt->second.isRecursivelyLegal = true;
4037 if (callback)
4038 opRecursiveLegalityFns[name] = composeLegalityCallbacks(
4039 std::move(opRecursiveLegalityFns[name]), callback);
4040 else
4041 opRecursiveLegalityFns.erase(name);
4042}
4043
4044void ConversionTarget::setLegalityCallback(
4045 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
4046 assert(callback && "expected valid legality callback");
4047 for (StringRef dialect : dialects)
4048 dialectLegalityFns[dialect] = composeLegalityCallbacks(
4049 std::move(dialectLegalityFns[dialect]), callback);
4050}
4051
4052void ConversionTarget::setLegalityCallback(
4053 const DynamicLegalityCallbackFn &callback) {
4054 assert(callback && "expected valid legality callback");
4055 unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
4056}
4057
4058auto ConversionTarget::getOpInfo(OperationName op) const
4059 -> std::optional<LegalizationInfo> {
4060 // Check for info for this specific operation.
4061 const auto *it = legalOperations.find(op);
4062 if (it != legalOperations.end())
4063 return it->second;
4064 // Check for info for the parent dialect.
4065 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4066 if (dialectIt != legalDialects.end()) {
4067 DynamicLegalityCallbackFn callback;
4068 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4069 if (dialectFn != dialectLegalityFns.end())
4070 callback = dialectFn->second;
4071 return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
4072 callback};
4073 }
4074 // Otherwise, check if we mark unknown operations as dynamic.
4075 if (unknownLegalityFn)
4076 return LegalizationInfo{LegalizationAction::Dynamic,
4077 /*isRecursivelyLegal=*/false, unknownLegalityFn};
4078 return std::nullopt;
4079}
4080
4081#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4082//===----------------------------------------------------------------------===//
4083// PDL Configuration
4084//===----------------------------------------------------------------------===//
4085
4086void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4087 auto &rewriterImpl =
4088 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4089 rewriterImpl.currentTypeConverter = getTypeConverter();
4090}
4091
4092void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4093 auto &rewriterImpl =
4094 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4095 rewriterImpl.currentTypeConverter = nullptr;
4096}
4097
4098/// Remap the given value using the rewriter and the type converter in the
4099/// provided config.
4100static FailureOr<SmallVector<Value>>
4101pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
4102 SmallVector<Value> mappedValues;
4103 if (failed(rewriter.getRemappedValues(values, mappedValues)))
4104 return failure();
4105 return std::move(mappedValues);
4106}
4107
4108void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
4109 patterns.getPDLPatterns().registerRewriteFunction(
4110 "convertValue",
4111 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4112 auto results = pdllConvertValues(
4113 static_cast<ConversionPatternRewriter &>(rewriter), value);
4114 if (failed(results))
4115 return failure();
4116 return results->front();
4117 });
4118 patterns.getPDLPatterns().registerRewriteFunction(
4119 "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
4120 return pdllConvertValues(
4121 static_cast<ConversionPatternRewriter &>(rewriter), values);
4122 });
4123 patterns.getPDLPatterns().registerRewriteFunction(
4124 "convertType",
4125 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4126 auto &rewriterImpl =
4127 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4128 if (const TypeConverter *converter =
4129 rewriterImpl.currentTypeConverter) {
4130 if (Type newType = converter->convertType(type))
4131 return newType;
4132 return failure();
4133 }
4134 return type;
4135 });
4136 patterns.getPDLPatterns().registerRewriteFunction(
4137 "convertTypes",
4138 [](PatternRewriter &rewriter,
4139 TypeRange types) -> FailureOr<SmallVector<Type>> {
4140 auto &rewriterImpl =
4141 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4142 const TypeConverter *converter = rewriterImpl.currentTypeConverter;
4143 if (!converter)
4144 return SmallVector<Type>(types);
4145
4146 SmallVector<Type> remappedTypes;
4147 if (failed(converter->convertTypes(types, remappedTypes)))
4148 return failure();
4149 return std::move(remappedTypes);
4150 });
4151}
4152#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
4153
4154//===----------------------------------------------------------------------===//
4155// Op Conversion Entry Points
4156//===----------------------------------------------------------------------===//
4157
4158/// This is the type of Action that is dispatched when a conversion is applied.
4160 : public tracing::ActionImpl<ApplyConversionAction> {
4161public:
4164 static constexpr StringLiteral tag = "apply-conversion";
4165 static constexpr StringLiteral desc =
4166 "Encapsulate the application of a dialect conversion";
4167
4168 void print(raw_ostream &os) const override { os << tag; }
4169};
4170
4172 const ConversionTarget &target,
4173 const FrozenRewritePatternSet &patterns,
4174 ConversionConfig config,
4175 OpConversionMode mode) {
4176 if (ops.empty())
4177 return success();
4178 MLIRContext *ctx = ops.front()->getContext();
4179 LogicalResult status = success();
4180 SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
4182 [&] {
4183 OperationConverter opConverter(ops.front()->getContext(), target,
4184 patterns, config, mode);
4185 status = opConverter.applyConversion(ops);
4186 },
4187 irUnits);
4188 return status;
4189}
4190
4191//===----------------------------------------------------------------------===//
4192// Partial Conversion
4193//===----------------------------------------------------------------------===//
4194
4195LogicalResult mlir::applyPartialConversion(
4196 ArrayRef<Operation *> ops, const ConversionTarget &target,
4197 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4198 return applyConversion(ops, target, patterns, config,
4199 OpConversionMode::Partial);
4200}
4201LogicalResult
4202mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
4203 const FrozenRewritePatternSet &patterns,
4204 ConversionConfig config) {
4205 return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
4206}
4207
4208//===----------------------------------------------------------------------===//
4209// Full Conversion
4210//===----------------------------------------------------------------------===//
4211
4212LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4213 const ConversionTarget &target,
4214 const FrozenRewritePatternSet &patterns,
4215 ConversionConfig config) {
4216 return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
4217}
4218LogicalResult mlir::applyFullConversion(Operation *op,
4219 const ConversionTarget &target,
4220 const FrozenRewritePatternSet &patterns,
4221 ConversionConfig config) {
4222 return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
4223}
4224
4225//===----------------------------------------------------------------------===//
4226// Analysis Conversion
4227//===----------------------------------------------------------------------===//
4228
4229/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
4230/// op is a top-level module op (which is expected to be isolated from above),
4231/// return that op.
4233 // Check if there is a top-level operation within `ops`. If so, return that
4234 // op.
4235 for (Operation *op : ops) {
4236 if (!op->getParentOp()) {
4237#ifndef NDEBUG
4238 assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
4239 "expected top-level op to be isolated from above");
4240 for (Operation *other : ops)
4241 assert(op->isAncestor(other) &&
4242 "expected ops to have a common ancestor");
4243#endif // NDEBUG
4244 return op;
4245 }
4246 }
4247
4248 // No top-level op. Find a common ancestor.
4249 Operation *commonAncestor =
4250 ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
4251 for (Operation *op : ops.drop_front()) {
4252 while (!commonAncestor->isProperAncestor(op)) {
4253 commonAncestor =
4255 assert(commonAncestor &&
4256 "expected to find a common isolated from above ancestor");
4257 }
4258 }
4259
4260 return commonAncestor;
4261}
4262
4263LogicalResult mlir::applyAnalysisConversion(
4264 ArrayRef<Operation *> ops, ConversionTarget &target,
4265 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4266#ifndef NDEBUG
4267 if (config.legalizableOps)
4268 assert(config.legalizableOps->empty() && "expected empty set");
4269#endif // NDEBUG
4270
4271 // Clone closted common ancestor that is isolated from above.
4272 Operation *commonAncestor = findCommonAncestor(ops);
4273 IRMapping mapping;
4274 Operation *clonedAncestor = commonAncestor->clone(mapping);
4275 // Compute inverse IR mapping.
4276 DenseMap<Operation *, Operation *> inverseOperationMap;
4277 for (auto &it : mapping.getOperationMap())
4278 inverseOperationMap[it.second] = it.first;
4279
4280 // Convert the cloned operations. The original IR will remain unchanged.
4281 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4282 ops, [&](Operation *op) { return mapping.lookup(op); });
4283 LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
4284 OpConversionMode::Analysis);
4285
4286 // Remap `legalizableOps`, so that they point to the original ops and not the
4287 // cloned ops.
4288 if (config.legalizableOps) {
4289 DenseSet<Operation *> originalLegalizableOps;
4290 for (Operation *op : *config.legalizableOps)
4291 originalLegalizableOps.insert(inverseOperationMap[op]);
4292 *config.legalizableOps = std::move(originalLegalizableOps);
4293 }
4294
4295 // Erase the cloned IR.
4296 clonedAncestor->erase();
4297 return status;
4298}
4299
4300LogicalResult
4301mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
4302 const FrozenRewritePatternSet &patterns,
4303 ConversionConfig config) {
4304 return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
4305}
return success()
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
#define DEBUG_TYPE
b getContext())
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition SCCP.cpp:67
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Definition Value.h:309
Location getLoc() const
Return the location for this argument.
Definition Value.h:324
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:318
BlockArgListType getArguments()
Definition Block.h:97
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
UnitAttr getUnitAttr()
Definition Builders.cpp:102
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
MLIRContext * context
Definition Builders.h:204
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition Builders.h:329
Block::iterator getPoint() const
Definition Builders.h:342
bool isSet() const
Returns true if this insert point is set.
Definition Builders.h:339
Block * getBlock() const
Definition Builders.h:341
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition Builders.cpp:477
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:425
This class represents an operand of an operation.
Definition Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1140
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
void destroyOpProperties(OpaqueProperties properties) const
This hooks destroy the op properties.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:247
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:881
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:778
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:863
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:589
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:541
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:234
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:277
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:703
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:255
OperandRange operand_range
Definition Operation.h:400
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:706
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:292
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:826
result_range getResults()
Definition Operation.h:444
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:237
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition Operation.h:929
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
CRTP Implementation of an action.
Definition Action.h:76
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Definition Action.h:66
AttrTypeReplacer.
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
This iterator enumerates elements according to their dominance relationship.
Definition Iterators.h:52
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition Builders.h:310
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition Builders.h:300
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
const ConversionTarget & getTarget()
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, bool isRecursiveLegalization=false)
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, Fn onFailure, bool isRecursiveLegalization=false)
Legalizes the given operations (and their nested operations) to the conversion target.
LogicalResult applyConversion(ArrayRef< Operation * > ops)
Applies the conversion to the given operations (and their nested operations).
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.