MLIR 23.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/Config/mlir-config.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Iterators.h"
17#include "mlir/IR/Operation.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/SaveAndRestore.h"
27#include "llvm/Support/ScopedPrinter.h"
28#include <optional>
29#include <utility>
30
31using namespace mlir;
32using namespace mlir::detail;
33
34#define DEBUG_TYPE "dialect-conversion"
35
36/// A utility function to log a successful result for the given reason.
37template <typename... Args>
38static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
39 LLVM_DEBUG({
40 os.unindent();
41 os.startLine() << "} -> SUCCESS";
42 if (!fmt.empty())
43 os.getOStream() << " : "
44 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
45 os.getOStream() << "\n";
46 });
47}
48
49/// A utility function to log a failure result for the given reason.
50template <typename... Args>
51static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
52 LLVM_DEBUG({
53 os.unindent();
54 os.startLine() << "} -> FAILURE : "
55 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
56 << "\n";
57 });
58}
59
60/// Helper function that computes an insertion point where the given value is
61/// defined and can be used without a dominance violation.
63 Block *insertBlock = value.getParentBlock();
64 Block::iterator insertPt = insertBlock->begin();
65 if (OpResult inputRes = dyn_cast<OpResult>(value))
66 insertPt = ++inputRes.getOwner()->getIterator();
67 return OpBuilder::InsertPoint(insertBlock, insertPt);
68}
69
70/// Helper function that computes an insertion point where the given values are
71/// defined and can be used without a dominance violation.
73 assert(!vals.empty() && "expected at least one value");
74 DominanceInfo domInfo;
76 for (Value v : vals.drop_front()) {
77 // Choose the "later" insertion point.
79 if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
80 nextPt.getPoint())) {
81 // pt is before nextPt => choose nextPt.
82 pt = nextPt;
83 } else {
84#ifndef NDEBUG
85 // nextPt should be before pt => choose pt.
86 // If pt, nextPt are no dominance relationship, then there is no valid
87 // insertion point at which all given values are defined.
88 bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
89 pt.getBlock(), pt.getPoint());
90 assert(dom && "unable to find valid insertion point");
91#endif // NDEBUG
92 }
93 }
94 return pt;
95}
96
97namespace {
98enum OpConversionMode {
99 /// In this mode, the conversion will ignore failed conversions to allow
100 /// illegal operations to co-exist in the IR.
101 Partial,
102
103 /// In this mode, all operations must be legal for the given target for the
104 /// conversion to succeed.
105 Full,
106
107 /// In this mode, operations are analyzed for legality. No actual rewrites are
108 /// applied to the operations on success.
109 Analysis,
110};
111} // namespace
112
113//===----------------------------------------------------------------------===//
114// ConversionValueMapping
115//===----------------------------------------------------------------------===//
116
117/// A vector of SSA values, optimized for the most common case of a single
118/// value.
120
121namespace {
122
123/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
124struct ValueVectorMapInfo {
125 static ValueVector getEmptyKey() { return ValueVector{Value()}; }
126 static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; }
127 static ::llvm::hash_code getHashValue(const ValueVector &val) {
128 return ::llvm::hash_combine_range(val);
129 }
130 static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
131 return LHS == RHS;
132 }
133};
134
135/// This class wraps a IRMapping to provide recursive lookup
136/// functionality, i.e. we will traverse if the mapped value also has a mapping.
137struct ConversionValueMapping {
138 /// Return "true" if an SSA value is mapped to the given value. May return
139 /// false positives.
140 bool isMappedTo(Value value) const { return mappedTo.contains(value); }
141
142 /// Lookup a value in the mapping.
143 ValueVector lookup(const ValueVector &from) const;
144
145 template <typename T>
146 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
147
148 /// Map a value vector to the one provided.
149 template <typename OldVal, typename NewVal>
150 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
151 map(OldVal &&oldVal, NewVal &&newVal) {
152 LLVM_DEBUG({
153 ValueVector next(newVal);
154 while (true) {
155 assert(next != oldVal && "inserting cyclic mapping");
156 auto it = mapping.find(next);
157 if (it == mapping.end())
158 break;
159 next = it->second;
160 }
161 });
162 mappedTo.insert_range(newVal);
163
164 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
165 }
166
167 /// Map a value vector or single value to the one provided.
168 template <typename OldVal, typename NewVal>
169 std::enable_if_t<!IsValueVector<OldVal>::value ||
170 !IsValueVector<NewVal>::value>
171 map(OldVal &&oldVal, NewVal &&newVal) {
172 if constexpr (IsValueVector<OldVal>{}) {
173 map(std::forward<OldVal>(oldVal), ValueVector{newVal});
174 } else if constexpr (IsValueVector<NewVal>{}) {
175 map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
176 } else {
177 map(ValueVector{oldVal}, ValueVector{newVal});
178 }
179 }
180
181 void map(Value oldVal, SmallVector<Value> &&newVal) {
182 map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
183 }
184
185 /// Drop the last mapping for the given values.
186 void erase(const ValueVector &value) { mapping.erase(value); }
187
188private:
189 /// Current value mappings.
191
192 /// All SSA values that are mapped to. May contain false positives.
193 DenseSet<Value> mappedTo;
194};
195} // namespace
196
197/// Marker attribute for pure type conversions. I.e., mappings whose only
198/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
199/// the replacement values of a "replaceOp" call, etc., are not pure type
200/// conversions.)
201static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
202
203/// Return the operation that defines all values in the vector. Return nullptr
204/// if the values are not defined by the same operation.
206 assert(!values.empty() && "expected non-empty value vector");
207 Operation *op = values.front().getDefiningOp();
208 for (Value v : llvm::drop_begin(values)) {
209 if (v.getDefiningOp() != op)
210 return nullptr;
211 }
212 return op;
213}
214
215/// A vector of values is a pure type conversion if all values are defined by
216/// the same operation and the operation has the `kPureTypeConversionMarker`
217/// attribute.
218static bool isPureTypeConversion(const ValueVector &values) {
219 assert(!values.empty() && "expected non-empty value vector");
220 Operation *op = getCommonDefiningOp(values);
221 return op && op->hasAttr(kPureTypeConversionMarker);
222}
223
224ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
225 auto it = mapping.find(from);
226 if (it == mapping.end()) {
227 // No mapping found: The lookup stops here.
228 return {};
229 }
230 return it->second;
231}
232
233//===----------------------------------------------------------------------===//
234// Rewriter and Translation State
235//===----------------------------------------------------------------------===//
236namespace {
237/// This class contains a snapshot of the current conversion rewriter state.
238/// This is useful when saving and undoing a set of rewrites.
239struct RewriterState {
240 RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
241 unsigned numReplacedOps)
242 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
243 numReplacedOps(numReplacedOps) {}
244
245 /// The current number of rewrites performed.
246 unsigned numRewrites;
247
248 /// The current number of ignored operations.
249 unsigned numIgnoredOperations;
250
251 /// The current number of replaced ops that are scheduled for erasure.
252 unsigned numReplacedOps;
253};
254
255//===----------------------------------------------------------------------===//
256// IR rewrites
257//===----------------------------------------------------------------------===//
258
259static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
260
261/// Notify the listener that the given block and its contents are being erased.
262static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
263 for (Operation &op : b)
264 notifyIRErased(listener, op);
265 listener->notifyBlockErased(&b);
266}
267
268/// Notify the listener that the given operation and its contents are being
269/// erased.
270static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
271 for (Region &r : op.getRegions()) {
272 for (Block &b : r) {
273 notifyIRErased(listener, b);
274 }
275 }
276 listener->notifyOperationErased(&op);
277}
278
279/// An IR rewrite that can be committed (upon success) or rolled back (upon
280/// failure).
281///
282/// The dialect conversion keeps track of IR modifications (requested by the
283/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
284/// are directly applied to the IR as the rewriter API is used, some are applied
285/// partially, and some are delayed until the `IRRewrite` objects are committed.
286class IRRewrite {
287public:
288 /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
289 /// Enum values are ordered, so that they can be used in `classof`: first all
290 /// block rewrites, then all operation rewrites.
291 enum class Kind {
292 // Block rewrites
293 CreateBlock,
294 EraseBlock,
295 InlineBlock,
296 MoveBlock,
297 BlockTypeConversion,
298 // Operation rewrites
299 MoveOperation,
300 ModifyOperation,
301 ReplaceOperation,
302 CreateOperation,
303 UnresolvedMaterialization,
304 // Value rewrites
305 ReplaceValue
306 };
307
308 virtual ~IRRewrite() = default;
309
310 /// Roll back the rewrite. Operations may be erased during rollback.
311 virtual void rollback() = 0;
312
313 /// Commit the rewrite. At this point, it is certain that the dialect
314 /// conversion will succeed. All IR modifications, except for operation/block
315 /// erasure, must be performed through the given rewriter.
316 ///
317 /// Instead of erasing operations/blocks, they should merely be unlinked
318 /// commit phase and finally be erased during the cleanup phase. This is
319 /// because internal dialect conversion state (such as `mapping`) may still
320 /// be using them.
321 ///
322 /// Any IR modification that was already performed before the commit phase
323 /// (e.g., insertion of an op) must be communicated to the listener that may
324 /// be attached to the given rewriter.
325 virtual void commit(RewriterBase &rewriter) {}
326
327 /// Cleanup operations/blocks. Cleanup is called after commit.
328 virtual void cleanup(RewriterBase &rewriter) {}
329
330 Kind getKind() const { return kind; }
331
332 static bool classof(const IRRewrite *rewrite) { return true; }
333
334protected:
335 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
336 : kind(kind), rewriterImpl(rewriterImpl) {}
337
338 const ConversionConfig &getConfig() const;
339
340 const Kind kind;
341 ConversionPatternRewriterImpl &rewriterImpl;
342};
343
344/// A block rewrite.
345class BlockRewrite : public IRRewrite {
346public:
347 /// Return the block that this rewrite operates on.
348 Block *getBlock() const { return block; }
349
350 static bool classof(const IRRewrite *rewrite) {
351 return rewrite->getKind() >= Kind::CreateBlock &&
352 rewrite->getKind() <= Kind::BlockTypeConversion;
353 }
354
355protected:
356 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
357 Block *block)
358 : IRRewrite(kind, rewriterImpl), block(block) {}
359
360 // The block that this rewrite operates on.
361 Block *block;
362};
363
364/// A value rewrite.
365class ValueRewrite : public IRRewrite {
366public:
367 /// Return the value that this rewrite operates on.
368 Value getValue() const { return value; }
369
370 static bool classof(const IRRewrite *rewrite) {
371 return rewrite->getKind() == Kind::ReplaceValue;
372 }
373
374protected:
375 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
376 Value value)
377 : IRRewrite(kind, rewriterImpl), value(value) {}
378
379 // The value that this rewrite operates on.
380 Value value;
381};
382
383/// Creation of a block. Block creations are immediately reflected in the IR.
384/// There is no extra work to commit the rewrite. During rollback, the newly
385/// created block is erased.
386class CreateBlockRewrite : public BlockRewrite {
387public:
388 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
389 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
390
391 static bool classof(const IRRewrite *rewrite) {
392 return rewrite->getKind() == Kind::CreateBlock;
393 }
394
395 void commit(RewriterBase &rewriter) override {
396 // The block was already created and inserted. Just inform the listener.
397 if (auto *listener = rewriter.getListener())
398 listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
399 }
400
401 void rollback() override {
402 // Unlink all of the operations within this block, they will be deleted
403 // separately.
404 auto &blockOps = block->getOperations();
405 while (!blockOps.empty())
406 blockOps.remove(blockOps.begin());
407 block->dropAllUses();
408 if (block->getParent())
409 block->erase();
410 else
411 delete block;
412 }
413};
414
415/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
416/// blocks are immediately unlinked, but only erased during cleanup. This makes
417/// it easier to rollback a block erasure: the block is simply inserted into its
418/// original location.
419class EraseBlockRewrite : public BlockRewrite {
420public:
421 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
422 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
423 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
424
425 static bool classof(const IRRewrite *rewrite) {
426 return rewrite->getKind() == Kind::EraseBlock;
427 }
428
429 ~EraseBlockRewrite() override {
430 assert(!block &&
431 "rewrite was neither rolled back nor committed/cleaned up");
432 }
433
434 void rollback() override {
435 // The block (owned by this rewrite) was not actually erased yet. It was
436 // just unlinked. Put it back into its original position.
437 assert(block && "expected block");
438 auto &blockList = region->getBlocks();
439 Region::iterator before = insertBeforeBlock
440 ? Region::iterator(insertBeforeBlock)
441 : blockList.end();
442 blockList.insert(before, block);
443 block = nullptr;
444 }
445
446 void commit(RewriterBase &rewriter) override {
447 assert(block && "expected block");
448
449 // Notify the listener that the block and its contents are being erased.
450 if (auto *listener =
451 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
452 notifyIRErased(listener, *block);
453 }
454
455 void cleanup(RewriterBase &rewriter) override {
456 // Erase the contents of the block.
457 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
458 rewriter.eraseOp(&op);
459 assert(block->empty() && "expected empty block");
460
461 // Erase the block.
462 block->dropAllDefinedValueUses();
463 delete block;
464 block = nullptr;
465 }
466
467private:
468 // The region in which this block was previously contained.
469 Region *region;
470
471 // The original successor of this block before it was unlinked. "nullptr" if
472 // this block was the only block in the region.
473 Block *insertBeforeBlock;
474};
475
476/// Inlining of a block. This rewrite is immediately reflected in the IR.
477/// Note: This rewrite represents only the inlining of the operations. The
478/// erasure of the inlined block is a separate rewrite.
479class InlineBlockRewrite : public BlockRewrite {
480public:
481 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
482 Block *sourceBlock, Block::iterator before)
483 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
484 sourceBlock(sourceBlock),
485 firstInlinedInst(sourceBlock->empty() ? nullptr
486 : &sourceBlock->front()),
487 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
488 // If a listener is attached to the dialect conversion, ops must be moved
489 // one-by-one. When they are moved in bulk, notifications cannot be sent
490 // because the ops that used to be in the source block at the time of the
491 // inlining (before the "commit" phase) are unknown at the time when
492 // notifications are sent (which is during the "commit" phase).
493 assert(!getConfig().listener &&
494 "InlineBlockRewrite not supported if listener is attached");
495 }
496
497 static bool classof(const IRRewrite *rewrite) {
498 return rewrite->getKind() == Kind::InlineBlock;
499 }
500
501 void rollback() override {
502 // Put the operations from the destination block (owned by the rewrite)
503 // back into the source block.
504 if (firstInlinedInst) {
505 assert(lastInlinedInst && "expected operation");
506 sourceBlock->getOperations().splice(sourceBlock->begin(),
507 block->getOperations(),
508 Block::iterator(firstInlinedInst),
509 ++Block::iterator(lastInlinedInst));
510 }
511 }
512
513private:
514 // The block that originally contained the operations.
515 Block *sourceBlock;
516
517 // The first inlined operation.
518 Operation *firstInlinedInst;
519
520 // The last inlined operation.
521 Operation *lastInlinedInst;
522};
523
524/// Moving of a block. This rewrite is immediately reflected in the IR.
525class MoveBlockRewrite : public BlockRewrite {
526public:
527 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
528 Region *previousRegion, Region::iterator previousIt)
529 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
530 region(previousRegion),
531 insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
532 : &*previousIt) {}
533
534 static bool classof(const IRRewrite *rewrite) {
535 return rewrite->getKind() == Kind::MoveBlock;
536 }
537
538 void commit(RewriterBase &rewriter) override {
539 // The block was already moved. Just inform the listener.
540 if (auto *listener = rewriter.getListener()) {
541 // Note: `previousIt` cannot be passed because this is a delayed
542 // notification and iterators into past IR state cannot be represented.
543 listener->notifyBlockInserted(block, /*previous=*/region,
544 /*previousIt=*/{});
545 }
546 }
547
548 void rollback() override {
549 // Move the block back to its original position.
550 Region::iterator before =
551 insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
552 if (Region *currentParent = block->getParent()) {
553 // Block is still in a region, use cheap splice to move it back.
554 region->getBlocks().splice(before, currentParent->getBlocks(), block);
555 return;
556 }
557 // Block was orphaned by a prior rollback, can't splice.
558 region->getBlocks().insert(before, block);
559 }
560
561private:
562 // The region in which this block was previously contained.
563 Region *region;
564
565 // The original successor of this block before it was moved. "nullptr" if
566 // this block was the only block in the region.
567 Block *insertBeforeBlock;
568};
569
570/// Block type conversion. This rewrite is partially reflected in the IR.
571class BlockTypeConversionRewrite : public BlockRewrite {
572public:
573 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
574 Block *origBlock, Block *newBlock)
575 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
576 newBlock(newBlock) {}
577
578 static bool classof(const IRRewrite *rewrite) {
579 return rewrite->getKind() == Kind::BlockTypeConversion;
580 }
581
582 Block *getOrigBlock() const { return block; }
583
584 Block *getNewBlock() const { return newBlock; }
585
586 void commit(RewriterBase &rewriter) override;
587
588 void rollback() override;
589
590private:
591 /// The new block that was created as part of this signature conversion.
592 Block *newBlock;
593};
594
595/// Replacing a value. This rewrite is not immediately reflected in the
596/// IR. An internal IR mapping is updated, but the actual replacement is delayed
597/// until the rewrite is committed.
598class ReplaceValueRewrite : public ValueRewrite {
599public:
600 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
601 const TypeConverter *converter)
602 : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
603 converter(converter) {}
604
605 static bool classof(const IRRewrite *rewrite) {
606 return rewrite->getKind() == Kind::ReplaceValue;
607 }
608
609 void commit(RewriterBase &rewriter) override;
610
611 void rollback() override;
612
613private:
614 /// The current type converter when the value was replaced.
615 const TypeConverter *converter;
616};
617
618/// An operation rewrite.
619class OperationRewrite : public IRRewrite {
620public:
621 /// Return the operation that this rewrite operates on.
622 Operation *getOperation() const { return op; }
623
624 static bool classof(const IRRewrite *rewrite) {
625 return rewrite->getKind() >= Kind::MoveOperation &&
626 rewrite->getKind() <= Kind::UnresolvedMaterialization;
627 }
628
629protected:
630 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
631 Operation *op)
632 : IRRewrite(kind, rewriterImpl), op(op) {}
633
634 // The operation that this rewrite operates on.
635 Operation *op;
636};
637
638/// Moving of an operation. This rewrite is immediately reflected in the IR.
639class MoveOperationRewrite : public OperationRewrite {
640public:
641 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
642 Operation *op, OpBuilder::InsertPoint previous)
643 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
644 block(previous.getBlock()),
645 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
646 ? nullptr
647 : &*previous.getPoint()) {}
648
649 static bool classof(const IRRewrite *rewrite) {
650 return rewrite->getKind() == Kind::MoveOperation;
651 }
652
653 void commit(RewriterBase &rewriter) override {
654 // The operation was already moved. Just inform the listener.
655 if (auto *listener = rewriter.getListener()) {
656 // Note: `previousIt` cannot be passed because this is a delayed
657 // notification and iterators into past IR state cannot be represented.
658 listener->notifyOperationInserted(
659 op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
660 /*insertPt=*/{}));
661 }
662 }
663
664 void rollback() override {
665 // Move the operation back to its original position.
666 Block::iterator before =
667 insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
668 block->getOperations().splice(before, op->getBlock()->getOperations(), op);
669 }
670
671private:
672 // The block in which this operation was previously contained.
673 Block *block;
674
675 // The original successor of this operation before it was moved. "nullptr"
676 // if this operation was the only operation in the region.
677 Operation *insertBeforeOp;
678};
679
680/// In-place modification of an op. This rewrite is immediately reflected in
681/// the IR. The previous state of the operation is stored in this object.
682class ModifyOperationRewrite : public OperationRewrite {
683public:
684 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
685 Operation *op)
686 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
687 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
688 operands(op->operand_begin(), op->operand_end()),
689 successors(op->successor_begin(), op->successor_end()) {
690 if (OpaqueProperties prop = op->getPropertiesStorage()) {
691 // Make a copy of the properties.
692 propertiesStorage = operator new(op->getPropertiesStorageSize());
693 OpaqueProperties propCopy(propertiesStorage);
694 name.initOpProperties(propCopy, /*init=*/prop);
695 }
696 }
697
698 static bool classof(const IRRewrite *rewrite) {
699 return rewrite->getKind() == Kind::ModifyOperation;
700 }
701
702 ~ModifyOperationRewrite() override {
703 assert(!propertiesStorage &&
704 "rewrite was neither committed nor rolled back");
705 }
706
707 void commit(RewriterBase &rewriter) override {
708 // Notify the listener that the operation was modified in-place.
709 if (auto *listener =
710 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
711 listener->notifyOperationModified(op);
712
713 if (propertiesStorage) {
714 OpaqueProperties propCopy(propertiesStorage);
715 // Note: The operation may have been erased in the mean time, so
716 // OperationName must be stored in this object.
717 name.destroyOpProperties(propCopy);
718 operator delete(propertiesStorage);
719 propertiesStorage = nullptr;
720 }
721 }
722
723 void rollback() override {
724 op->setLoc(loc);
725 op->setAttrs(attrs);
726 op->setOperands(operands);
727 for (const auto &it : llvm::enumerate(successors))
728 op->setSuccessor(it.value(), it.index());
729 if (propertiesStorage) {
730 OpaqueProperties propCopy(propertiesStorage);
731 op->copyProperties(propCopy);
732 name.destroyOpProperties(propCopy);
733 operator delete(propertiesStorage);
734 propertiesStorage = nullptr;
735 }
736 }
737
738private:
739 OperationName name;
740 LocationAttr loc;
741 DictionaryAttr attrs;
742 SmallVector<Value, 8> operands;
743 SmallVector<Block *, 2> successors;
744 void *propertiesStorage = nullptr;
745};
746
747/// Replacing an operation. Erasing an operation is treated as a special case
748/// with "null" replacements. This rewrite is not immediately reflected in the
749/// IR. An internal IR mapping is updated, but values are not replaced and the
750/// original op is not erased until the rewrite is committed.
751class ReplaceOperationRewrite : public OperationRewrite {
752public:
753 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
754 Operation *op, const TypeConverter *converter)
755 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
756 converter(converter) {}
757
758 static bool classof(const IRRewrite *rewrite) {
759 return rewrite->getKind() == Kind::ReplaceOperation;
760 }
761
762 void commit(RewriterBase &rewriter) override;
763
764 void rollback() override;
765
766 void cleanup(RewriterBase &rewriter) override;
767
768private:
769 /// An optional type converter that can be used to materialize conversions
770 /// between the new and old values if necessary.
771 const TypeConverter *converter;
772};
773
774class CreateOperationRewrite : public OperationRewrite {
775public:
776 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
777 Operation *op)
778 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
779
780 static bool classof(const IRRewrite *rewrite) {
781 return rewrite->getKind() == Kind::CreateOperation;
782 }
783
784 void commit(RewriterBase &rewriter) override {
785 // The operation was already created and inserted. Just inform the listener.
786 if (auto *listener = rewriter.getListener())
787 listener->notifyOperationInserted(op, /*previous=*/{});
788 }
789
790 void rollback() override;
791};
792
793/// The type of materialization.
794enum MaterializationKind {
795 /// This materialization materializes a conversion from an illegal type to a
796 /// legal one.
797 Target,
798
799 /// This materialization materializes a conversion from a legal type back to
800 /// an illegal one.
801 Source
802};
803
804/// Helper class that stores metadata about an unresolved materialization.
805class UnresolvedMaterializationInfo {
806public:
807 UnresolvedMaterializationInfo() = default;
808 UnresolvedMaterializationInfo(const TypeConverter *converter,
809 MaterializationKind kind, Type originalType)
810 : converterAndKind(converter, kind), originalType(originalType) {}
811
812 /// Return the type converter of this materialization (which may be null).
813 const TypeConverter *getConverter() const {
814 return converterAndKind.getPointer();
815 }
816
817 /// Return the kind of this materialization.
818 MaterializationKind getMaterializationKind() const {
819 return converterAndKind.getInt();
820 }
821
822 /// Return the original type of the SSA value.
823 Type getOriginalType() const { return originalType; }
824
825private:
826 /// The corresponding type converter to use when resolving this
827 /// materialization, and the kind of this materialization.
828 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
829 converterAndKind;
830
831 /// The original type of the SSA value. Only used for target
832 /// materializations.
833 Type originalType;
834};
835
836/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
837/// op. Unresolved materializations fold away or are replaced with
838/// source/target materializations at the end of the dialect conversion.
839class UnresolvedMaterializationRewrite : public OperationRewrite {
840public:
841 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
842 UnrealizedConversionCastOp op,
843 ValueVector mappedValues)
844 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
845 mappedValues(std::move(mappedValues)) {}
846
847 static bool classof(const IRRewrite *rewrite) {
848 return rewrite->getKind() == Kind::UnresolvedMaterialization;
849 }
850
851 void rollback() override;
852
853 UnrealizedConversionCastOp getOperation() const {
854 return cast<UnrealizedConversionCastOp>(op);
855 }
856
857private:
858 /// The values in the conversion value mapping that are being replaced by the
859 /// results of this unresolved materialization.
860 ValueVector mappedValues;
861};
862} // namespace
863
864#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
865/// Return "true" if there is an operation rewrite that matches the specified
866/// rewrite type and operation among the given rewrites.
867template <typename RewriteTy, typename R>
868static bool hasRewrite(R &&rewrites, Operation *op) {
869 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
870 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
871 return rewriteTy && rewriteTy->getOperation() == op;
872 });
873}
874
875/// Return "true" if there is a block rewrite that matches the specified
876/// rewrite type and block among the given rewrites.
877template <typename RewriteTy, typename R>
878static bool hasRewrite(R &&rewrites, Block *block) {
879 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
880 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
881 return rewriteTy && rewriteTy->getBlock() == block;
882 });
883}
884#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
885
886//===----------------------------------------------------------------------===//
887// ConversionPatternRewriterImpl
888//===----------------------------------------------------------------------===//
889namespace mlir {
890namespace detail {
892 explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
893 const ConversionConfig &config,
897
898 //===--------------------------------------------------------------------===//
899 // State Management
900 //===--------------------------------------------------------------------===//
901
902 /// Return the current state of the rewriter.
903 RewriterState getCurrentState();
904
905 /// Apply all requested operation rewrites. This method is invoked when the
906 /// conversion process succeeds.
907 void applyRewrites();
908
909 /// Reset the state of the rewriter to a previously saved point. Optionally,
910 /// the name of the pattern that triggered the rollback can specified for
911 /// debugging purposes.
912 void resetState(RewriterState state, StringRef patternName = "");
913
914 /// Append a rewrite. Rewrites are committed upon success and rolled back upon
915 /// failure.
916 template <typename RewriteTy, typename... Args>
917 void appendRewrite(Args &&...args) {
918 assert(config.allowPatternRollback && "appending rewrites is not allowed");
919 rewrites.push_back(
920 std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
921 }
922
923 /// Undo the rewrites (motions, splits) one by one in reverse order until
924 /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
925 /// that triggered the rollback can specified for debugging purposes.
926 void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
927
928 /// Remap the given values to those with potentially different types. Returns
929 /// success if the values could be remapped, failure otherwise. `valueDiagTag`
930 /// is the tag used when describing a value within a diagnostic, e.g.
931 /// "operand".
932 LogicalResult remapValues(StringRef valueDiagTag,
933 std::optional<Location> inputLoc, ValueRange values,
934 SmallVector<ValueVector> &remapped);
935
936 /// Return "true" if the given operation is ignored, and does not need to be
937 /// converted.
938 bool isOpIgnored(Operation *op) const;
939
940 /// Return "true" if the given operation was replaced or erased.
941 bool wasOpReplaced(Operation *op) const;
942
943 /// Lookup the most recently mapped values with the desired types in the
944 /// mapping, taking into account only replacements. Perform a best-effort
945 /// search for existing materializations with the desired types.
946 ///
947 /// If `skipPureTypeConversions` is "true", materializations that are pure
948 /// type conversions are not considered.
949 ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
950 bool skipPureTypeConversions = false) const;
951
952 /// Lookup the given value within the map, or return an empty vector if the
953 /// value is not mapped. If it is mapped, this follows the same behavior
954 /// as `lookupOrDefault`.
955 ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
956
957 //===--------------------------------------------------------------------===//
958 // IR Rewrites / Type Conversion
959 //===--------------------------------------------------------------------===//
960
961 /// Convert the types of block arguments within the given region.
962 FailureOr<Block *>
963 convertRegionTypes(Region *region, const TypeConverter &converter,
964 TypeConverter::SignatureConversion *entryConversion);
965
966 /// Apply the given signature conversion on the given block. The new block
967 /// containing the updated signature is returned. If no conversions were
968 /// necessary, e.g. if the block has no arguments, `block` is returned.
969 /// `converter` is used to generate any necessary cast operations that
970 /// translate between the origin argument types and those specified in the
971 /// signature conversion.
972 Block *applySignatureConversion(
973 Block *block, const TypeConverter *converter,
974 TypeConverter::SignatureConversion &signatureConversion);
975
976 /// Replace the results of the given operation with the given values and
977 /// erase the operation.
978 ///
979 /// There can be multiple replacement values for each result (1:N
980 /// replacement). If the replacement values are empty, the respective result
981 /// is dropped and a source materialization is built if the result still has
982 /// uses.
983 void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
984
985 /// Replace the uses of the given value with the given values. The specified
986 /// converter is used to build materializations (if necessary). If `functor`
987 /// is specified, only the uses that the functor returns "true" for are
988 /// replaced.
989 void replaceValueUses(Value from, ValueRange to,
990 const TypeConverter *converter,
991 function_ref<bool(OpOperand &)> functor = nullptr);
992
993 /// Erase the given block and its contents.
994 void eraseBlock(Block *block);
995
996 /// Inline the source block into the destination block before the given
997 /// iterator.
998 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
999
1000 //===--------------------------------------------------------------------===//
1001 // Materializations
1002 //===--------------------------------------------------------------------===//
1003
1004 /// Build an unresolved materialization operation given a range of output
1005 /// types and a list of input operands. Returns the inputs if they their
1006 /// types match the output types.
1007 ///
1008 /// If a cast op was built, it can optionally be returned with the `castOp`
1009 /// output argument.
1010 ///
1011 /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
1012 /// the results of the unresolved materialization in the conversion value
1013 /// mapping.
1014 ///
1015 /// If `isPureTypeConversion` is "true", the materialization is created only
1016 /// to resolve a type mismatch. That means it is not a regular value
1017 /// replacement issued by the user. (Replacement values that are created
1018 /// "out of thin air" appear like unresolved materializations because they are
1019 /// unrealized_conversion_cast ops. However, they must be treated like
1020 /// regular value replacements.)
1021 ValueRange buildUnresolvedMaterialization(
1022 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1023 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1024 Type originalType, const TypeConverter *converter,
1025 bool isPureTypeConversion = true);
1026
1027 /// Find a replacement value for the given SSA value in the conversion value
1028 /// mapping. The replacement value must have the same type as the given SSA
1029 /// value. If there is no replacement value with the correct type, find the
1030 /// latest replacement value (regardless of the type) and build a source
1031 /// materialization.
1032 Value findOrBuildReplacementValue(Value value,
1033 const TypeConverter *converter);
1034
1035 //===--------------------------------------------------------------------===//
1036 // Rewriter Notification Hooks
1037 //===--------------------------------------------------------------------===//
1038
1039 //// Notifies that an op was inserted.
1040 void notifyOperationInserted(Operation *op,
1041 OpBuilder::InsertPoint previous) override;
1042
1043 /// Notifies that a block was inserted.
1044 void notifyBlockInserted(Block *block, Region *previous,
1045 Region::iterator previousIt) override;
1046
1047 /// Notifies that a pattern match failed for the given reason.
1048 void
1049 notifyMatchFailure(Location loc,
1050 function_ref<void(Diagnostic &)> reasonCallback) override;
1051
1052 //===--------------------------------------------------------------------===//
1053 // IR Erasure
1054 //===--------------------------------------------------------------------===//
1055
1056 /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1057 /// operation or block is erased multiple times. This rewriter assumes that
1058 /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1060 public:
1063 std::function<void(Operation *)> opErasedCallback = nullptr)
1064 : RewriterBase(context, /*listener=*/this),
1065 opErasedCallback(std::move(opErasedCallback)) {}
1066
1067 /// Erase the given op (unless it was already erased).
1068 void eraseOp(Operation *op) override {
1069 if (wasErased(op))
1070 return;
1071 op->dropAllUses();
1073 }
1074
1075 /// Erase the given block (unless it was already erased).
1076 void eraseBlock(Block *block) override {
1077 if (wasErased(block))
1078 return;
1079 assert(block->empty() && "expected empty block");
1080 block->dropAllDefinedValueUses();
1082 }
1083
1084 bool wasErased(void *ptr) const { return erased.contains(ptr); }
1085
1087 erased.insert(op);
1088 if (opErasedCallback)
1089 opErasedCallback(op);
1090 }
1091
1092 void notifyBlockErased(Block *block) override { erased.insert(block); }
1093
1094 private:
1095 /// Pointers to all erased operations and blocks.
1096 DenseSet<void *> erased;
1097
1098 /// A callback that is invoked when an operation is erased.
1099 std::function<void(Operation *)> opErasedCallback;
1100 };
1101
1102 //===--------------------------------------------------------------------===//
1103 // State
1104 //===--------------------------------------------------------------------===//
1105
1106 /// The rewriter that is used to perform the conversion.
1107 ConversionPatternRewriter &rewriter;
1108
1109 // Mapping between replaced values that differ in type. This happens when
1110 // replacing a value with one of a different type.
1111 ConversionValueMapping mapping;
1112
1113 /// Ordered list of block operations (creations, splits, motions).
1114 /// This vector is maintained only if `allowPatternRollback` is set to
1115 /// "true". Otherwise, all IR rewrites are materialized immediately and no
1116 /// bookkeeping is needed.
1118
1119 /// A set of operations that should no longer be considered for legalization.
1120 /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1121 /// tracked separately.
1123
1124 /// A set of operations that were replaced/erased. Such ops are not erased
1125 /// immediately but only when the dialect conversion succeeds. In the mean
1126 /// time, they should no longer be considered for legalization and any attempt
1127 /// to modify/access them is invalid rewriter API usage.
1129
1130 /// A set of operations that were created by the current pattern.
1132
1133 /// A set of operations that were modified by the current pattern.
1135
1136 /// A list of unresolved materializations that were created by the current
1137 /// pattern.
1139
1140 /// A mapping for looking up metadata of unresolved materializations.
1143
1144 /// The current type converter, or nullptr if no type converter is currently
1145 /// active.
1147
1148 /// A mapping of regions to type converters that should be used when
1149 /// converting the arguments of blocks within that region.
1151
1152 /// Dialect conversion configuration.
1153 const ConversionConfig &config;
1154
1155 /// The operation converter to use for recursive legalization.
1157
1158 /// A set of erased operations. This set is utilized only if
1159 /// `allowPatternRollback` is set to "false". Conceptually, this set is
1160 /// similar to `replacedOps` (which is maintained when the flag is set to
1161 /// "true"). However, erasing from a DenseSet is more efficient than erasing
1162 /// from a SetVector.
1164
1165 /// A set of erased blocks. This set is utilized only if
1166 /// `allowPatternRollback` is set to "false".
1168
1169 /// A rewriter that notifies the listener (if any) about all IR
1170 /// modifications. This rewriter is utilized only if `allowPatternRollback`
1171 /// is set to "false". If the flag is set to "true", the listener is notified
1172 /// with a separate mechanism (e.g., in `IRRewrite::commit`).
1174
1175#ifndef NDEBUG
1176 /// A set of replaced values. This set is for debugging purposes only and it
1177 /// is maintained only if `allowPatternRollback` is set to "true".
1179
1180 /// A set of operations that have pending updates. This tracking isn't
1181 /// strictly necessary, and is thus only active during debug builds for extra
1182 /// verification.
1184
1185 /// A raw output stream used to prefix the debug log.
1186 llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
1187 llvm::dbgs()};
1188
1189 /// A logger used to emit diagnostics during the conversion process.
1190 llvm::ScopedPrinter logger{os};
1191 std::string logPrefix;
1192#endif
1193};
1194} // namespace detail
1195} // namespace mlir
1196
1197const ConversionConfig &IRRewrite::getConfig() const {
1198 return rewriterImpl.config;
1199}
1200
1201void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1202 // Inform the listener about all IR modifications that have already taken
1203 // place: References to the original block have been replaced with the new
1204 // block.
1205 if (auto *listener =
1206 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1207 for (Operation *op : getNewBlock()->getUsers())
1208 listener->notifyOperationModified(op);
1209}
1210
1211void BlockTypeConversionRewrite::rollback() {
1212 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1213}
1214
1215/// Replace all uses of `from` with `repl`.
1216static void
1218 function_ref<bool(OpOperand &)> functor = nullptr) {
1219 if (isa<BlockArgument>(repl)) {
1220 // `repl` is a block argument. Directly replace all uses.
1221 if (functor) {
1222 rewriter.replaceUsesWithIf(from, repl, functor);
1223 } else {
1224 rewriter.replaceAllUsesWith(from, repl);
1225 }
1226 return;
1227 }
1228
1229 // If the replacement value is an operation, only replace those uses that:
1230 // - are in a different block than the replacement operation, or
1231 // - are in the same block but after the replacement operation.
1232 //
1233 // Example:
1234 // ^bb0(%arg0: i32):
1235 // %0 = "consumer"(%arg0) : (i32) -> (i32)
1236 // "another_consumer"(%arg0) : (i32) -> ()
1237 //
1238 // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1239 // use in "another_consumer" but not the use in "consumer". When using the
1240 // normal RewriterBase API, this would typically be done with
1241 // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1242 // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1243 // it cannot be supported efficiently with `allowPatternRollback` set to
1244 // "true". Therefore, the conversion driver is trying to be smart and replaces
1245 // only those uses that do not lead to a dominance violation. E.g., the
1246 // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1247 // behavior.
1248 //
1249 // TODO: As we move more and more towards `allowPatternRollback` set to
1250 // "false", we should remove this special handling, in order to align the
1251 // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1252 Operation *replOp = repl.getDefiningOp();
1253 Block *replBlock = replOp->getBlock();
1254 rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
1255 Operation *user = operand.getOwner();
1256 bool result =
1257 user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1258 if (result && functor)
1259 result &= functor(operand);
1260 return result;
1261 });
1262}
1263
1264void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1265 Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
1266 if (!repl)
1267 return;
1268 performReplaceValue(rewriter, value, repl);
1269}
1270
1271void ReplaceValueRewrite::rollback() {
1272 rewriterImpl.mapping.erase({value});
1273#ifndef NDEBUG
1274 rewriterImpl.replacedValues.erase(value);
1275#endif // NDEBUG
1276}
1277
1278void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1279 auto *listener =
1280 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1281
1282 // Compute replacement values.
1283 SmallVector<Value> replacements =
1284 llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1285 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1286 });
1287
1288 // Notify the listener that the operation is about to be replaced.
1289 if (listener)
1290 listener->notifyOperationReplaced(op, replacements);
1291
1292 // Replace all uses with the new values.
1293 for (auto [result, newValue] :
1294 llvm::zip_equal(op->getResults(), replacements))
1295 if (newValue)
1296 rewriter.replaceAllUsesWith(result, newValue);
1297
1298 // The original op will be erased, so remove it from the set of unlegalized
1299 // ops.
1300 if (getConfig().unlegalizedOps)
1301 getConfig().unlegalizedOps->erase(op);
1302
1303 // Notify the listener that the operation and its contents are being erased.
1304 if (listener)
1305 notifyIRErased(listener, *op);
1306
1307 // Do not erase the operation yet. It may still be referenced in `mapping`.
1308 // Just unlink it for now and erase it during cleanup.
1309 op->getBlock()->getOperations().remove(op);
1310}
1311
1312void ReplaceOperationRewrite::rollback() {
1313 for (auto result : op->getResults())
1314 rewriterImpl.mapping.erase({result});
1315}
1316
1317void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1318 rewriter.eraseOp(op);
1319}
1320
1321void CreateOperationRewrite::rollback() {
1322 for (Region &region : op->getRegions()) {
1323 while (!region.getBlocks().empty())
1324 region.getBlocks().remove(region.getBlocks().begin());
1325 }
1326 op->dropAllUses();
1327 op->erase();
1328}
1329
1330void UnresolvedMaterializationRewrite::rollback() {
1331 if (!mappedValues.empty())
1332 rewriterImpl.mapping.erase(mappedValues);
1333 rewriterImpl.unresolvedMaterializations.erase(getOperation());
1334 op->erase();
1335}
1336
1338 // Commit all rewrites. Use a new rewriter, so the modifications are not
1339 // tracked for rollback purposes etc.
1340 IRRewriter irRewriter(rewriter.getContext(), config.listener);
1341 // Note: New rewrites may be added during the "commit" phase and the
1342 // `rewrites` vector may reallocate.
1343 for (size_t i = 0; i < rewrites.size(); ++i)
1344 rewrites[i]->commit(irRewriter);
1345
1346 // Clean up all rewrites.
1347 SingleEraseRewriter eraseRewriter(
1348 rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
1349 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1350 unresolvedMaterializations.erase(castOp);
1351 });
1352 for (auto &rewrite : rewrites)
1353 rewrite->cleanup(eraseRewriter);
1354}
1355
1356//===----------------------------------------------------------------------===//
1357// State Management
1358//===----------------------------------------------------------------------===//
1359
1361 Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1362 // Helper function that looks up a single value.
1363 auto lookup = [&](const ValueVector &values) -> ValueVector {
1364 assert(!values.empty() && "expected non-empty value vector");
1365
1366 // If the pattern rollback is enabled, use the mapping to look up the
1367 // values.
1368 if (config.allowPatternRollback)
1369 return mapping.lookup(values);
1370
1371 // Otherwise, look up values by examining the IR. All replacements have
1372 // already been materialized in IR.
1373 Operation *op = getCommonDefiningOp(values);
1374 if (!op)
1375 return {};
1376 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1377 if (!castOp)
1378 return {};
1379 if (!this->unresolvedMaterializations.contains(castOp))
1380 return {};
1381 if (castOp.getOutputs() != values)
1382 return {};
1383 return castOp.getInputs();
1384 };
1385
1386 // Helper function that looks up each value in `values` individually and then
1387 // composes the results. If that fails, it tries to look up the entire vector
1388 // at once.
1389 auto composedLookup = [&](const ValueVector &values) -> ValueVector {
1390 // If possible, replace each value with (one or multiple) mapped values.
1391 ValueVector next;
1392 for (Value v : values) {
1393 ValueVector r = lookup({v});
1394 if (!r.empty()) {
1395 llvm::append_range(next, r);
1396 } else {
1397 next.push_back(v);
1398 }
1399 }
1400 if (next != values) {
1401 // At least one value was replaced.
1402 return next;
1403 }
1404
1405 // Otherwise: Check if there is a mapping for the entire vector. Such
1406 // mappings are materializations. (N:M mapping are not supported for value
1407 // replacements.)
1408 //
1409 // Note: From a correctness point of view, materializations do not have to
1410 // be stored (and looked up) in the mapping. But for performance reasons,
1411 // we choose to reuse existing IR (when possible) instead of creating it
1412 // multiple times.
1413 ValueVector r = lookup(values);
1414 if (r.empty()) {
1415 // No mapping found: The lookup stops here.
1416 return {};
1417 }
1418 return r;
1419 };
1420
1421 // Try to find the deepest values that have the desired types. If there is no
1422 // such mapping, simply return the deepest values.
1423 ValueVector desiredValue;
1424 ValueVector current{from};
1425 ValueVector lastNonMaterialization{from};
1426 do {
1427 // Store the current value if the types match.
1428 bool match = TypeRange(ValueRange(current)) == desiredTypes;
1429 if (skipPureTypeConversions) {
1430 // Skip pure type conversions, if requested.
1431 bool pureConversion = isPureTypeConversion(current);
1432 match &= !pureConversion;
1433 // Keep track of the last mapped value that was not a pure type
1434 // conversion.
1435 if (!pureConversion)
1436 lastNonMaterialization = current;
1437 }
1438 if (match)
1439 desiredValue = current;
1440
1441 // Lookup next value in the mapping.
1442 ValueVector next = composedLookup(current);
1443 if (next.empty())
1444 break;
1445 current = std::move(next);
1446 } while (true);
1447
1448 // If the desired values were found use them, otherwise default to the leaf
1449 // values. (Skip pure type conversions, if requested.)
1450 if (!desiredTypes.empty())
1451 return desiredValue;
1452 if (skipPureTypeConversions)
1453 return lastNonMaterialization;
1454 return current;
1455}
1456
1459 TypeRange desiredTypes) const {
1460 ValueVector result = lookupOrDefault(from, desiredTypes);
1461 if (result == ValueVector{from} ||
1462 (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
1463 return {};
1464 return result;
1465}
1466
1468 return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1469}
1470
1472 StringRef patternName) {
1473 // Undo any rewrites.
1474 undoRewrites(state.numRewrites, patternName);
1475
1476 // Pop all of the recorded ignored operations that are no longer valid.
1477 while (ignoredOps.size() != state.numIgnoredOperations)
1478 ignoredOps.pop_back();
1479
1480 while (replacedOps.size() != state.numReplacedOps)
1481 replacedOps.pop_back();
1482}
1483
1484void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1485 StringRef patternName) {
1486 for (auto &rewrite :
1487 llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1488 rewrite->rollback();
1489 rewrites.resize(numRewritesToKeep);
1490}
1491
1493 StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
1494 SmallVector<ValueVector> &remapped) {
1495 remapped.reserve(llvm::size(values));
1496
1497 for (const auto &it : llvm::enumerate(values)) {
1498 Value operand = it.value();
1499 Type origType = operand.getType();
1500 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1501
1502 if (!currentTypeConverter) {
1503 // The current pattern does not have a type converter. Pass the most
1504 // recently mapped values, excluding materializations. Materializations
1505 // are intentionally excluded because their presence may depend on other
1506 // patterns. Including materializations would make the lookup fragile
1507 // and unpredictable.
1508 remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1509 /*skipPureTypeConversions=*/true));
1510 continue;
1511 }
1512
1513 // If there is no legal conversion, fail to match this pattern.
1514 SmallVector<Type, 1> legalTypes;
1515 if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
1516 notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1517 diag << "unable to convert type for " << valueDiagTag << " #"
1518 << it.index() << ", type was " << origType;
1519 });
1520 return failure();
1521 }
1522 // If a type is converted to 0 types, there is nothing to do.
1523 if (legalTypes.empty()) {
1524 remapped.push_back({});
1525 continue;
1526 }
1527
1528 ValueVector repl = lookupOrDefault(operand, legalTypes);
1529 if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1530 // Mapped values have the correct type or there is an existing
1531 // materialization. Or the operand is not mapped at all and has the
1532 // correct type.
1533 remapped.push_back(std::move(repl));
1534 continue;
1535 }
1536
1537 // Create a materialization for the most recently mapped values.
1538 repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1539 /*skipPureTypeConversions=*/true);
1541 MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1542 /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1543 /*originalType=*/origType, currentTypeConverter);
1544 remapped.push_back(castValues);
1545 }
1546 return success();
1547}
1548
1550 // Check to see if this operation is ignored or was replaced.
1551 return wasOpReplaced(op) || ignoredOps.count(op);
1552}
1553
1555 // Check to see if this operation was replaced.
1556 return replacedOps.count(op) || erasedOps.count(op);
1557}
1558
1559//===----------------------------------------------------------------------===//
1560// Type Conversion
1561//===----------------------------------------------------------------------===//
1562
1564 Region *region, const TypeConverter &converter,
1565 TypeConverter::SignatureConversion *entryConversion) {
1566 regionToConverter[region] = &converter;
1567 if (region->empty())
1568 return nullptr;
1569
1570 // Convert the arguments of each non-entry block within the region.
1571 for (Block &block :
1572 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1573 // Compute the signature for the block with the provided converter.
1574 std::optional<TypeConverter::SignatureConversion> conversion =
1575 converter.convertBlockSignature(&block);
1576 if (!conversion)
1577 return failure();
1578 // Convert the block with the computed signature.
1579 applySignatureConversion(&block, &converter, *conversion);
1580 }
1581
1582 // Convert the entry block. If an entry signature conversion was provided,
1583 // use that one. Otherwise, compute the signature with the type converter.
1584 if (entryConversion)
1585 return applySignatureConversion(&region->front(), &converter,
1586 *entryConversion);
1587 std::optional<TypeConverter::SignatureConversion> conversion =
1588 converter.convertBlockSignature(&region->front());
1589 if (!conversion)
1590 return failure();
1591 return applySignatureConversion(&region->front(), &converter, *conversion);
1592}
1593
1595 Block *block, const TypeConverter *converter,
1596 TypeConverter::SignatureConversion &signatureConversion) {
1597#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1598 // A block cannot be converted multiple times.
1599 if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1600 llvm::reportFatalInternalError("block was already converted");
1601#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1602
1604
1605 // If no arguments are being changed or added, there is nothing to do.
1606 unsigned origArgCount = block->getNumArguments();
1607 auto convertedTypes = signatureConversion.getConvertedTypes();
1608 if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1609 return block;
1610
1611 // Compute the locations of all block arguments in the new block.
1612 SmallVector<Location> newLocs(convertedTypes.size(),
1613 rewriter.getUnknownLoc());
1614 for (unsigned i = 0; i < origArgCount; ++i) {
1615 auto inputMap = signatureConversion.getInputMapping(i);
1616 if (!inputMap || inputMap->replacedWithValues())
1617 continue;
1618 Location origLoc = block->getArgument(i).getLoc();
1619 for (unsigned j = 0; j < inputMap->size; ++j)
1620 newLocs[inputMap->inputNo + j] = origLoc;
1621 }
1622
1623 // Insert a new block with the converted block argument types and move all ops
1624 // from the old block to the new block.
1625 Block *newBlock =
1626 rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1627 convertedTypes, newLocs);
1628
1629 // If a listener is attached to the dialect conversion, ops cannot be moved
1630 // to the destination block in bulk ("fast path"). This is because at the time
1631 // the notifications are sent, it is unknown which ops were moved. Instead,
1632 // ops should be moved one-by-one ("slow path"), so that a separate
1633 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1634 // a bit more efficient, so we try to do that when possible.
1635 bool fastPath = !config.listener;
1636 if (fastPath) {
1637 if (config.allowPatternRollback)
1638 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1639 newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1640 } else {
1641 while (!block->empty())
1642 rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1643 }
1644
1645 // Replace all uses of the old block with the new block.
1646 block->replaceAllUsesWith(newBlock);
1647
1648 for (unsigned i = 0; i != origArgCount; ++i) {
1649 BlockArgument origArg = block->getArgument(i);
1650 Type origArgType = origArg.getType();
1651
1652 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1653 signatureConversion.getInputMapping(i);
1654 if (!inputMap) {
1655 // This block argument was dropped and no replacement value was provided.
1656 // Materialize a replacement value "out of thin air".
1657 // Note: Materialization must be built here because we cannot find a
1658 // valid insertion point in the new block. (Will point to the old block.)
1659 Value mat =
1661 MaterializationKind::Source,
1662 OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1663 origArg.getLoc(),
1664 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1665 /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1666 /*isPureTypeConversion=*/false)
1667 .front();
1668 replaceValueUses(origArg, mat, converter);
1669 continue;
1670 }
1671
1672 if (inputMap->replacedWithValues()) {
1673 // This block argument was dropped and replacement values were provided.
1674 assert(inputMap->size == 0 &&
1675 "invalid to provide a replacement value when the argument isn't "
1676 "dropped");
1677 replaceValueUses(origArg, inputMap->replacementValues, converter);
1678 continue;
1679 }
1680
1681 // This is a 1->1+ mapping.
1682 auto replArgs =
1683 newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1684 replaceValueUses(origArg, replArgs, converter);
1685 }
1686
1687 if (config.allowPatternRollback)
1688 appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1689
1690 // Erase the old block. (It is just unlinked for now and will be erased during
1691 // cleanup.)
1692 rewriter.eraseBlock(block);
1693
1694 return newBlock;
1695}
1696
1697//===----------------------------------------------------------------------===//
1698// Materializations
1699//===----------------------------------------------------------------------===//
1700
1701/// Build an unresolved materialization operation given an output type and set
1702/// of input operands.
1704 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1705 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1706 Type originalType, const TypeConverter *converter,
1707 bool isPureTypeConversion) {
1708 assert((!originalType || kind == MaterializationKind::Target) &&
1709 "original type is valid only for target materializations");
1710 assert(TypeRange(inputs) != outputTypes &&
1711 "materialization is not necessary");
1712
1713 // Create an unresolved materialization. We use a new OpBuilder to avoid
1714 // tracking the materialization like we do for other operations.
1715 OpBuilder builder(outputTypes.front().getContext());
1716 builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1717 UnrealizedConversionCastOp convertOp =
1718 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1719 if (config.attachDebugMaterializationKind) {
1720 StringRef kindStr =
1721 kind == MaterializationKind::Source ? "source" : "target";
1722 convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
1723 }
1725 convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
1726
1727 // Register the materialization.
1728 unresolvedMaterializations[convertOp] =
1729 UnresolvedMaterializationInfo(converter, kind, originalType);
1730 if (config.allowPatternRollback) {
1731 if (!valuesToMap.empty())
1732 mapping.map(valuesToMap, convertOp.getResults());
1734 std::move(valuesToMap));
1735 } else {
1736 patternMaterializations.insert(convertOp);
1737 }
1738 return convertOp.getResults();
1739}
1740
1742 Value value, const TypeConverter *converter) {
1743 assert(config.allowPatternRollback &&
1744 "this code path is valid only in rollback mode");
1745
1746 // Try to find a replacement value with the same type in the conversion value
1747 // mapping. This includes cached materializations. We try to reuse those
1748 // instead of generating duplicate IR.
1749 ValueVector repl = lookupOrNull(value, value.getType());
1750 if (!repl.empty())
1751 return repl.front();
1752
1753 // Check if the value is dead. No replacement value is needed in that case.
1754 // This is an approximate check that may have false negatives but does not
1755 // require computing and traversing an inverse mapping. (We may end up
1756 // building source materializations that are never used and that fold away.)
1757 if (llvm::all_of(value.getUsers(),
1758 [&](Operation *op) { return replacedOps.contains(op); }) &&
1759 !mapping.isMappedTo(value))
1760 return Value();
1761
1762 // No replacement value was found. Get the latest replacement value
1763 // (regardless of the type) and build a source materialization to the
1764 // original type.
1765 repl = lookupOrNull(value);
1766
1767 // Compute the insertion point of the materialization.
1769 if (repl.empty()) {
1770 // The source materialization has no inputs. Insert it right before the
1771 // value that it is replacing.
1772 ip = computeInsertPoint(value);
1773 } else {
1774 // Compute the "earliest" insertion point at which all values in `repl` are
1775 // defined. It is important to emit the materialization at that location
1776 // because the same materialization may be reused in a different context.
1777 // (That's because materializations are cached in the conversion value
1778 // mapping.) The insertion point of the materialization must be valid for
1779 // all future users that may be created later in the conversion process.
1780 ip = computeInsertPoint(repl);
1781 }
1783 MaterializationKind::Source, ip, value.getLoc(),
1784 /*valuesToMap=*/repl, /*inputs=*/repl,
1785 /*outputTypes=*/value.getType(),
1786 /*originalType=*/Type(), converter,
1787 /*isPureTypeConversion=*/!repl.empty())
1788 .front();
1789 return castValue;
1790}
1791
1792//===----------------------------------------------------------------------===//
1793// Rewriter Notification Hooks
1794//===----------------------------------------------------------------------===//
1795
1797 Operation *op, OpBuilder::InsertPoint previous) {
1798 // If no previous insertion point is provided, the op used to be detached.
1799 bool wasDetached = !previous.isSet();
1800 LLVM_DEBUG({
1801 logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
1802 << ")";
1803 if (wasDetached)
1804 logger.getOStream() << " (was detached)";
1805 logger.getOStream() << "\n";
1806 });
1807
1808 // In rollback mode, it is easier to misuse the API, so perform extra error
1809 // checking.
1810 assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
1811 "attempting to insert into a block within a replaced/erased op");
1812
1813 // In "no rollback" mode, the listener is always notified immediately.
1814 if (!config.allowPatternRollback && config.listener)
1815 config.listener->notifyOperationInserted(op, previous);
1816
1817 if (wasDetached) {
1818 // If the op was detached, it is most likely a newly created op. Add it the
1819 // set of newly created ops, so that it will be legalized. If this op is
1820 // not a newly created op, it will be legalized a second time, which is
1821 // inefficient but harmless.
1822 patternNewOps.insert(op);
1823
1824 if (config.allowPatternRollback) {
1825 // TODO: If the same op is inserted multiple times from a detached
1826 // state, the rollback mechanism may erase the same op multiple times.
1827 // This is a bug in the rollback-based dialect conversion driver.
1829 } else {
1830 // In "no rollback" mode, there is an extra data structure for tracking
1831 // erased operations that must be kept up to date.
1832 erasedOps.erase(op);
1833 }
1834 return;
1835 }
1836
1837 // The op was moved from one place to another.
1838 if (config.allowPatternRollback)
1840}
1841
1842/// Given that `fromRange` is about to be replaced with `toRange`, compute
1843/// replacement values with the types of `fromRange`.
1844static SmallVector<Value>
1846 const SmallVector<SmallVector<Value>> &toRange,
1847 const TypeConverter *converter) {
1848 assert(!impl.config.allowPatternRollback &&
1849 "this code path is valid only in 'no rollback' mode");
1850 SmallVector<Value> repls;
1851 for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1852 if (from.use_empty()) {
1853 // The replaced value is dead. No replacement value is needed.
1854 repls.push_back(Value());
1855 continue;
1856 }
1857
1858 if (to.empty()) {
1859 // The replaced value is dropped. Materialize a replacement value "out of
1860 // thin air".
1861 Value srcMat = impl.buildUnresolvedMaterialization(
1862 MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
1863 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1864 /*outputTypes=*/from.getType(), /*originalType=*/Type(),
1865 converter)[0];
1866 repls.push_back(srcMat);
1867 continue;
1868 }
1869
1870 if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
1871 // The replacement value already has the correct type. Use it directly.
1872 repls.push_back(to[0]);
1873 continue;
1874 }
1875
1876 // The replacement value has the wrong type. Build a source materialization
1877 // to the original type.
1878 // TODO: This is a bit inefficient. We should try to reuse existing
1879 // materializations if possible. This would require an extension of the
1880 // `lookupOrDefault` API.
1881 Value srcMat = impl.buildUnresolvedMaterialization(
1882 MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
1883 /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
1884 /*originalType=*/Type(), converter)[0];
1885 repls.push_back(srcMat);
1886 }
1887
1888 return repls;
1889}
1890
1892 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1893 assert(newValues.size() == op->getNumResults() &&
1894 "incorrect number of replacement values");
1895 LLVM_DEBUG({
1896 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
1897 << ")\n";
1899 // If the user-provided replacement types are different from the
1900 // legalized types, as per the current type converter, print a note.
1901 // In most cases, the replacement types are expected to match the types
1902 // produced by the type converter, so this could indicate a bug in the
1903 // user code.
1904 for (auto [result, repls] :
1905 llvm::zip_equal(op->getResults(), newValues)) {
1906 Type resultType = result.getType();
1907 auto logProlog = [&, repls = repls]() {
1908 logger.startLine() << " Note: Replacing op result of type "
1909 << resultType << " with value(s) of type (";
1910 llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) {
1911 logger.getOStream() << v.getType();
1912 });
1913 logger.getOStream() << ")";
1914 };
1915 SmallVector<Type> convertedTypes;
1916 if (failed(currentTypeConverter->convertTypes(resultType,
1917 convertedTypes))) {
1918 logProlog();
1919 logger.getOStream() << ", but the type converter failed to legalize "
1920 "the original type.\n";
1921 continue;
1922 }
1923 if (TypeRange(convertedTypes) != TypeRange(ValueRange(repls))) {
1924 logProlog();
1925 logger.getOStream() << ", but the legalized type(s) is/are (";
1926 llvm::interleaveComma(convertedTypes, logger.getOStream(),
1927 [&](Type t) { logger.getOStream() << t; });
1928 logger.getOStream() << ")\n";
1929 }
1930 }
1931 }
1932 });
1933
1934 if (!config.allowPatternRollback) {
1935 // Pattern rollback is not allowed: materialize all IR changes immediately.
1937 *this, op->getResults(), newValues, currentTypeConverter);
1938 // Update internal data structures, so that there are no dangling pointers
1939 // to erased IR.
1940 op->walk([&](Operation *op) {
1941 erasedOps.insert(op);
1942 ignoredOps.remove(op);
1943 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1944 unresolvedMaterializations.erase(castOp);
1945 patternMaterializations.erase(castOp);
1946 }
1947 // The original op will be erased, so remove it from the set of
1948 // unlegalized ops.
1949 if (config.unlegalizedOps)
1950 config.unlegalizedOps->erase(op);
1951 });
1952 op->walk([&](Block *block) { erasedBlocks.insert(block); });
1953 // Replace the op with the replacement values and notify the listener.
1954 notifyingRewriter.replaceOp(op, repls);
1955 return;
1956 }
1957
1958 assert(!ignoredOps.contains(op) && "operation was already replaced");
1959#ifndef NDEBUG
1960 for (Value v : op->getResults())
1961 assert(!replacedValues.contains(v) &&
1962 "attempting to replace a value that was already replaced");
1963#endif // NDEBUG
1964
1965 // Check if replaced op is an unresolved materialization, i.e., an
1966 // unrealized_conversion_cast op that was created by the conversion driver.
1967 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1968 // Make sure that the user does not mess with unresolved materializations
1969 // that were inserted by the conversion driver. We keep track of these
1970 // ops in internal data structures.
1971 assert(!unresolvedMaterializations.contains(castOp) &&
1972 "attempting to replace/erase an unresolved materialization");
1973 }
1974
1975 // Create mappings for each of the new result values.
1976 for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
1977 mapping.map(static_cast<Value>(result), std::move(repl));
1978
1980 // Mark this operation and all nested ops as replaced.
1981 op->walk([&](Operation *op) { replacedOps.insert(op); });
1982}
1983
1985 Value from, ValueRange to, const TypeConverter *converter,
1986 function_ref<bool(OpOperand &)> functor) {
1987 LLVM_DEBUG({
1988 logger.startLine() << "** Replace Value : '" << from << "'";
1989 if (auto blockArg = dyn_cast<BlockArgument>(from)) {
1990 if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
1991 logger.getOStream() << " (in region of '" << parentOp->getName()
1992 << "' (" << parentOp << ")";
1993 } else {
1994 logger.getOStream() << " (unlinked block)";
1995 }
1996 }
1997 if (functor) {
1998 logger.getOStream() << ", conditional replacement";
1999 }
2000 });
2001
2002 if (!config.allowPatternRollback) {
2003 SmallVector<Value> toConv = llvm::to_vector(to);
2004 SmallVector<Value> repls =
2005 getReplacementValues(*this, from, {toConv}, converter);
2006 IRRewriter r(from.getContext());
2007 Value repl = repls.front();
2008 if (!repl)
2009 return;
2010
2011 performReplaceValue(r, from, repl, functor);
2012 return;
2013 }
2014
2015#ifndef NDEBUG
2016 // Make sure that a value is not replaced multiple times. In rollback mode,
2017 // `replaceAllUsesWith` replaces not only all current uses of the given value,
2018 // but also all future uses that may be introduced by future pattern
2019 // applications. Therefore, it does not make sense to call
2020 // `replaceAllUsesWith` multiple times with the same value. Doing so would
2021 // overwrite the mapping and mess with the internal state of the dialect
2022 // conversion driver.
2023 assert(!replacedValues.contains(from) &&
2024 "attempting to replace a value that was already replaced");
2025 assert(!wasOpReplaced(from.getDefiningOp()) &&
2026 "attempting to replace a op result that was already replaced");
2027 replacedValues.insert(from);
2028#endif // NDEBUG
2029
2030 if (functor)
2031 llvm::reportFatalInternalError(
2032 "conditional value replacement is not supported in rollback mode");
2033 mapping.map(from, to);
2034 appendRewrite<ReplaceValueRewrite>(from, converter);
2035}
2036
2038 if (!config.allowPatternRollback) {
2039 // Pattern rollback is not allowed: materialize all IR changes immediately.
2040 // Update internal data structures, so that there are no dangling pointers
2041 // to erased IR.
2042 block->walk([&](Operation *op) {
2043 erasedOps.insert(op);
2044 ignoredOps.remove(op);
2045 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2046 unresolvedMaterializations.erase(castOp);
2047 patternMaterializations.erase(castOp);
2048 }
2049 // The original op will be erased, so remove it from the set of
2050 // unlegalized ops.
2051 if (config.unlegalizedOps)
2052 config.unlegalizedOps->erase(op);
2053 });
2054 block->walk([&](Block *block) { erasedBlocks.insert(block); });
2055 // Erase the block and notify the listener.
2056 notifyingRewriter.eraseBlock(block);
2057 return;
2058 }
2059
2060 assert(!wasOpReplaced(block->getParentOp()) &&
2061 "attempting to erase a block within a replaced/erased op");
2063
2064 // Unlink the block from its parent region. The block is kept in the rewrite
2065 // object and will be actually destroyed when rewrites are applied. This
2066 // allows us to keep the operations in the block live and undo the removal by
2067 // re-inserting the block.
2068 block->getParent()->getBlocks().remove(block);
2069
2070 // Mark all nested ops as erased.
2071 block->walk([&](Operation *op) { replacedOps.insert(op); });
2072}
2073
2075 Block *block, Region *previous, Region::iterator previousIt) {
2076 // If no previous insertion point is provided, the block used to be detached.
2077 bool wasDetached = !previous;
2078 Operation *newParentOp = block->getParentOp();
2079 LLVM_DEBUG(
2080 {
2081 Operation *parent = newParentOp;
2082 if (parent) {
2083 logger.startLine() << "** Insert Block into : '" << parent->getName()
2084 << "' (" << parent << ")";
2085 } else {
2086 logger.startLine()
2087 << "** Insert Block into detached Region (nullptr parent op)";
2088 }
2089 if (wasDetached)
2090 logger.getOStream() << " (was detached)";
2091 logger.getOStream() << "\n";
2092 });
2093
2094 // In rollback mode, it is easier to misuse the API, so perform extra error
2095 // checking.
2096 assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
2097 "attempting to insert into a region within a replaced/erased op");
2098 (void)newParentOp;
2099
2100 // In "no rollback" mode, the listener is always notified immediately.
2101 if (!config.allowPatternRollback && config.listener)
2102 config.listener->notifyBlockInserted(block, previous, previousIt);
2103
2104 if (wasDetached) {
2105 // If the block was detached, it is most likely a newly created block.
2106 if (config.allowPatternRollback) {
2107 // TODO: If the same block is inserted multiple times from a detached
2108 // state, the rollback mechanism may erase the same block multiple times.
2109 // This is a bug in the rollback-based dialect conversion driver.
2111 } else {
2112 // In "no rollback" mode, there is an extra data structure for tracking
2113 // erased blocks that must be kept up to date.
2114 erasedBlocks.erase(block);
2115 }
2116 return;
2117 }
2118
2119 // The block was moved from one place to another.
2120 if (config.allowPatternRollback)
2121 appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
2122}
2123
2125 Block *dest,
2126 Block::iterator before) {
2127 appendRewrite<InlineBlockRewrite>(dest, source, before);
2128}
2129
2131 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
2132 LLVM_DEBUG({
2134 reasonCallback(diag);
2135 logger.startLine() << "** Failure : " << diag.str() << "\n";
2136 if (config.notifyCallback)
2137 config.notifyCallback(diag);
2138 });
2139}
2140
2141//===----------------------------------------------------------------------===//
2142// ConversionPatternRewriter
2143//===----------------------------------------------------------------------===//
2144
2145ConversionPatternRewriter::ConversionPatternRewriter(
2146 MLIRContext *ctx, const ConversionConfig &config,
2147 OperationConverter &opConverter)
2149 *this, config, opConverter)) {
2150 setListener(impl.get());
2151}
2152
2153ConversionPatternRewriter::~ConversionPatternRewriter() = default;
2154
2155const ConversionConfig &ConversionPatternRewriter::getConfig() const {
2156 return impl->config;
2157}
2158
2159void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
2160 assert(op && newOp && "expected non-null op");
2161 replaceOp(op, newOp->getResults());
2162}
2163
2164void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
2165 assert(op->getNumResults() == newValues.size() &&
2166 "incorrect # of replacement values");
2167
2168 // If the current insertion point is before the erased operation, we adjust
2169 // the insertion point to be after the operation.
2170 if (getInsertionPoint() == op->getIterator())
2172
2173 SmallVector<SmallVector<Value>> newVals =
2174 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2175 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2176 });
2177 impl->replaceOp(op, std::move(newVals));
2178}
2179
2180void ConversionPatternRewriter::replaceOpWithMultiple(
2181 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2182 assert(op->getNumResults() == newValues.size() &&
2183 "incorrect # of replacement values");
2184
2185 // If the current insertion point is before the erased operation, we adjust
2186 // the insertion point to be after the operation.
2187 if (getInsertionPoint() == op->getIterator())
2189
2190 impl->replaceOp(op, std::move(newValues));
2191}
2192
2193void ConversionPatternRewriter::eraseOp(Operation *op) {
2194 LLVM_DEBUG({
2195 impl->logger.startLine()
2196 << "** Erase : '" << op->getName() << "'(" << op << ")\n";
2197 });
2198
2199 // If the current insertion point is before the erased operation, we adjust
2200 // the insertion point to be after the operation.
2201 if (getInsertionPoint() == op->getIterator())
2203
2204 SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
2205 impl->replaceOp(op, std::move(nullRepls));
2206}
2207
2208void ConversionPatternRewriter::eraseBlock(Block *block) {
2209 impl->eraseBlock(block);
2210}
2211
2212Block *ConversionPatternRewriter::applySignatureConversion(
2213 Block *block, TypeConverter::SignatureConversion &conversion,
2214 const TypeConverter *converter) {
2215 assert(!impl->wasOpReplaced(block->getParentOp()) &&
2216 "attempting to apply a signature conversion to a block within a "
2217 "replaced/erased op");
2218 return impl->applySignatureConversion(block, converter, conversion);
2219}
2220
2221FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2222 Region *region, const TypeConverter &converter,
2223 TypeConverter::SignatureConversion *entryConversion) {
2224 assert(!impl->wasOpReplaced(region->getParentOp()) &&
2225 "attempting to apply a signature conversion to a block within a "
2226 "replaced/erased op");
2227 return impl->convertRegionTypes(region, converter, entryConversion);
2228}
2229
2230void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
2231 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2232}
2233
2234void ConversionPatternRewriter::replaceUsesWithIf(
2235 Value from, ValueRange to, function_ref<bool(OpOperand &)> functor,
2236 bool *allUsesReplaced) {
2237 assert(!allUsesReplaced &&
2238 "allUsesReplaced is not supported in a dialect conversion");
2239 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2240}
2241
2242Value ConversionPatternRewriter::getRemappedValue(Value key) {
2243 SmallVector<ValueVector> remappedValues;
2244 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
2245 remappedValues)))
2246 return nullptr;
2247 assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
2248 return remappedValues.front().front();
2249}
2250
2251LogicalResult
2252ConversionPatternRewriter::getRemappedValues(ValueRange keys,
2253 SmallVectorImpl<Value> &results) {
2254 if (keys.empty())
2255 return success();
2256 SmallVector<ValueVector> remapped;
2257 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
2258 remapped)))
2259 return failure();
2260 for (const auto &values : remapped) {
2261 assert(values.size() == 1 && "1:N conversion not supported");
2262 results.push_back(values.front());
2263 }
2264 return success();
2265}
2266
2267void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
2268 Block::iterator before,
2269 ValueRange argValues) {
2270#ifndef NDEBUG
2271 assert(argValues.size() == source->getNumArguments() &&
2272 "incorrect # of argument replacement values");
2273 assert(!impl->wasOpReplaced(source->getParentOp()) &&
2274 "attempting to inline a block from a replaced/erased op");
2275 assert(!impl->wasOpReplaced(dest->getParentOp()) &&
2276 "attempting to inline a block into a replaced/erased op");
2277 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
2278 // The source block will be deleted, so it should not have any users (i.e.,
2279 // there should be no predecessors).
2280 assert(llvm::all_of(source->getUsers(), opIgnored) &&
2281 "expected 'source' to have no predecessors");
2282#endif // NDEBUG
2283
2284 // If a listener is attached to the dialect conversion, ops cannot be moved
2285 // to the destination block in bulk ("fast path"). This is because at the time
2286 // the notifications are sent, it is unknown which ops were moved. Instead,
2287 // ops should be moved one-by-one ("slow path"), so that a separate
2288 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
2289 // a bit more efficient, so we try to do that when possible.
2290 bool fastPath = !getConfig().listener;
2291
2292 if (fastPath && impl->config.allowPatternRollback)
2293 impl->inlineBlockBefore(source, dest, before);
2294
2295 // Replace all uses of block arguments.
2296 for (auto it : llvm::zip(source->getArguments(), argValues))
2297 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2298
2299 if (fastPath) {
2300 // Move all ops at once.
2301 dest->getOperations().splice(before, source->getOperations());
2302 } else {
2303 // Move op by op.
2304 while (!source->empty())
2305 moveOpBefore(&source->front(), dest, before);
2306 }
2307
2308 // If the current insertion point is within the source block, adjust the
2309 // insertion point to the destination block.
2310 if (getInsertionBlock() == source)
2311 setInsertionPoint(dest, getInsertionPoint());
2312
2313 // Erase the source block.
2314 eraseBlock(source);
2315}
2316
2317void ConversionPatternRewriter::startOpModification(Operation *op) {
2318 if (!impl->config.allowPatternRollback) {
2319 // Pattern rollback is not allowed: no extra bookkeeping is needed.
2321 return;
2322 }
2323 assert(!impl->wasOpReplaced(op) &&
2324 "attempting to modify a replaced/erased op");
2325#ifndef NDEBUG
2326 impl->pendingRootUpdates.insert(op);
2327#endif
2328 impl->appendRewrite<ModifyOperationRewrite>(op);
2329}
2330
2331void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2332 impl->patternModifiedOps.insert(op);
2333 if (!impl->config.allowPatternRollback) {
2335 if (getConfig().listener)
2336 getConfig().listener->notifyOperationModified(op);
2337 return;
2338 }
2339
2340 // There is nothing to do here, we only need to track the operation at the
2341 // start of the update.
2342#ifndef NDEBUG
2343 assert(!impl->wasOpReplaced(op) &&
2344 "attempting to modify a replaced/erased op");
2345 assert(impl->pendingRootUpdates.erase(op) &&
2346 "operation did not have a pending in-place update");
2347#endif
2348}
2349
2350void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2351 if (!impl->config.allowPatternRollback) {
2353 return;
2354 }
2355#ifndef NDEBUG
2356 assert(impl->pendingRootUpdates.erase(op) &&
2357 "operation did not have a pending in-place update");
2358#endif
2359 // Erase the last update for this operation.
2360 auto it = llvm::find_if(
2361 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
2362 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2363 return modifyRewrite && modifyRewrite->getOperation() == op;
2364 });
2365 assert(it != impl->rewrites.rend() && "no root update started on op");
2366 (*it)->rollback();
2367 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2368 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2369}
2370
2371detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2372 return *impl;
2373}
2374
2375//===----------------------------------------------------------------------===//
2376// ConversionPattern
2377//===----------------------------------------------------------------------===//
2378
2379FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2380 ArrayRef<ValueRange> operands) const {
2381 SmallVector<Value> oneToOneOperands;
2382 oneToOneOperands.reserve(operands.size());
2383 for (ValueRange operand : operands) {
2384 if (operand.size() != 1)
2385 return failure();
2386
2387 oneToOneOperands.push_back(operand.front());
2388 }
2389 return std::move(oneToOneOperands);
2390}
2391
2392LogicalResult
2393ConversionPattern::matchAndRewrite(Operation *op,
2394 PatternRewriter &rewriter) const {
2395 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
2396 auto &rewriterImpl = dialectRewriter.getImpl();
2397
2398 // Track the current conversion pattern type converter in the rewriter.
2399 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
2400 getTypeConverter());
2401
2402 // Remap the operands of the operation.
2403 SmallVector<ValueVector> remapped;
2404 if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
2405 op->getOperands(), remapped))) {
2406 return failure();
2407 }
2408 SmallVector<ValueRange> remappedAsRange =
2409 llvm::to_vector_of<ValueRange>(remapped);
2410 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2411}
2412
2413//===----------------------------------------------------------------------===//
2414// OperationLegalizer
2415//===----------------------------------------------------------------------===//
2416
2417namespace {
2418/// A set of rewrite patterns that can be used to legalize a given operation.
2419using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2420
2421/// This class defines a recursive operation legalizer.
2422class OperationLegalizer {
2423public:
2424 using LegalizationAction = ConversionTarget::LegalizationAction;
2425
2426 OperationLegalizer(ConversionPatternRewriter &rewriter,
2427 const ConversionTarget &targetInfo,
2428 const FrozenRewritePatternSet &patterns);
2429
2430 /// Returns true if the given operation is known to be illegal on the target.
2431 bool isIllegal(Operation *op) const;
2432
2433 /// Attempt to legalize the given operation. Returns success if the operation
2434 /// was legalized, failure otherwise.
2435 LogicalResult legalize(Operation *op);
2436
2437 /// Returns the conversion target in use by the legalizer.
2438 const ConversionTarget &getTarget() { return target; }
2439
2440private:
2441 /// Attempt to legalize the given operation by folding it.
2442 LogicalResult legalizeWithFold(Operation *op);
2443
2444 /// Attempt to legalize the given operation by applying a pattern. Returns
2445 /// success if the operation was legalized, failure otherwise.
2446 LogicalResult legalizeWithPattern(Operation *op);
2447
2448 /// Return true if the given pattern may be applied to the given operation,
2449 /// false otherwise.
2450 bool canApplyPattern(Operation *op, const Pattern &pattern);
2451
2452 /// Legalize the resultant IR after successfully applying the given pattern.
2453 LogicalResult
2454 legalizePatternResult(Operation *op, const Pattern &pattern,
2455 const RewriterState &curState,
2456 const SetVector<Operation *> &newOps,
2457 const SetVector<Operation *> &modifiedOps);
2458
2459 LogicalResult
2460 legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
2461 LogicalResult
2462 legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
2463
2464 //===--------------------------------------------------------------------===//
2465 // Cost Model
2466 //===--------------------------------------------------------------------===//
2467
2468 /// Build an optimistic legalization graph given the provided patterns. This
2469 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2470 /// patterns for operations that are not directly legal, but may be
2471 /// transitively legal for the current target given the provided patterns.
2472 void buildLegalizationGraph(
2473 LegalizationPatterns &anyOpLegalizerPatterns,
2475
2476 /// Compute the benefit of each node within the computed legalization graph.
2477 /// This orders the patterns within 'legalizerPatterns' based upon two
2478 /// criteria:
2479 /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2480 /// represent the more direct mapping to the target.
2481 /// 2) When comparing patterns with the same legalization depth, prefer the
2482 /// pattern with the highest PatternBenefit. This allows for users to
2483 /// prefer specific legalizations over others.
2484 void computeLegalizationGraphBenefit(
2485 LegalizationPatterns &anyOpLegalizerPatterns,
2487
2488 /// Compute the legalization depth when legalizing an operation of the given
2489 /// type.
2490 unsigned computeOpLegalizationDepth(
2491 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2493
2494 /// Apply the conversion cost model to the given set of patterns, and return
2495 /// the smallest legalization depth of any of the patterns. See
2496 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2497 unsigned applyCostModelToPatterns(
2498 LegalizationPatterns &patterns,
2499 DenseMap<OperationName, unsigned> &minOpPatternDepth,
2501
2502 /// The current set of patterns that have been applied.
2503 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2504
2505 /// The rewriter to use when converting operations.
2506 ConversionPatternRewriter &rewriter;
2507
2508 /// The legalization information provided by the target.
2509 const ConversionTarget &target;
2510
2511 /// The pattern applicator to use for conversions.
2512 PatternApplicator applicator;
2513};
2514} // namespace
2515
2516OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2517 const ConversionTarget &targetInfo,
2518 const FrozenRewritePatternSet &patterns)
2519 : rewriter(rewriter), target(targetInfo), applicator(patterns) {
2520 // The set of patterns that can be applied to illegal operations to transform
2521 // them into legal ones.
2523 LegalizationPatterns anyOpLegalizerPatterns;
2524
2525 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2526 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2527}
2528
2529bool OperationLegalizer::isIllegal(Operation *op) const {
2530 return target.isIllegal(op);
2531}
2532
2533LogicalResult OperationLegalizer::legalize(Operation *op) {
2534#ifndef NDEBUG
2535 const char *logLineComment =
2536 "//===-------------------------------------------===//\n";
2537
2538 auto &logger = rewriter.getImpl().logger;
2539#endif
2540
2541 // Check to see if the operation is ignored and doesn't need to be converted.
2542 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2543
2544 LLVM_DEBUG({
2545 logger.getOStream() << "\n";
2546 logger.startLine() << logLineComment;
2547 logger.startLine() << "Legalizing operation : ";
2548 // Do not print the operation name if the operation is ignored. Ignored ops
2549 // may have been erased and should not be accessed. The pointer can be
2550 // printed safely.
2551 if (!isIgnored)
2552 logger.getOStream() << "'" << op->getName() << "' ";
2553 logger.getOStream() << "(" << op << ") {\n";
2554 logger.indent();
2555
2556 // If the operation has no regions, just print it here.
2557 if (!isIgnored && op->getNumRegions() == 0) {
2558 logger.startLine() << OpWithFlags(op,
2559 OpPrintingFlags().printGenericOpForm())
2560 << "\n";
2561 }
2562 });
2563
2564 if (isIgnored) {
2565 LLVM_DEBUG({
2566 logSuccess(logger, "operation marked 'ignored' during conversion");
2567 logger.startLine() << logLineComment;
2568 });
2569 return success();
2570 }
2571
2572 // Check if this operation is legal on the target.
2573 if (auto legalityInfo = target.isLegal(op)) {
2574 LLVM_DEBUG({
2575 logSuccess(
2576 logger, "operation marked legal by the target{0}",
2577 legalityInfo->isRecursivelyLegal
2578 ? "; NOTE: operation is recursively legal; skipping internals"
2579 : "");
2580 logger.startLine() << logLineComment;
2581 });
2582
2583 // If this operation is recursively legal, mark its children as ignored so
2584 // that we don't consider them for legalization.
2585 if (legalityInfo->isRecursivelyLegal) {
2586 op->walk([&](Operation *nested) {
2587 if (op != nested)
2588 rewriter.getImpl().ignoredOps.insert(nested);
2589 });
2590 }
2591
2592 return success();
2593 }
2594
2595 // If the operation is not legal, try to fold it in-place if the folding mode
2596 // is 'BeforePatterns'. 'Never' will skip this.
2597 const ConversionConfig &config = rewriter.getConfig();
2598 if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2599 if (succeeded(legalizeWithFold(op))) {
2600 LLVM_DEBUG({
2601 logSuccess(logger, "operation was folded");
2602 logger.startLine() << logLineComment;
2603 });
2604 return success();
2605 }
2606 }
2607
2608 // Otherwise, we need to apply a legalization pattern to this operation.
2609 if (succeeded(legalizeWithPattern(op))) {
2610 LLVM_DEBUG({
2611 logSuccess(logger, "");
2612 logger.startLine() << logLineComment;
2613 });
2614 return success();
2615 }
2616
2617 // If the operation can't be legalized via patterns, try to fold it in-place
2618 // if the folding mode is 'AfterPatterns'.
2619 if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2620 if (succeeded(legalizeWithFold(op))) {
2621 LLVM_DEBUG({
2622 logSuccess(logger, "operation was folded");
2623 logger.startLine() << logLineComment;
2624 });
2625 return success();
2626 }
2627 }
2628
2629 LLVM_DEBUG({
2630 logFailure(logger, "no matched legalization pattern");
2631 logger.startLine() << logLineComment;
2632 });
2633 return failure();
2634}
2635
2636/// Helper function that moves and returns the given object. Also resets the
2637/// original object, so that it is in a valid, empty state again.
2638template <typename T>
2639static T moveAndReset(T &obj) {
2640 T result = std::move(obj);
2641 obj = T();
2642 return result;
2643}
2644
2645LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2646 auto &rewriterImpl = rewriter.getImpl();
2647 LLVM_DEBUG({
2648 rewriterImpl.logger.startLine() << "* Fold {\n";
2649 rewriterImpl.logger.indent();
2650 });
2651
2652 // Clear pattern state, so that the next pattern application starts with a
2653 // clean slate. (The op/block sets are populated by listener notifications.)
2654 llvm::scope_exit cleanup([&]() {
2655 rewriterImpl.patternNewOps.clear();
2656 rewriterImpl.patternModifiedOps.clear();
2657 });
2658
2659 // Upon failure, undo all changes made by the folder.
2660 RewriterState curState = rewriterImpl.getCurrentState();
2661
2662 // Try to fold the operation.
2663 StringRef opName = op->getName().getStringRef();
2664 SmallVector<Value, 2> replacementValues;
2665 SmallVector<Operation *, 2> newOps;
2666 rewriter.setInsertionPoint(op);
2667 rewriter.startOpModification(op);
2668 if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2669 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2670 rewriter.cancelOpModification(op);
2671 return failure();
2672 }
2673 rewriter.finalizeOpModification(op);
2674
2675 // An empty list of replacement values indicates that the fold was in-place.
2676 // As the operation changed, a new legalization needs to be attempted.
2677 if (replacementValues.empty())
2678 return legalize(op);
2679
2680 // Insert a replacement for 'op' with the folded replacement values.
2681 rewriter.replaceOp(op, replacementValues);
2682
2683 // Recursively legalize any new constant operations.
2684 for (Operation *newOp : newOps) {
2685 if (failed(legalize(newOp))) {
2686 LLVM_DEBUG(logFailure(rewriterImpl.logger,
2687 "failed to legalize generated constant '{0}'",
2688 newOp->getName()));
2689 if (!rewriter.getConfig().allowPatternRollback) {
2690 // Rolling back a folder is like rolling back a pattern.
2691 llvm::reportFatalInternalError(
2692 "op '" + opName +
2693 "' folder rollback of IR modifications requested");
2694 }
2695 rewriterImpl.resetState(
2696 curState, std::string(op->getName().getStringRef()) + " folder");
2697 return failure();
2698 }
2699 }
2700
2701 LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2702 return success();
2703}
2704
2705/// Report a fatal error indicating that newly produced or modified IR could
2706/// not be legalized.
2707static void
2709 const SetVector<Operation *> &newOps,
2710 const SetVector<Operation *> &modifiedOps) {
2711 auto newOpNames = llvm::map_range(
2712 newOps, [](Operation *op) { return op->getName().getStringRef(); });
2713 auto modifiedOpNames = llvm::map_range(
2714 modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
2715 llvm::reportFatalInternalError("pattern '" + pattern.getDebugName() +
2716 "' produced IR that could not be legalized. " +
2717 "new ops: {" + llvm::join(newOpNames, ", ") +
2718 "}, " + "modified ops: {" +
2719 llvm::join(modifiedOpNames, ", ") + "}");
2720}
2721
2722LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2723 auto &rewriterImpl = rewriter.getImpl();
2724 const ConversionConfig &config = rewriter.getConfig();
2725
2726#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2727 Operation *checkOp;
2728 std::optional<OperationFingerPrint> topLevelFingerPrint;
2729 if (!rewriterImpl.config.allowPatternRollback) {
2730 // The op may be getting erased, so we have to check the parent op.
2731 // (In rare cases, a pattern may even erase the parent op, which will cause
2732 // a crash here. Expensive checks are "best effort".) Skip the check if the
2733 // op does not have a parent op.
2734 if ((checkOp = op->getParentOp())) {
2735 if (!op->getContext()->isMultithreadingEnabled()) {
2736 topLevelFingerPrint = OperationFingerPrint(checkOp);
2737 } else {
2738 // Another thread may be modifying a sibling operation. Therefore, the
2739 // fingerprinting mechanism of the parent op works only in
2740 // single-threaded mode.
2741 LLVM_DEBUG({
2742 rewriterImpl.logger.startLine()
2743 << "WARNING: Multi-threadeding is enabled. Some dialect "
2744 "conversion expensive checks are skipped in multithreading "
2745 "mode!\n";
2746 });
2747 }
2748 }
2749 }
2750#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2751
2752 // Functor that returns if the given pattern may be applied.
2753 auto canApply = [&](const Pattern &pattern) {
2754 bool canApply = canApplyPattern(op, pattern);
2755 if (canApply && config.listener)
2756 config.listener->notifyPatternBegin(pattern, op);
2757 return canApply;
2758 };
2759
2760 // Functor that cleans up the rewriter state after a pattern failed to match.
2761 RewriterState curState = rewriterImpl.getCurrentState();
2762 auto onFailure = [&](const Pattern &pattern) {
2763 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2764 if (!rewriterImpl.config.allowPatternRollback) {
2765 // Erase all unresolved materializations.
2766 for (auto op : rewriterImpl.patternMaterializations) {
2767 rewriterImpl.unresolvedMaterializations.erase(op);
2768 op.erase();
2769 }
2770 rewriterImpl.patternMaterializations.clear();
2771#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2772 // Expensive pattern check that can detect API violations.
2773 if (checkOp && topLevelFingerPrint) {
2774 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2775 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2776 llvm::reportFatalInternalError(
2777 "pattern '" + pattern.getDebugName() +
2778 "' returned failure but IR did change");
2779 }
2780#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2781 }
2782 rewriterImpl.patternNewOps.clear();
2783 rewriterImpl.patternModifiedOps.clear();
2784 LLVM_DEBUG({
2785 logFailure(rewriterImpl.logger, "pattern failed to match");
2786 if (rewriterImpl.config.notifyCallback) {
2787 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
2788 diag << "Failed to apply pattern \"" << pattern.getDebugName()
2789 << "\" on op:\n"
2790 << *op;
2791 rewriterImpl.config.notifyCallback(diag);
2792 }
2793 });
2794 if (config.listener)
2795 config.listener->notifyPatternEnd(pattern, failure());
2796 rewriterImpl.resetState(curState, pattern.getDebugName());
2797 appliedPatterns.erase(&pattern);
2798 };
2799
2800 // Functor that performs additional legalization when a pattern is
2801 // successfully applied.
2802 auto onSuccess = [&](const Pattern &pattern) {
2803 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2804 if (!rewriterImpl.config.allowPatternRollback) {
2805 // Eagerly erase unused materializations.
2806 for (auto op : rewriterImpl.patternMaterializations) {
2807 if (op->use_empty()) {
2808 rewriterImpl.unresolvedMaterializations.erase(op);
2809 op.erase();
2810 }
2811 }
2812 rewriterImpl.patternMaterializations.clear();
2813 }
2814 SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2815 SetVector<Operation *> modifiedOps =
2816 moveAndReset(rewriterImpl.patternModifiedOps);
2817 auto result =
2818 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2819 appliedPatterns.erase(&pattern);
2820 if (failed(result)) {
2821 if (!rewriterImpl.config.allowPatternRollback)
2822 reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
2823 rewriterImpl.resetState(curState, pattern.getDebugName());
2824 }
2825 if (config.listener)
2826 config.listener->notifyPatternEnd(pattern, result);
2827 return result;
2828 };
2829
2830 // Try to match and rewrite a pattern on this operation.
2831 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2832 onSuccess);
2833}
2834
2835bool OperationLegalizer::canApplyPattern(Operation *op,
2836 const Pattern &pattern) {
2837 LLVM_DEBUG({
2838 auto &os = rewriter.getImpl().logger;
2839 os.getOStream() << "\n";
2840 os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2841 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2842 os.getOStream() << ")' {\n";
2843 os.indent();
2844 });
2845
2846 // Ensure that we don't cycle by not allowing the same pattern to be
2847 // applied twice in the same recursion stack if it is not known to be safe.
2848 if (!pattern.hasBoundedRewriteRecursion() &&
2849 !appliedPatterns.insert(&pattern).second) {
2850 LLVM_DEBUG(
2851 logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2852 return false;
2853 }
2854 return true;
2855}
2856
2857LogicalResult OperationLegalizer::legalizePatternResult(
2858 Operation *op, const Pattern &pattern, const RewriterState &curState,
2859 const SetVector<Operation *> &newOps,
2860 const SetVector<Operation *> &modifiedOps) {
2861 [[maybe_unused]] auto &impl = rewriter.getImpl();
2862 assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2863
2864#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2865 if (impl.config.allowPatternRollback) {
2866 // Check that the root was either replaced or updated in place.
2867 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2868 auto replacedRoot = [&] {
2869 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2870 };
2871 auto updatedRootInPlace = [&] {
2872 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2873 };
2874 if (!replacedRoot() && !updatedRootInPlace())
2875 llvm::reportFatalInternalError(
2876 "expected pattern to replace the root operation "
2877 "or modify it in place");
2878 }
2879#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2880
2881 // Legalize each of the actions registered during application.
2882 if (failed(legalizePatternRootUpdates(modifiedOps)) ||
2883 failed(legalizePatternCreatedOperations(newOps))) {
2884 return failure();
2885 }
2886
2887 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2888 return success();
2889}
2890
2891LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2892 const SetVector<Operation *> &newOps) {
2893 for (Operation *op : newOps) {
2894 if (failed(legalize(op))) {
2895 LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
2896 "failed to legalize generated operation '{0}'({1})",
2897 op->getName(), op));
2898 return failure();
2899 }
2900 }
2901 return success();
2902}
2903
2904LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2905 const SetVector<Operation *> &modifiedOps) {
2906 for (Operation *op : modifiedOps) {
2907 if (failed(legalize(op))) {
2908 LLVM_DEBUG(
2909 logFailure(rewriter.getImpl().logger,
2910 "failed to legalize operation updated in-place '{0}'",
2911 op->getName()));
2912 return failure();
2913 }
2914 }
2915 return success();
2916}
2917
2918//===----------------------------------------------------------------------===//
2919// Cost Model
2920//===----------------------------------------------------------------------===//
2921
2922void OperationLegalizer::buildLegalizationGraph(
2923 LegalizationPatterns &anyOpLegalizerPatterns,
2925 // A mapping between an operation and a set of operations that can be used to
2926 // generate it.
2928 // A mapping between an operation and any currently invalid patterns it has.
2930 // A worklist of patterns to consider for legality.
2931 SetVector<const Pattern *> patternWorklist;
2932
2933 // Build the mapping from operations to the parent ops that may generate them.
2934 applicator.walkAllPatterns([&](const Pattern &pattern) {
2935 std::optional<OperationName> root = pattern.getRootKind();
2936
2937 // If the pattern has no specific root, we can't analyze the relationship
2938 // between the root op and generated operations. Given that, add all such
2939 // patterns to the legalization set.
2940 if (!root) {
2941 anyOpLegalizerPatterns.push_back(&pattern);
2942 return;
2943 }
2944
2945 // Skip operations that are always known to be legal.
2946 if (target.getOpAction(*root) == LegalizationAction::Legal)
2947 return;
2948
2949 // Add this pattern to the invalid set for the root op and record this root
2950 // as a parent for any generated operations.
2951 invalidPatterns[*root].insert(&pattern);
2952 for (auto op : pattern.getGeneratedOps())
2953 parentOps[op].insert(*root);
2954
2955 // Add this pattern to the worklist.
2956 patternWorklist.insert(&pattern);
2957 });
2958
2959 // If there are any patterns that don't have a specific root kind, we can't
2960 // make direct assumptions about what operations will never be legalized.
2961 // Note: Technically we could, but it would require an analysis that may
2962 // recurse into itself. It would be better to perform this kind of filtering
2963 // at a higher level than here anyways.
2964 if (!anyOpLegalizerPatterns.empty()) {
2965 for (const Pattern *pattern : patternWorklist)
2966 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2967 return;
2968 }
2969
2970 while (!patternWorklist.empty()) {
2971 auto *pattern = patternWorklist.pop_back_val();
2972
2973 // Check to see if any of the generated operations are invalid.
2974 if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2975 std::optional<LegalizationAction> action = target.getOpAction(op);
2976 return !legalizerPatterns.count(op) &&
2977 (!action || action == LegalizationAction::Illegal);
2978 }))
2979 continue;
2980
2981 // Otherwise, if all of the generated operation are valid, this op is now
2982 // legal so add all of the child patterns to the worklist.
2983 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2984 invalidPatterns[*pattern->getRootKind()].erase(pattern);
2985
2986 // Add any invalid patterns of the parent operations to see if they have now
2987 // become legal.
2988 for (auto op : parentOps[*pattern->getRootKind()])
2989 patternWorklist.set_union(invalidPatterns[op]);
2990 }
2991}
2992
2993void OperationLegalizer::computeLegalizationGraphBenefit(
2994 LegalizationPatterns &anyOpLegalizerPatterns,
2996 // The smallest pattern depth, when legalizing an operation.
2997 DenseMap<OperationName, unsigned> minOpPatternDepth;
2998
2999 // For each operation that is transitively legal, compute a cost for it.
3000 for (auto &opIt : legalizerPatterns)
3001 if (!minOpPatternDepth.count(opIt.first))
3002 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3003 legalizerPatterns);
3004
3005 // Apply the cost model to the patterns that can match any operation. Those
3006 // with a specific operation type are already resolved when computing the op
3007 // legalization depth.
3008 if (!anyOpLegalizerPatterns.empty())
3009 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3010 legalizerPatterns);
3011
3012 // Apply a cost model to the pattern applicator. We order patterns first by
3013 // depth then benefit. `legalizerPatterns` contains per-op patterns by
3014 // decreasing benefit.
3015 applicator.applyCostModel([&](const Pattern &pattern) {
3016 ArrayRef<const Pattern *> orderedPatternList;
3017 if (std::optional<OperationName> rootName = pattern.getRootKind())
3018 orderedPatternList = legalizerPatterns[*rootName];
3019 else
3020 orderedPatternList = anyOpLegalizerPatterns;
3021
3022 // If the pattern is not found, then it was removed and cannot be matched.
3023 auto *it = llvm::find(orderedPatternList, &pattern);
3024 if (it == orderedPatternList.end())
3026
3027 // Patterns found earlier in the list have higher benefit.
3028 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3029 });
3030}
3031
3032unsigned OperationLegalizer::computeOpLegalizationDepth(
3033 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
3035 // Check for existing depth.
3036 auto depthIt = minOpPatternDepth.find(op);
3037 if (depthIt != minOpPatternDepth.end())
3038 return depthIt->second;
3039
3040 // If a mapping for this operation does not exist, then this operation
3041 // is always legal. Return 0 as the depth for a directly legal operation.
3042 auto opPatternsIt = legalizerPatterns.find(op);
3043 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3044 return 0u;
3045
3046 // Record this initial depth in case we encounter this op again when
3047 // recursively computing the depth.
3048 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3049
3050 // Apply the cost model to the operation patterns, and update the minimum
3051 // depth.
3052 unsigned minDepth = applyCostModelToPatterns(
3053 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3054 minOpPatternDepth[op] = minDepth;
3055 return minDepth;
3056}
3057
3058unsigned OperationLegalizer::applyCostModelToPatterns(
3059 LegalizationPatterns &patterns,
3060 DenseMap<OperationName, unsigned> &minOpPatternDepth,
3062 unsigned minDepth = std::numeric_limits<unsigned>::max();
3063
3064 // Compute the depth for each pattern within the set.
3065 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3066 patternsByDepth.reserve(patterns.size());
3067 for (const Pattern *pattern : patterns) {
3068 unsigned depth = 1;
3069 for (auto generatedOp : pattern->getGeneratedOps()) {
3070 unsigned generatedOpDepth = computeOpLegalizationDepth(
3071 generatedOp, minOpPatternDepth, legalizerPatterns);
3072 depth = std::max(depth, generatedOpDepth + 1);
3073 }
3074 patternsByDepth.emplace_back(pattern, depth);
3075
3076 // Update the minimum depth of the pattern list.
3077 minDepth = std::min(minDepth, depth);
3078 }
3079
3080 // If the operation only has one legalization pattern, there is no need to
3081 // sort them.
3082 if (patternsByDepth.size() == 1)
3083 return minDepth;
3084
3085 // Sort the patterns by those likely to be the most beneficial.
3086 llvm::stable_sort(patternsByDepth,
3087 [](const std::pair<const Pattern *, unsigned> &lhs,
3088 const std::pair<const Pattern *, unsigned> &rhs) {
3089 // First sort by the smaller pattern legalization
3090 // depth.
3091 if (lhs.second != rhs.second)
3092 return lhs.second < rhs.second;
3093
3094 // Then sort by the larger pattern benefit.
3095 auto lhsBenefit = lhs.first->getBenefit();
3096 auto rhsBenefit = rhs.first->getBenefit();
3097 return lhsBenefit > rhsBenefit;
3098 });
3099
3100 // Update the legalization pattern to use the new sorted list.
3101 patterns.clear();
3102 for (auto &patternIt : patternsByDepth)
3103 patterns.push_back(patternIt.first);
3104 return minDepth;
3105}
3106
3107//===----------------------------------------------------------------------===//
3108// Reconcile Unrealized Casts
3109//===----------------------------------------------------------------------===//
3110
3111/// Try to reconcile all given UnrealizedConversionCastOps and store the
3112/// left-over ops in `remainingCastOps` (if provided). See documentation in
3113/// DialectConversion.h for more details.
3114/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3115/// algorithm may visit an operand (or user) which is a cast op, but will not
3116/// try to reconcile it if not in the filtered set.
3117template <typename RangeT>
3119 RangeT castOps,
3120 function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3122 // A worklist of cast ops to process.
3123 SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
3124
3125 // Helper function that return the unrealized_conversion_cast op that
3126 // defines all inputs of the given op (in the same order). Return "nullptr"
3127 // if there is no such op.
3128 auto getInputCast =
3129 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3130 if (castOp.getInputs().empty())
3131 return {};
3132 auto inputCastOp =
3133 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3134 if (!inputCastOp)
3135 return {};
3136 if (inputCastOp.getOutputs() != castOp.getInputs())
3137 return {};
3138 return inputCastOp;
3139 };
3140
3141 // Process ops in the worklist bottom-to-top.
3142 while (!worklist.empty()) {
3143 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3144
3145 // Traverse the chain of input cast ops to see if an op with the same
3146 // input types can be found.
3147 UnrealizedConversionCastOp nextCast = castOp;
3148 while (nextCast) {
3149 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3150 if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
3151 return v.getDefiningOp() == castOp;
3152 })) {
3153 // Ran into a cycle.
3154 break;
3155 }
3156
3157 // Found a cast where the input types match the output types of the
3158 // matched op. We can directly use those inputs.
3159 castOp.replaceAllUsesWith(nextCast.getInputs());
3160 break;
3161 }
3162 nextCast = getInputCast(nextCast);
3163 }
3164 }
3165
3166 // A set of all alive cast ops. I.e., ops whose results are (transitively)
3167 // used by an op that is not a cast op.
3168 DenseSet<Operation *> liveOps;
3169
3170 // Helper function that marks the given op and transitively reachable input
3171 // cast ops as alive.
3172 auto markOpLive = [&](Operation *rootOp) {
3173 SmallVector<Operation *> worklist;
3174 worklist.push_back(rootOp);
3175 while (!worklist.empty()) {
3176 Operation *op = worklist.pop_back_val();
3177 if (liveOps.insert(op).second) {
3178 // Successfully inserted: process reachable input cast ops.
3179 for (Value v : op->getOperands())
3180 if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3181 if (isCastOpOfInterestFn(castOp))
3182 worklist.push_back(castOp);
3183 }
3184 }
3185 };
3186
3187 // Find all alive cast ops.
3188 for (UnrealizedConversionCastOp op : castOps) {
3189 // The op may have been marked live already as being an operand of another
3190 // live cast op.
3191 if (liveOps.contains(op.getOperation()))
3192 continue;
3193 // If any of the users is not a cast op, mark the current op (and its
3194 // input ops) as live.
3195 if (llvm::any_of(op->getUsers(), [&](Operation *user) {
3196 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3197 return !castOp || !isCastOpOfInterestFn(castOp);
3198 }))
3199 markOpLive(op);
3200 }
3201
3202 // Erase all dead cast ops.
3203 for (UnrealizedConversionCastOp op : castOps) {
3204 if (liveOps.contains(op)) {
3205 // Op is alive and was not erased. Add it to the remaining cast ops.
3206 if (remainingCastOps)
3207 remainingCastOps->push_back(op);
3208 continue;
3209 }
3210
3211 // Op is dead. Erase it.
3212 op->dropAllUses();
3213 op->erase();
3214 }
3215}
3216
3218 ArrayRef<UnrealizedConversionCastOp> castOps,
3219 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3220 // Set of all cast ops for faster lookups.
3221 DenseSet<UnrealizedConversionCastOp> castOpSet;
3222 for (UnrealizedConversionCastOp op : castOps)
3223 castOpSet.insert(op);
3224 reconcileUnrealizedCasts(castOpSet, remainingCastOps);
3225}
3226
3228 const DenseSet<UnrealizedConversionCastOp> &castOps,
3229 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3231 llvm::make_range(castOps.begin(), castOps.end()),
3232 [&](UnrealizedConversionCastOp castOp) {
3233 return castOps.contains(castOp);
3234 },
3235 remainingCastOps);
3236}
3237
3238namespace mlir {
3241 &castOps,
3244 castOps.keys(),
3245 [&](UnrealizedConversionCastOp castOp) {
3246 return castOps.contains(castOp);
3247 },
3248 remainingCastOps);
3249}
3250} // namespace mlir
3251
3252//===----------------------------------------------------------------------===//
3253// OperationConverter
3254//===----------------------------------------------------------------------===//
3255
3256namespace mlir {
3257// This class converts operations to a given conversion target via a set of
3258// rewrite patterns. The conversion behaves differently depending on the
3259// conversion mode.
3263 const ConversionConfig &config,
3264 OpConversionMode mode)
3265 : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
3266 mode(mode) {}
3267
3268 /// Applies the conversion to the given operations (and their nested
3269 /// operations).
3270 LogicalResult applyConversion(ArrayRef<Operation *> ops);
3271
3272 /// Legalizes the given operations (and their nested operations) to the
3273 /// conversion target.
3274 template <typename Fn>
3275 LogicalResult legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
3276 bool isRecursiveLegalization = false);
3278 bool isRecursiveLegalization = false) {
3279 return legalizeOperations(
3280 ops, /*onFailure=*/[&]() {}, isRecursiveLegalization);
3281 }
3282
3283 /// Converts a single operation. If `isRecursiveLegalization` is "true", the
3284 /// conversion is a recursive legalization request, triggered from within a
3285 /// pattern. In that case, do not emit errors because there will be another
3286 /// attempt at legalizing the operation later (via the regular pre-order
3287 /// legalization mechanism).
3288 LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
3289
3290 const ConversionTarget &getTarget() { return opLegalizer.getTarget(); }
3291
3292private:
3293 /// The rewriter to use when converting operations.
3294 ConversionPatternRewriter rewriter;
3295
3296 /// The legalizer to use when converting operations.
3297 OperationLegalizer opLegalizer;
3298
3299 /// The conversion mode to use when legalizing operations.
3300 OpConversionMode mode;
3301};
3302} // namespace mlir
3303
3305 bool isRecursiveLegalization) {
3306 const ConversionConfig &config = rewriter.getConfig();
3307
3308 // Legalize the given operation.
3309 if (failed(opLegalizer.legalize(op))) {
3310 // Handle the case of a failed conversion for each of the different modes.
3311 // Full conversions expect all operations to be converted.
3312 if (mode == OpConversionMode::Full) {
3313 if (!isRecursiveLegalization)
3314 op->emitError() << "failed to legalize operation '" << op->getName()
3315 << "'";
3316 return failure();
3317 }
3318 // Partial conversions allow conversions to fail iff the operation was not
3319 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
3320 // set, non-legalizable ops are added to that set.
3321 if (mode == OpConversionMode::Partial) {
3322 if (opLegalizer.isIllegal(op)) {
3323 if (!isRecursiveLegalization)
3324 op->emitError() << "failed to legalize operation '" << op->getName()
3325 << "' that was explicitly marked illegal";
3326 return failure();
3327 }
3328 if (config.unlegalizedOps && !isRecursiveLegalization)
3329 config.unlegalizedOps->insert(op);
3330 }
3331 } else if (mode == OpConversionMode::Analysis) {
3332 // Analysis conversions don't fail if any operations fail to legalize,
3333 // they are only interested in the operations that were successfully
3334 // legalized.
3335 if (config.legalizableOps && !isRecursiveLegalization)
3336 config.legalizableOps->insert(op);
3337 }
3338 return success();
3339}
3340
3341static LogicalResult
3343 UnrealizedConversionCastOp op,
3344 const UnresolvedMaterializationInfo &info) {
3345 assert(!op.use_empty() &&
3346 "expected that dead materializations have already been DCE'd");
3347 Operation::operand_range inputOperands = op.getOperands();
3348
3349 // Try to materialize the conversion.
3350 if (const TypeConverter *converter = info.getConverter()) {
3351 rewriter.setInsertionPoint(op);
3352 SmallVector<Value> newMaterialization;
3353 switch (info.getMaterializationKind()) {
3354 case MaterializationKind::Target:
3355 newMaterialization = converter->materializeTargetConversion(
3356 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3357 info.getOriginalType());
3358 break;
3359 case MaterializationKind::Source:
3360 assert(op->getNumResults() == 1 && "expected single result");
3361 Value sourceMat = converter->materializeSourceConversion(
3362 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3363 if (sourceMat)
3364 newMaterialization.push_back(sourceMat);
3365 break;
3366 }
3367 if (!newMaterialization.empty()) {
3368#ifndef NDEBUG
3369 ValueRange newMaterializationRange(newMaterialization);
3370 assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
3371 "materialization callback produced value of incorrect type");
3372#endif // NDEBUG
3373 rewriter.replaceOp(op, newMaterialization);
3374 return success();
3375 }
3376 }
3377
3378 InFlightDiagnostic diag = op->emitError()
3379 << "failed to legalize unresolved materialization "
3380 "from ("
3381 << inputOperands.getTypes() << ") to ("
3382 << op.getResultTypes()
3383 << ") that remained live after conversion";
3384 diag.attachNote(op->getUsers().begin()->getLoc())
3385 << "see existing live user here: " << *op->getUsers().begin();
3386 return failure();
3387}
3388
3389template <typename Fn>
3390LogicalResult
3392 bool isRecursiveLegalization) {
3393 const ConversionTarget &target = opLegalizer.getTarget();
3394
3395 // Compute the set of operations and blocks to convert.
3396 SmallVector<Operation *> toConvert;
3397 for (Operation *op : ops) {
3399 [&](Operation *op) {
3400 toConvert.push_back(op);
3401 // Don't check this operation's children for conversion if the
3402 // operation is recursively legal.
3403 auto legalityInfo = target.isLegal(op);
3404 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3405 return WalkResult::skip();
3406 return WalkResult::advance();
3407 });
3408 }
3409 for (Operation *op : toConvert) {
3410 if (failed(convert(op, isRecursiveLegalization))) {
3411 // Failed to convert an operation.
3412 onFailure();
3413 return failure();
3414 }
3415 }
3416 return success();
3417}
3418
3419LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
3420 return impl->opConverter.legalizeOperations(op,
3421 /*isRecursiveLegalization=*/true);
3422}
3423
3424LogicalResult ConversionPatternRewriter::legalize(Region *r) {
3425 // Fast path: If the region is empty, there is nothing to legalize.
3426 if (r->empty())
3427 return success();
3428
3429 // Gather a list of all operations to legalize. This is done before
3430 // converting the entry block signature because unrealized_conversion_cast
3431 // ops should not be included.
3433 for (Block &b : *r)
3434 for (Operation &op : b)
3435 ops.push_back(&op);
3436
3437 // If the current pattern runs with a type converter, convert the entry block
3438 // signature.
3439 if (const TypeConverter *converter = impl->currentTypeConverter) {
3440 std::optional<TypeConverter::SignatureConversion> conversion =
3441 converter->convertBlockSignature(&r->front());
3442 if (!conversion)
3443 return failure();
3444 applySignatureConversion(&r->front(), *conversion, converter);
3445 }
3446
3447 // Legalize all operations in the region. This includes all nested
3448 // operations.
3449 return impl->opConverter.legalizeOperations(ops,
3450 /*isRecursiveLegalization=*/true);
3451}
3452
3454 // Convert each operation and discard rewrites on failure.
3455 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
3456 LogicalResult status = legalizeOperations(ops, /*onFailure=*/[&]() {
3457 // Dialect conversion failed.
3458 if (rewriterImpl.config.allowPatternRollback) {
3459 // Rollback is allowed: restore the original IR.
3460 rewriterImpl.undoRewrites();
3461 } else {
3462 // Rollback is not allowed: apply all modifications that have been
3463 // performed so far.
3464 rewriterImpl.applyRewrites();
3465 }
3466 });
3467 if (failed(status))
3468 return failure();
3469
3470 // After a successful conversion, apply rewrites.
3471 rewriterImpl.applyRewrites();
3472
3473 // Reconcile all UnrealizedConversionCastOps that were inserted by the
3474 // dialect conversion frameworks. (Not the ones that were inserted by
3475 // patterns.)
3477 &materializations = rewriterImpl.unresolvedMaterializations;
3479 reconcileUnrealizedCasts(materializations, &remainingCastOps);
3480
3481 // Drop markers.
3482 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3483 castOp->removeAttr(kPureTypeConversionMarker);
3484
3485 // Try to legalize all unresolved materializations.
3486 if (rewriter.getConfig().buildMaterializations) {
3487 // Use a new rewriter, so the modifications are not tracked for rollback
3488 // purposes etc.
3489 IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
3490 rewriter.getConfig().listener);
3491 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3492 auto it = materializations.find(castOp);
3493 assert(it != materializations.end() && "inconsistent state");
3494 if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
3495 it->second)))
3496 return failure();
3497 }
3498 }
3499
3500 return success();
3501}
3502
3503//===----------------------------------------------------------------------===//
3504// Type Conversion
3505//===----------------------------------------------------------------------===//
3506
3507void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
3508 ArrayRef<Type> types) {
3509 assert(!types.empty() && "expected valid types");
3510 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
3511 addInputs(types);
3512}
3513
3514void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
3515 assert(!types.empty() &&
3516 "1->0 type remappings don't need to be added explicitly");
3517 argTypes.append(types.begin(), types.end());
3518}
3519
3520void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3521 unsigned newInputNo,
3522 unsigned newInputCount) {
3523 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3524 assert(newInputCount != 0 && "expected valid input count");
3525 remappedInputs[origInputNo] =
3526 InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
3527}
3528
3529void TypeConverter::SignatureConversion::remapInput(
3530 unsigned origInputNo, ArrayRef<Value> replacements) {
3531 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3532 remappedInputs[origInputNo] = InputMapping{
3533 origInputNo, /*size=*/0,
3534 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3535}
3536
3537/// Internal implementation of the type conversion.
3538/// This is used with either a Type or a Value as the first argument.
3539/// - we can cache the context-free conversions until the last registered
3540/// context-aware conversion.
3541/// - we can't cache the result of type conversion happening after context-aware
3542/// conversions, because the type converter may return different results for the
3543/// same input type.
3544LogicalResult
3545TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3546 SmallVectorImpl<Type> &results) const {
3547 assert(typeOrValue && "expected non-null type");
3548 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3549 : cast<Type>(typeOrValue);
3550 {
3551 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3552 std::defer_lock);
3554 cacheReadLock.lock();
3555 auto existingIt = cachedDirectConversions.find(t);
3556 if (existingIt != cachedDirectConversions.end()) {
3557 if (existingIt->second)
3558 results.push_back(existingIt->second);
3559 return success(existingIt->second != nullptr);
3560 }
3561 auto multiIt = cachedMultiConversions.find(t);
3562 if (multiIt != cachedMultiConversions.end()) {
3563 results.append(multiIt->second.begin(), multiIt->second.end());
3564 return success();
3565 }
3566 }
3567 // Walk the added converters in reverse order to apply the most recently
3568 // registered first.
3569 size_t currentCount = results.size();
3570
3571 // We can cache the context-free conversions until the last registered
3572 // context-aware conversion. But only if we're processing a Value right now.
3573 auto isCacheable = [&](int index) {
3574 int numberOfConversionsUntilContextAware =
3575 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3576 return index < numberOfConversionsUntilContextAware;
3577 };
3578
3579 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3580 std::defer_lock);
3581
3582 for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3583 const ConversionCallbackFn &converter = indexedConverter.value();
3584 std::optional<LogicalResult> result = converter(typeOrValue, results);
3585 if (!result) {
3586 assert(results.size() == currentCount &&
3587 "failed type conversion should not change results");
3588 continue;
3589 }
3590 if (!isCacheable(indexedConverter.index()))
3591 return success();
3593 cacheWriteLock.lock();
3594 if (!succeeded(*result)) {
3595 assert(results.size() == currentCount &&
3596 "failed type conversion should not change results");
3597 cachedDirectConversions.try_emplace(t, nullptr);
3598 return failure();
3599 }
3600 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3601 if (newTypes.size() == 1)
3602 cachedDirectConversions.try_emplace(t, newTypes.front());
3603 else
3604 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3605 return success();
3606 }
3607 return failure();
3608}
3609
3610LogicalResult TypeConverter::convertType(Type t,
3611 SmallVectorImpl<Type> &results) const {
3612 return convertTypeImpl(t, results);
3613}
3614
3615LogicalResult TypeConverter::convertType(Value v,
3616 SmallVectorImpl<Type> &results) const {
3617 return convertTypeImpl(v, results);
3618}
3619
3620Type TypeConverter::convertType(Type t) const {
3621 // Use the multi-type result version to convert the type.
3622 SmallVector<Type, 1> results;
3623 if (failed(convertType(t, results)))
3624 return nullptr;
3625
3626 // Check to ensure that only one type was produced.
3627 return results.size() == 1 ? results.front() : nullptr;
3628}
3629
3630Type TypeConverter::convertType(Value v) const {
3631 // Use the multi-type result version to convert the type.
3632 SmallVector<Type, 1> results;
3633 if (failed(convertType(v, results)))
3634 return nullptr;
3635
3636 // Check to ensure that only one type was produced.
3637 return results.size() == 1 ? results.front() : nullptr;
3638}
3639
3640LogicalResult
3641TypeConverter::convertTypes(TypeRange types,
3642 SmallVectorImpl<Type> &results) const {
3643 for (Type type : types)
3644 if (failed(convertType(type, results)))
3645 return failure();
3646 return success();
3647}
3648
3649LogicalResult
3650TypeConverter::convertTypes(ValueRange values,
3651 SmallVectorImpl<Type> &results) const {
3652 for (Value value : values)
3653 if (failed(convertType(value, results)))
3654 return failure();
3655 return success();
3656}
3657
3658bool TypeConverter::isLegal(Type type) const {
3659 return convertType(type) == type;
3660}
3661
3662bool TypeConverter::isLegal(Value value) const {
3663 return convertType(value) == value.getType();
3664}
3665
3666bool TypeConverter::isLegal(Operation *op) const {
3667 return isLegal(op->getOperands()) && isLegal(op->getResults());
3668}
3669
3670bool TypeConverter::isLegal(Region *region) const {
3671 return llvm::all_of(
3672 *region, [this](Block &block) { return isLegal(block.getArguments()); });
3673}
3674
3675bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3676 if (!isLegal(ty.getInputs()))
3677 return false;
3678 if (!isLegal(ty.getResults()))
3679 return false;
3680 return true;
3681}
3682
3683LogicalResult
3684TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
3685 SignatureConversion &result) const {
3686 // Try to convert the given input type.
3687 SmallVector<Type, 1> convertedTypes;
3688 if (failed(convertType(type, convertedTypes)))
3689 return failure();
3690
3691 // If this argument is being dropped, there is nothing left to do.
3692 if (convertedTypes.empty())
3693 return success();
3694
3695 // Otherwise, add the new inputs.
3696 result.addInputs(inputNo, convertedTypes);
3697 return success();
3698}
3699LogicalResult
3700TypeConverter::convertSignatureArgs(TypeRange types,
3701 SignatureConversion &result,
3702 unsigned origInputOffset) const {
3703 for (unsigned i = 0, e = types.size(); i != e; ++i)
3704 if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3705 return failure();
3706 return success();
3707}
3708LogicalResult
3709TypeConverter::convertSignatureArg(unsigned inputNo, Value value,
3710 SignatureConversion &result) const {
3711 // Try to convert the given input type.
3712 SmallVector<Type, 1> convertedTypes;
3713 if (failed(convertType(value, convertedTypes)))
3714 return failure();
3715
3716 // If this argument is being dropped, there is nothing left to do.
3717 if (convertedTypes.empty())
3718 return success();
3719
3720 // Otherwise, add the new inputs.
3721 result.addInputs(inputNo, convertedTypes);
3722 return success();
3723}
3724LogicalResult
3725TypeConverter::convertSignatureArgs(ValueRange values,
3726 SignatureConversion &result,
3727 unsigned origInputOffset) const {
3728 for (unsigned i = 0, e = values.size(); i != e; ++i)
3729 if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
3730 return failure();
3731 return success();
3732}
3733
3734Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3735 Location loc, Type resultType,
3736 ValueRange inputs) const {
3737 for (const SourceMaterializationCallbackFn &fn :
3738 llvm::reverse(sourceMaterializations))
3739 if (Value result = fn(builder, resultType, inputs, loc))
3740 return result;
3741 return nullptr;
3742}
3743
3744Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3745 Location loc, Type resultType,
3746 ValueRange inputs,
3747 Type originalType) const {
3748 SmallVector<Value> result = materializeTargetConversion(
3749 builder, loc, TypeRange(resultType), inputs, originalType);
3750 if (result.empty())
3751 return nullptr;
3752 assert(result.size() == 1 && "expected single result");
3753 return result.front();
3754}
3755
3756SmallVector<Value> TypeConverter::materializeTargetConversion(
3757 OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3758 Type originalType) const {
3759 for (const TargetMaterializationCallbackFn &fn :
3760 llvm::reverse(targetMaterializations)) {
3761 SmallVector<Value> result =
3762 fn(builder, resultTypes, inputs, loc, originalType);
3763 if (result.empty())
3764 continue;
3765 assert(TypeRange(ValueRange(result)) == resultTypes &&
3766 "callback produced incorrect number of values or values with "
3767 "incorrect types");
3768 return result;
3769 }
3770 return {};
3771}
3772
3773std::optional<TypeConverter::SignatureConversion>
3774TypeConverter::convertBlockSignature(Block *block) const {
3775 SignatureConversion conversion(block->getNumArguments());
3776 if (failed(convertSignatureArgs(block->getArguments(), conversion)))
3777 return std::nullopt;
3778 return conversion;
3779}
3780
3781//===----------------------------------------------------------------------===//
3782// Type attribute conversion
3783//===----------------------------------------------------------------------===//
3784TypeConverter::AttributeConversionResult
3785TypeConverter::AttributeConversionResult::result(Attribute attr) {
3786 return AttributeConversionResult(attr, resultTag);
3787}
3788
3789TypeConverter::AttributeConversionResult
3790TypeConverter::AttributeConversionResult::na() {
3791 return AttributeConversionResult(nullptr, naTag);
3792}
3793
3794TypeConverter::AttributeConversionResult
3795TypeConverter::AttributeConversionResult::abort() {
3796 return AttributeConversionResult(nullptr, abortTag);
3797}
3798
3799bool TypeConverter::AttributeConversionResult::hasResult() const {
3800 return impl.getInt() == resultTag;
3801}
3802
3803bool TypeConverter::AttributeConversionResult::isNa() const {
3804 return impl.getInt() == naTag;
3805}
3806
3807bool TypeConverter::AttributeConversionResult::isAbort() const {
3808 return impl.getInt() == abortTag;
3809}
3810
3811Attribute TypeConverter::AttributeConversionResult::getResult() const {
3812 assert(hasResult() && "Cannot get result from N/A or abort");
3813 return impl.getPointer();
3814}
3815
3816std::optional<Attribute>
3817TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
3818 for (const TypeAttributeConversionCallbackFn &fn :
3819 llvm::reverse(typeAttributeConversions)) {
3820 AttributeConversionResult res = fn(type, attr);
3821 if (res.hasResult())
3822 return res.getResult();
3823 if (res.isAbort())
3824 return std::nullopt;
3825 }
3826 return std::nullopt;
3827}
3828
3829//===----------------------------------------------------------------------===//
3830// FunctionOpInterfaceSignatureConversion
3831//===----------------------------------------------------------------------===//
3832
3833static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3834 const TypeConverter &typeConverter,
3835 ConversionPatternRewriter &rewriter) {
3836 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3837 if (!type)
3838 return failure();
3839
3840 // Convert the original function types.
3841 TypeConverter::SignatureConversion result(type.getNumInputs());
3842 SmallVector<Type, 1> newResults;
3843 if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3844 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3845 return failure();
3846 if (!funcOp.getFunctionBody().empty())
3847 rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
3848 &typeConverter);
3849
3850 // Update the function signature in-place.
3851 auto newType = FunctionType::get(rewriter.getContext(),
3852 result.getConvertedTypes(), newResults);
3853
3854 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3855
3856 return success();
3857}
3858
3859/// Create a default conversion pattern that rewrites the type signature of a
3860/// FunctionOpInterface op. This only supports ops which use FunctionType to
3861/// represent their type.
3862namespace {
3863struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3864 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3865 MLIRContext *ctx,
3866 const TypeConverter &converter,
3867 PatternBenefit benefit)
3868 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3869
3870 LogicalResult
3871 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3872 ConversionPatternRewriter &rewriter) const override {
3873 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3874 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3875 }
3876};
3877
3878struct AnyFunctionOpInterfaceSignatureConversion
3879 : public OpInterfaceConversionPattern<FunctionOpInterface> {
3880 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3881
3882 LogicalResult
3883 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3884 ConversionPatternRewriter &rewriter) const override {
3885 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3886 }
3887};
3888} // namespace
3889
3890FailureOr<Operation *>
3891mlir::convertOpResultTypes(Operation *op, ValueRange operands,
3892 const TypeConverter &converter,
3893 ConversionPatternRewriter &rewriter) {
3894 assert(op && "Invalid op");
3895 Location loc = op->getLoc();
3896 if (converter.isLegal(op))
3897 return rewriter.notifyMatchFailure(loc, "op already legal");
3898
3899 OperationState newOp(loc, op->getName());
3900 newOp.addOperands(operands);
3901
3902 SmallVector<Type> newResultTypes;
3903 if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
3904 return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3905
3906 newOp.addTypes(newResultTypes);
3907 newOp.addAttributes(op->getAttrs());
3908 return rewriter.create(newOp);
3909}
3910
3911void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3912 StringRef functionLikeOpName, RewritePatternSet &patterns,
3913 const TypeConverter &converter, PatternBenefit benefit) {
3914 patterns.add<FunctionOpInterfaceSignatureConversion>(
3915 functionLikeOpName, patterns.getContext(), converter, benefit);
3916}
3917
3918void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3919 RewritePatternSet &patterns, const TypeConverter &converter,
3920 PatternBenefit benefit) {
3921 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3922 converter, patterns.getContext(), benefit);
3923}
3924
3925//===----------------------------------------------------------------------===//
3926// ConversionTarget
3927//===----------------------------------------------------------------------===//
3928
3929void ConversionTarget::setOpAction(OperationName op,
3930 LegalizationAction action) {
3931 legalOperations[op].action = action;
3932}
3933
3934void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3935 LegalizationAction action) {
3936 for (StringRef dialect : dialectNames)
3937 legalDialects[dialect] = action;
3938}
3939
3940auto ConversionTarget::getOpAction(OperationName op) const
3941 -> std::optional<LegalizationAction> {
3942 std::optional<LegalizationInfo> info = getOpInfo(op);
3943 return info ? info->action : std::optional<LegalizationAction>();
3944}
3945
3946auto ConversionTarget::isLegal(Operation *op) const
3947 -> std::optional<LegalOpDetails> {
3948 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3949 if (!info)
3950 return std::nullopt;
3951
3952 // Returns true if this operation instance is known to be legal.
3953 auto isOpLegal = [&] {
3954 // Handle dynamic legality either with the provided legality function.
3955 if (info->action == LegalizationAction::Dynamic) {
3956 std::optional<bool> result = info->legalityFn(op);
3957 if (result)
3958 return *result;
3959 }
3960
3961 // Otherwise, the operation is only legal if it was marked 'Legal'.
3962 return info->action == LegalizationAction::Legal;
3963 };
3964 if (!isOpLegal())
3965 return std::nullopt;
3966
3967 // This operation is legal, compute any additional legality information.
3968 LegalOpDetails legalityDetails;
3969 if (info->isRecursivelyLegal) {
3970 auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
3971 if (legalityFnIt != opRecursiveLegalityFns.end()) {
3972 legalityDetails.isRecursivelyLegal =
3973 legalityFnIt->second(op).value_or(true);
3974 } else {
3975 legalityDetails.isRecursivelyLegal = true;
3976 }
3977 }
3978 return legalityDetails;
3979}
3980
3981bool ConversionTarget::isIllegal(Operation *op) const {
3982 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3983 if (!info)
3984 return false;
3985
3986 if (info->action == LegalizationAction::Dynamic) {
3987 std::optional<bool> result = info->legalityFn(op);
3988 if (!result)
3989 return false;
3990
3991 return !(*result);
3992 }
3993
3994 return info->action == LegalizationAction::Illegal;
3995}
3996
3997static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
3998 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
3999 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
4000 if (!oldCallback)
4001 return newCallback;
4002
4003 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
4004 Operation *op) -> std::optional<bool> {
4005 if (std::optional<bool> result = newCl(op))
4006 return *result;
4007
4008 return oldCl(op);
4009 };
4010 return chain;
4011}
4012
4013void ConversionTarget::setLegalityCallback(
4014 OperationName name, const DynamicLegalityCallbackFn &callback) {
4015 assert(callback && "expected valid legality callback");
4016 auto *infoIt = legalOperations.find(name);
4017 assert(infoIt != legalOperations.end() &&
4018 infoIt->second.action == LegalizationAction::Dynamic &&
4019 "expected operation to already be marked as dynamically legal");
4020 infoIt->second.legalityFn =
4021 composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
4022}
4023
4024void ConversionTarget::markOpRecursivelyLegal(
4025 OperationName name, const DynamicLegalityCallbackFn &callback) {
4026 auto *infoIt = legalOperations.find(name);
4027 assert(infoIt != legalOperations.end() &&
4028 infoIt->second.action != LegalizationAction::Illegal &&
4029 "expected operation to already be marked as legal");
4030 infoIt->second.isRecursivelyLegal = true;
4031 if (callback)
4032 opRecursiveLegalityFns[name] = composeLegalityCallbacks(
4033 std::move(opRecursiveLegalityFns[name]), callback);
4034 else
4035 opRecursiveLegalityFns.erase(name);
4036}
4037
4038void ConversionTarget::setLegalityCallback(
4039 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
4040 assert(callback && "expected valid legality callback");
4041 for (StringRef dialect : dialects)
4042 dialectLegalityFns[dialect] = composeLegalityCallbacks(
4043 std::move(dialectLegalityFns[dialect]), callback);
4044}
4045
4046void ConversionTarget::setLegalityCallback(
4047 const DynamicLegalityCallbackFn &callback) {
4048 assert(callback && "expected valid legality callback");
4049 unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
4050}
4051
4052auto ConversionTarget::getOpInfo(OperationName op) const
4053 -> std::optional<LegalizationInfo> {
4054 // Check for info for this specific operation.
4055 const auto *it = legalOperations.find(op);
4056 if (it != legalOperations.end())
4057 return it->second;
4058 // Check for info for the parent dialect.
4059 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4060 if (dialectIt != legalDialects.end()) {
4061 DynamicLegalityCallbackFn callback;
4062 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4063 if (dialectFn != dialectLegalityFns.end())
4064 callback = dialectFn->second;
4065 return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
4066 callback};
4067 }
4068 // Otherwise, check if we mark unknown operations as dynamic.
4069 if (unknownLegalityFn)
4070 return LegalizationInfo{LegalizationAction::Dynamic,
4071 /*isRecursivelyLegal=*/false, unknownLegalityFn};
4072 return std::nullopt;
4073}
4074
4075#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4076//===----------------------------------------------------------------------===//
4077// PDL Configuration
4078//===----------------------------------------------------------------------===//
4079
4080void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4081 auto &rewriterImpl =
4082 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4083 rewriterImpl.currentTypeConverter = getTypeConverter();
4084}
4085
4086void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4087 auto &rewriterImpl =
4088 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4089 rewriterImpl.currentTypeConverter = nullptr;
4090}
4091
4092/// Remap the given value using the rewriter and the type converter in the
4093/// provided config.
4094static FailureOr<SmallVector<Value>>
4095pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
4096 SmallVector<Value> mappedValues;
4097 if (failed(rewriter.getRemappedValues(values, mappedValues)))
4098 return failure();
4099 return std::move(mappedValues);
4100}
4101
4102void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
4103 patterns.getPDLPatterns().registerRewriteFunction(
4104 "convertValue",
4105 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4106 auto results = pdllConvertValues(
4107 static_cast<ConversionPatternRewriter &>(rewriter), value);
4108 if (failed(results))
4109 return failure();
4110 return results->front();
4111 });
4112 patterns.getPDLPatterns().registerRewriteFunction(
4113 "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
4114 return pdllConvertValues(
4115 static_cast<ConversionPatternRewriter &>(rewriter), values);
4116 });
4117 patterns.getPDLPatterns().registerRewriteFunction(
4118 "convertType",
4119 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4120 auto &rewriterImpl =
4121 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4122 if (const TypeConverter *converter =
4123 rewriterImpl.currentTypeConverter) {
4124 if (Type newType = converter->convertType(type))
4125 return newType;
4126 return failure();
4127 }
4128 return type;
4129 });
4130 patterns.getPDLPatterns().registerRewriteFunction(
4131 "convertTypes",
4132 [](PatternRewriter &rewriter,
4133 TypeRange types) -> FailureOr<SmallVector<Type>> {
4134 auto &rewriterImpl =
4135 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4136 const TypeConverter *converter = rewriterImpl.currentTypeConverter;
4137 if (!converter)
4138 return SmallVector<Type>(types);
4139
4140 SmallVector<Type> remappedTypes;
4141 if (failed(converter->convertTypes(types, remappedTypes)))
4142 return failure();
4143 return std::move(remappedTypes);
4144 });
4145}
4146#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
4147
4148//===----------------------------------------------------------------------===//
4149// Op Conversion Entry Points
4150//===----------------------------------------------------------------------===//
4151
4152/// This is the type of Action that is dispatched when a conversion is applied.
4154 : public tracing::ActionImpl<ApplyConversionAction> {
4155public:
4158 static constexpr StringLiteral tag = "apply-conversion";
4159 static constexpr StringLiteral desc =
4160 "Encapsulate the application of a dialect conversion";
4161
4162 void print(raw_ostream &os) const override { os << tag; }
4163};
4164
4166 const ConversionTarget &target,
4168 ConversionConfig config,
4169 OpConversionMode mode) {
4170 if (ops.empty())
4171 return success();
4172 MLIRContext *ctx = ops.front()->getContext();
4173 LogicalResult status = success();
4174 SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
4176 [&] {
4177 OperationConverter opConverter(ops.front()->getContext(), target,
4178 patterns, config, mode);
4179 status = opConverter.applyConversion(ops);
4180 },
4181 irUnits);
4182 return status;
4183}
4184
4185//===----------------------------------------------------------------------===//
4186// Partial Conversion
4187//===----------------------------------------------------------------------===//
4188
4189LogicalResult mlir::applyPartialConversion(
4190 ArrayRef<Operation *> ops, const ConversionTarget &target,
4191 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4192 return applyConversion(ops, target, patterns, config,
4193 OpConversionMode::Partial);
4194}
4195LogicalResult
4196mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
4197 const FrozenRewritePatternSet &patterns,
4198 ConversionConfig config) {
4199 return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
4200}
4201
4202//===----------------------------------------------------------------------===//
4203// Full Conversion
4204//===----------------------------------------------------------------------===//
4205
4206LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4207 const ConversionTarget &target,
4208 const FrozenRewritePatternSet &patterns,
4209 ConversionConfig config) {
4210 return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
4211}
4212LogicalResult mlir::applyFullConversion(Operation *op,
4213 const ConversionTarget &target,
4214 const FrozenRewritePatternSet &patterns,
4215 ConversionConfig config) {
4216 return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
4217}
4218
4219//===----------------------------------------------------------------------===//
4220// Analysis Conversion
4221//===----------------------------------------------------------------------===//
4222
4223/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
4224/// op is a top-level module op (which is expected to be isolated from above),
4225/// return that op.
4227 // Check if there is a top-level operation within `ops`. If so, return that
4228 // op.
4229 for (Operation *op : ops) {
4230 if (!op->getParentOp()) {
4231#ifndef NDEBUG
4232 assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
4233 "expected top-level op to be isolated from above");
4234 for (Operation *other : ops)
4235 assert(op->isAncestor(other) &&
4236 "expected ops to have a common ancestor");
4237#endif // NDEBUG
4238 return op;
4239 }
4240 }
4241
4242 // No top-level op. Find a common ancestor.
4243 Operation *commonAncestor =
4244 ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
4245 for (Operation *op : ops.drop_front()) {
4246 while (!commonAncestor->isProperAncestor(op)) {
4247 commonAncestor =
4249 assert(commonAncestor &&
4250 "expected to find a common isolated from above ancestor");
4251 }
4252 }
4253
4254 return commonAncestor;
4255}
4256
4257LogicalResult mlir::applyAnalysisConversion(
4258 ArrayRef<Operation *> ops, ConversionTarget &target,
4259 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4260#ifndef NDEBUG
4261 if (config.legalizableOps)
4262 assert(config.legalizableOps->empty() && "expected empty set");
4263#endif // NDEBUG
4264
4265 // Clone closted common ancestor that is isolated from above.
4266 Operation *commonAncestor = findCommonAncestor(ops);
4267 IRMapping mapping;
4268 Operation *clonedAncestor = commonAncestor->clone(mapping);
4269 // Compute inverse IR mapping.
4270 DenseMap<Operation *, Operation *> inverseOperationMap;
4271 for (auto &it : mapping.getOperationMap())
4272 inverseOperationMap[it.second] = it.first;
4273
4274 // Convert the cloned operations. The original IR will remain unchanged.
4275 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4276 ops, [&](Operation *op) { return mapping.lookup(op); });
4277 LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
4278 OpConversionMode::Analysis);
4279
4280 // Remap `legalizableOps`, so that they point to the original ops and not the
4281 // cloned ops.
4282 if (config.legalizableOps) {
4283 DenseSet<Operation *> originalLegalizableOps;
4284 for (Operation *op : *config.legalizableOps)
4285 originalLegalizableOps.insert(inverseOperationMap[op]);
4286 *config.legalizableOps = std::move(originalLegalizableOps);
4287 }
4288
4289 // Erase the cloned IR.
4290 clonedAncestor->erase();
4291 return status;
4292}
4293
4294LogicalResult
4295mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
4296 const FrozenRewritePatternSet &patterns,
4297 ConversionConfig config) {
4298 return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
4299}
return success()
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
SmallVector< Value, 1 > ValueVector
A vector of SSA values, optimized for the most common case of a single value.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
#define DEBUG_TYPE
b getContext())
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition SCCP.cpp:67
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Definition Value.h:309
Location getLoc() const
Return the location for this argument.
Definition Value.h:324
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:318
BlockArgListType getArguments()
Definition Block.h:97
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
UnitAttr getUnitAttr()
Definition Builders.cpp:102
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
MLIRContext * context
Definition Builders.h:204
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition Builders.h:329
Block::iterator getPoint() const
Definition Builders.h:342
bool isSet() const
Returns true if this insert point is set.
Definition Builders.h:339
Block * getBlock() const
Definition Builders.h:341
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition Builders.cpp:477
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:425
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for ops that are known to be isolated from above.
type_range getTypes() const
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
void destroyOpProperties(OpaqueProperties properties) const
This hooks destroy the op properties.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:226
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:852
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:834
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:263
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition Operation.h:900
void erase()
Remove this operation from its parent block and delete it.
void copyProperties(OpaqueProperties rhs)
Copy properties from an existing other properties object.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
CRTP Implementation of an action.
Definition Action.h:76
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Definition Action.h:66
AttrTypeReplacer.
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
This iterator enumerates elements according to their dominance relationship.
Definition Iterators.h:48
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition Builders.h:310
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition Builders.h:300
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
const ConversionTarget & getTarget()
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, bool isRecursiveLegalization=false)
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, Fn onFailure, bool isRecursiveLegalization=false)
Legalizes the given operations (and their nested operations) to the conversion target.
LogicalResult applyConversion(ArrayRef< Operation * > ops)
Applies the conversion to the given operations (and their nested operations).
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.