MLIR 23.0.0git
DialectConversion.cpp
Go to the documentation of this file.
1//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/Config/mlir-config.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Iterators.h"
17#include "mlir/IR/Operation.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/DebugLog.h"
24#include "llvm/Support/ErrorHandling.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/SaveAndRestore.h"
27#include "llvm/Support/ScopedPrinter.h"
28#include <optional>
29#include <utility>
30
31using namespace mlir;
32using namespace mlir::detail;
33
34#define DEBUG_TYPE "dialect-conversion"
35
36/// A utility function to log a successful result for the given reason.
37template <typename... Args>
38static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
39 LLVM_DEBUG({
40 os.unindent();
41 os.startLine() << "} -> SUCCESS";
42 if (!fmt.empty())
43 os.getOStream() << " : "
44 << llvm::formatv(fmt.data(), std::forward<Args>(args)...);
45 os.getOStream() << "\n";
46 });
47}
48
49/// A utility function to log a failure result for the given reason.
50template <typename... Args>
51static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
52 LLVM_DEBUG({
53 os.unindent();
54 os.startLine() << "} -> FAILURE : "
55 << llvm::formatv(fmt.data(), std::forward<Args>(args)...)
56 << "\n";
57 });
58}
59
60/// Helper function that computes an insertion point where the given value is
61/// defined and can be used without a dominance violation.
63 Block *insertBlock = value.getParentBlock();
64 Block::iterator insertPt = insertBlock->begin();
65 if (OpResult inputRes = dyn_cast<OpResult>(value))
66 insertPt = ++inputRes.getOwner()->getIterator();
67 return OpBuilder::InsertPoint(insertBlock, insertPt);
68}
69
70/// Helper function that computes an insertion point where the given values are
71/// defined and can be used without a dominance violation.
73 assert(!vals.empty() && "expected at least one value");
74 DominanceInfo domInfo;
76 for (Value v : vals.drop_front()) {
77 // Choose the "later" insertion point.
79 if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(),
80 nextPt.getPoint())) {
81 // pt is before nextPt => choose nextPt.
82 pt = nextPt;
83 } else {
84#ifndef NDEBUG
85 // nextPt should be before pt => choose pt.
86 // If pt, nextPt are no dominance relationship, then there is no valid
87 // insertion point at which all given values are defined.
88 bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(),
89 pt.getBlock(), pt.getPoint());
90 assert(dom && "unable to find valid insertion point");
91#endif // NDEBUG
92 }
93 }
94 return pt;
95}
96
97namespace {
98enum OpConversionMode {
99 /// In this mode, the conversion will ignore failed conversions to allow
100 /// illegal operations to co-exist in the IR.
101 Partial,
102
103 /// In this mode, all operations must be legal for the given target for the
104 /// conversion to succeed.
105 Full,
106
107 /// In this mode, operations are analyzed for legality. No actual rewrites are
108 /// applied to the operations on success.
109 Analysis,
110};
111} // namespace
112
113//===----------------------------------------------------------------------===//
114// ConversionValueMapping
115//===----------------------------------------------------------------------===//
116
117/// A vector of SSA values, optimized for the most common case of one or two
118/// values. Inline size chosen empirically based on compilation profiling.
119/// Profiled: 2.3M calls, avg=2.0+-0.3. N=2 covers 98% of cases inline.
121
122namespace {
123
124/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
125struct ValueVectorMapInfo {
126 static ValueVector getEmptyKey() { return ValueVector{Value()}; }
127 static ::llvm::hash_code getHashValue(const ValueVector &val) {
128 return ::llvm::hash_combine_range(val);
129 }
130 static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
131 return LHS == RHS;
132 }
133};
134
135/// This class wraps a IRMapping to provide recursive lookup
136/// functionality, i.e. we will traverse if the mapped value also has a mapping.
137struct ConversionValueMapping {
138 /// Return "true" if an SSA value is mapped to the given value. May return
139 /// false positives.
140 bool isMappedTo(Value value) const { return mappedTo.contains(value); }
141
142 /// Lookup a value in the mapping.
143 ValueVector lookup(const ValueVector &from) const;
144
145 template <typename T>
146 struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
147
148 /// Map a value vector to the one provided.
149 template <typename OldVal, typename NewVal>
150 std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
151 map(OldVal &&oldVal, NewVal &&newVal) {
152 LLVM_DEBUG({
153 ValueVector next(newVal);
154 while (true) {
155 assert(next != oldVal && "inserting cyclic mapping");
156 auto it = mapping.find(next);
157 if (it == mapping.end())
158 break;
159 next = it->second;
160 }
161 });
162 mappedTo.insert_range(newVal);
163
164 mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
165 }
166
167 /// Map a value vector or single value to the one provided.
168 template <typename OldVal, typename NewVal>
169 std::enable_if_t<!IsValueVector<OldVal>::value ||
170 !IsValueVector<NewVal>::value>
171 map(OldVal &&oldVal, NewVal &&newVal) {
172 if constexpr (IsValueVector<OldVal>{}) {
173 map(std::forward<OldVal>(oldVal), ValueVector{newVal});
174 } else if constexpr (IsValueVector<NewVal>{}) {
175 map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
176 } else {
177 map(ValueVector{oldVal}, ValueVector{newVal});
178 }
179 }
180
181 void map(Value oldVal, SmallVector<Value> &&newVal) {
182 map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
183 }
184
185 /// Drop the last mapping for the given values.
186 void erase(const ValueVector &value) { mapping.erase(value); }
187
188private:
189 /// Current value mappings.
191
192 /// All SSA values that are mapped to. May contain false positives.
193 DenseSet<Value> mappedTo;
194};
195} // namespace
196
197/// Marker attribute for pure type conversions. I.e., mappings whose only
198/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
199/// the replacement values of a "replaceOp" call, etc., are not pure type
200/// conversions.)
201static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
202
203/// Return the operation that defines all values in the vector. Return nullptr
204/// if the values are not defined by the same operation.
206 assert(!values.empty() && "expected non-empty value vector");
207 Operation *op = values.front().getDefiningOp();
208 for (Value v : llvm::drop_begin(values)) {
209 if (v.getDefiningOp() != op)
210 return nullptr;
211 }
212 return op;
213}
214
215/// A vector of values is a pure type conversion if all values are defined by
216/// the same operation and the operation has the `kPureTypeConversionMarker`
217/// attribute.
218static bool isPureTypeConversion(const ValueVector &values) {
219 assert(!values.empty() && "expected non-empty value vector");
220 Operation *op = getCommonDefiningOp(values);
221 return op && op->hasAttr(kPureTypeConversionMarker);
222}
223
224ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
225 auto it = mapping.find(from);
226 if (it == mapping.end()) {
227 // No mapping found: The lookup stops here.
228 return {};
229 }
230 return it->second;
231}
232
233//===----------------------------------------------------------------------===//
234// Rewriter and Translation State
235//===----------------------------------------------------------------------===//
236namespace {
237/// This class contains a snapshot of the current conversion rewriter state.
238/// This is useful when saving and undoing a set of rewrites.
239struct RewriterState {
240 RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
241 unsigned numReplacedOps)
242 : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
243 numReplacedOps(numReplacedOps) {}
244
245 /// The current number of rewrites performed.
246 unsigned numRewrites;
247
248 /// The current number of ignored operations.
249 unsigned numIgnoredOperations;
250
251 /// The current number of replaced ops that are scheduled for erasure.
252 unsigned numReplacedOps;
253};
254
255//===----------------------------------------------------------------------===//
256// IR rewrites
257//===----------------------------------------------------------------------===//
258
259static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
260
261/// Notify the listener that the given block and its contents are being erased.
262static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
263 for (Operation &op : b)
264 notifyIRErased(listener, op);
265 listener->notifyBlockErased(&b);
266}
267
268/// Notify the listener that the given operation and its contents are being
269/// erased.
270static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
271 for (Region &r : op.getRegions()) {
272 for (Block &b : r) {
273 notifyIRErased(listener, b);
274 }
275 }
276 listener->notifyOperationErased(&op);
277}
278
279/// An IR rewrite that can be committed (upon success) or rolled back (upon
280/// failure).
281///
282/// The dialect conversion keeps track of IR modifications (requested by the
283/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
284/// are directly applied to the IR as the rewriter API is used, some are applied
285/// partially, and some are delayed until the `IRRewrite` objects are committed.
286class IRRewrite {
287public:
288 /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
289 /// Enum values are ordered, so that they can be used in `classof`: first all
290 /// block rewrites, then all operation rewrites.
291 enum class Kind {
292 // Block rewrites
293 CreateBlock,
294 EraseBlock,
295 InlineBlock,
296 MoveBlock,
297 BlockTypeConversion,
298 // Operation rewrites
299 MoveOperation,
300 ModifyOperation,
301 ReplaceOperation,
302 CreateOperation,
303 UnresolvedMaterialization,
304 // Value rewrites
305 ReplaceValue
306 };
307
308 virtual ~IRRewrite() = default;
309
310 /// Roll back the rewrite. Operations may be erased during rollback.
311 virtual void rollback() = 0;
312
313 /// Commit the rewrite. At this point, it is certain that the dialect
314 /// conversion will succeed. All IR modifications, except for operation/block
315 /// erasure, must be performed through the given rewriter.
316 ///
317 /// Instead of erasing operations/blocks, they should merely be unlinked
318 /// commit phase and finally be erased during the cleanup phase. This is
319 /// because internal dialect conversion state (such as `mapping`) may still
320 /// be using them.
321 ///
322 /// Any IR modification that was already performed before the commit phase
323 /// (e.g., insertion of an op) must be communicated to the listener that may
324 /// be attached to the given rewriter.
325 virtual void commit(RewriterBase &rewriter) {}
326
327 /// Cleanup operations/blocks. Cleanup is called after commit.
328 virtual void cleanup(RewriterBase &rewriter) {}
329
330 Kind getKind() const { return kind; }
331
332 static bool classof(const IRRewrite *rewrite) { return true; }
333
334protected:
335 IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
336 : kind(kind), rewriterImpl(rewriterImpl) {}
337
338 const ConversionConfig &getConfig() const;
339
340 const Kind kind;
341 ConversionPatternRewriterImpl &rewriterImpl;
342};
343
344/// A block rewrite.
345class BlockRewrite : public IRRewrite {
346public:
347 /// Return the block that this rewrite operates on.
348 Block *getBlock() const { return block; }
349
350 static bool classof(const IRRewrite *rewrite) {
351 return rewrite->getKind() >= Kind::CreateBlock &&
352 rewrite->getKind() <= Kind::BlockTypeConversion;
353 }
354
355protected:
356 BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
357 Block *block)
358 : IRRewrite(kind, rewriterImpl), block(block) {}
359
360 // The block that this rewrite operates on.
361 Block *block;
362};
363
364/// A value rewrite.
365class ValueRewrite : public IRRewrite {
366public:
367 /// Return the value that this rewrite operates on.
368 Value getValue() const { return value; }
369
370 static bool classof(const IRRewrite *rewrite) {
371 return rewrite->getKind() == Kind::ReplaceValue;
372 }
373
374protected:
375 ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
376 Value value)
377 : IRRewrite(kind, rewriterImpl), value(value) {}
378
379 // The value that this rewrite operates on.
380 Value value;
381};
382
383/// Creation of a block. Block creations are immediately reflected in the IR.
384/// There is no extra work to commit the rewrite. During rollback, the newly
385/// created block is erased.
386class CreateBlockRewrite : public BlockRewrite {
387public:
388 CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
389 : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
390
391 static bool classof(const IRRewrite *rewrite) {
392 return rewrite->getKind() == Kind::CreateBlock;
393 }
394
395 void commit(RewriterBase &rewriter) override {
396 // The block was already created and inserted. Just inform the listener.
397 if (auto *listener = rewriter.getListener())
398 listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
399 }
400
401 void rollback() override {
402 // Unlink all of the operations within this block, they will be deleted
403 // separately.
404 auto &blockOps = block->getOperations();
405 while (!blockOps.empty())
406 blockOps.remove(blockOps.begin());
407 block->dropAllUses();
408 if (block->getParent())
409 block->erase();
410 else
411 delete block;
412 }
413};
414
415/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
416/// blocks are immediately unlinked, but only erased during cleanup. This makes
417/// it easier to rollback a block erasure: the block is simply inserted into its
418/// original location.
419class EraseBlockRewrite : public BlockRewrite {
420public:
421 EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
422 : BlockRewrite(Kind::EraseBlock, rewriterImpl, block),
423 region(block->getParent()), insertBeforeBlock(block->getNextNode()) {}
424
425 static bool classof(const IRRewrite *rewrite) {
426 return rewrite->getKind() == Kind::EraseBlock;
427 }
428
429 ~EraseBlockRewrite() override {
430 assert(!block &&
431 "rewrite was neither rolled back nor committed/cleaned up");
432 }
433
434 void rollback() override {
435 // The block (owned by this rewrite) was not actually erased yet. It was
436 // just unlinked. Put it back into its original position.
437 assert(block && "expected block");
438 auto &blockList = region->getBlocks();
439 Region::iterator before = insertBeforeBlock
440 ? Region::iterator(insertBeforeBlock)
441 : blockList.end();
442 blockList.insert(before, block);
443 block = nullptr;
444 }
445
446 void commit(RewriterBase &rewriter) override {
447 assert(block && "expected block");
448
449 // Notify the listener that the block and its contents are being erased.
450 if (auto *listener =
451 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
452 notifyIRErased(listener, *block);
453 }
454
455 void cleanup(RewriterBase &rewriter) override {
456 // Erase the contents of the block.
457 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
458 rewriter.eraseOp(&op);
459 assert(block->empty() && "expected empty block");
460
461 // Erase the block.
462 block->dropAllDefinedValueUses();
463 delete block;
464 block = nullptr;
465 }
466
467private:
468 // The region in which this block was previously contained.
469 Region *region;
470
471 // The original successor of this block before it was unlinked. "nullptr" if
472 // this block was the only block in the region.
473 Block *insertBeforeBlock;
474};
475
476/// Inlining of a block. This rewrite is immediately reflected in the IR.
477/// Note: This rewrite represents only the inlining of the operations. The
478/// erasure of the inlined block is a separate rewrite.
479class InlineBlockRewrite : public BlockRewrite {
480public:
481 InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
482 Block *sourceBlock, Block::iterator before)
483 : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
484 sourceBlock(sourceBlock),
485 firstInlinedInst(sourceBlock->empty() ? nullptr
486 : &sourceBlock->front()),
487 lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
488 // If a listener is attached to the dialect conversion, ops must be moved
489 // one-by-one. When they are moved in bulk, notifications cannot be sent
490 // because the ops that used to be in the source block at the time of the
491 // inlining (before the "commit" phase) are unknown at the time when
492 // notifications are sent (which is during the "commit" phase).
493 assert(!getConfig().listener &&
494 "InlineBlockRewrite not supported if listener is attached");
495 }
496
497 static bool classof(const IRRewrite *rewrite) {
498 return rewrite->getKind() == Kind::InlineBlock;
499 }
500
501 void rollback() override {
502 // Put the operations from the destination block (owned by the rewrite)
503 // back into the source block.
504 if (firstInlinedInst) {
505 assert(lastInlinedInst && "expected operation");
506 sourceBlock->getOperations().splice(sourceBlock->begin(),
507 block->getOperations(),
508 Block::iterator(firstInlinedInst),
509 ++Block::iterator(lastInlinedInst));
510 }
511 }
512
513private:
514 // The block that originally contained the operations.
515 Block *sourceBlock;
516
517 // The first inlined operation.
518 Operation *firstInlinedInst;
519
520 // The last inlined operation.
521 Operation *lastInlinedInst;
522};
523
524/// Moving of a block. This rewrite is immediately reflected in the IR.
525class MoveBlockRewrite : public BlockRewrite {
526public:
527 MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
528 Region *previousRegion, Region::iterator previousIt)
529 : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
530 region(previousRegion),
531 insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
532 : &*previousIt) {}
533
534 static bool classof(const IRRewrite *rewrite) {
535 return rewrite->getKind() == Kind::MoveBlock;
536 }
537
538 void commit(RewriterBase &rewriter) override {
539 // The block was already moved. Just inform the listener.
540 if (auto *listener = rewriter.getListener()) {
541 // Note: `previousIt` cannot be passed because this is a delayed
542 // notification and iterators into past IR state cannot be represented.
543 listener->notifyBlockInserted(block, /*previous=*/region,
544 /*previousIt=*/{});
545 }
546 }
547
548 void rollback() override {
549 // Move the block back to its original position.
550 Region::iterator before =
551 insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
552 if (Region *currentParent = block->getParent()) {
553 // Block is still in a region, use cheap splice to move it back.
554 region->getBlocks().splice(before, currentParent->getBlocks(), block);
555 return;
556 }
557 // Block was orphaned by a prior rollback, can't splice.
558 region->getBlocks().insert(before, block);
559 }
560
561private:
562 // The region in which this block was previously contained.
563 Region *region;
564
565 // The original successor of this block before it was moved. "nullptr" if
566 // this block was the only block in the region.
567 Block *insertBeforeBlock;
568};
569
570/// Block type conversion. This rewrite is partially reflected in the IR.
571class BlockTypeConversionRewrite : public BlockRewrite {
572public:
573 BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
574 Block *origBlock, Block *newBlock)
575 : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
576 newBlock(newBlock) {}
577
578 static bool classof(const IRRewrite *rewrite) {
579 return rewrite->getKind() == Kind::BlockTypeConversion;
580 }
581
582 Block *getOrigBlock() const { return block; }
583
584 Block *getNewBlock() const { return newBlock; }
585
586 void commit(RewriterBase &rewriter) override;
587
588 void rollback() override;
589
590private:
591 /// The new block that was created as part of this signature conversion.
592 Block *newBlock;
593};
594
595/// Replacing a value. This rewrite is not immediately reflected in the
596/// IR. An internal IR mapping is updated, but the actual replacement is delayed
597/// until the rewrite is committed.
598class ReplaceValueRewrite : public ValueRewrite {
599public:
600 ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
601 const TypeConverter *converter)
602 : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
603 converter(converter) {}
604
605 static bool classof(const IRRewrite *rewrite) {
606 return rewrite->getKind() == Kind::ReplaceValue;
607 }
608
609 void commit(RewriterBase &rewriter) override;
610
611 void rollback() override;
612
613private:
614 /// The current type converter when the value was replaced.
615 const TypeConverter *converter;
616};
617
618/// An operation rewrite.
619class OperationRewrite : public IRRewrite {
620public:
621 /// Return the operation that this rewrite operates on.
622 Operation *getOperation() const { return op; }
623
624 static bool classof(const IRRewrite *rewrite) {
625 return rewrite->getKind() >= Kind::MoveOperation &&
626 rewrite->getKind() <= Kind::UnresolvedMaterialization;
627 }
628
629protected:
630 OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
631 Operation *op)
632 : IRRewrite(kind, rewriterImpl), op(op) {}
633
634 // The operation that this rewrite operates on.
635 Operation *op;
636};
637
638/// Moving of an operation. This rewrite is immediately reflected in the IR.
639class MoveOperationRewrite : public OperationRewrite {
640public:
641 MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
642 Operation *op, OpBuilder::InsertPoint previous)
643 : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
644 block(previous.getBlock()),
645 insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
646 ? nullptr
647 : &*previous.getPoint()) {}
648
649 static bool classof(const IRRewrite *rewrite) {
650 return rewrite->getKind() == Kind::MoveOperation;
651 }
652
653 void commit(RewriterBase &rewriter) override {
654 // The operation was already moved. Just inform the listener.
655 if (auto *listener = rewriter.getListener()) {
656 // Note: `previousIt` cannot be passed because this is a delayed
657 // notification and iterators into past IR state cannot be represented.
658 listener->notifyOperationInserted(
659 op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
660 /*insertPt=*/{}));
661 }
662 }
663
664 void rollback() override {
665 // Move the operation back to its original position.
666 Block::iterator before =
667 insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
668 block->getOperations().splice(before, op->getBlock()->getOperations(), op);
669 }
670
671private:
672 // The block in which this operation was previously contained.
673 Block *block;
674
675 // The original successor of this operation before it was moved. "nullptr"
676 // if this operation was the only operation in the region.
677 Operation *insertBeforeOp;
678};
679
680/// In-place modification of an op. This rewrite is immediately reflected in
681/// the IR. The previous state of the operation is stored in this object.
682class ModifyOperationRewrite : public OperationRewrite {
683public:
684 ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
685 Operation *op)
686 : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
687 name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()),
688 operands(op->operand_begin(), op->operand_end()),
689 successors(op->successor_begin(), op->successor_end()) {
690 if (PropertyRef prop = op->getPropertiesStorage()) {
691 // Make a copy of the properties.
692 propertiesStorage = operator new(op->getPropertiesStorageSize());
693 PropertyRef propCopy(name.getOpPropertiesTypeID(), propertiesStorage);
694 name.initOpProperties(propCopy, /*init=*/prop);
695 }
696 }
697
698 static bool classof(const IRRewrite *rewrite) {
699 return rewrite->getKind() == Kind::ModifyOperation;
700 }
701
702 ~ModifyOperationRewrite() override {
703 assert(!propertiesStorage &&
704 "rewrite was neither committed nor rolled back");
705 }
706
707 void commit(RewriterBase &rewriter) override {
708 // Notify the listener that the operation was modified in-place.
709 if (auto *listener =
710 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
711 listener->notifyOperationModified(op);
712
713 if (propertiesStorage) {
714 PropertyRef propCopy(name.getOpPropertiesTypeID(), propertiesStorage);
715 // Note: The operation may have been erased in the mean time, so
716 // OperationName must be stored in this object.
717 name.destroyOpProperties(propCopy);
718 operator delete(propertiesStorage);
719 propertiesStorage = nullptr;
720 }
721 }
722
723 void rollback() override {
724 op->setLoc(loc);
725 op->setAttrs(attrs);
726 op->setOperands(operands);
727 for (const auto &it : llvm::enumerate(successors))
728 op->setSuccessor(it.value(), it.index());
729 if (propertiesStorage) {
730 PropertyRef propCopy(name.getOpPropertiesTypeID(), propertiesStorage);
731 op->copyProperties(propCopy);
732 name.destroyOpProperties(propCopy);
733 operator delete(propertiesStorage);
734 propertiesStorage = nullptr;
735 }
736 }
737
738private:
739 OperationName name;
740 LocationAttr loc;
741 DictionaryAttr attrs;
742 SmallVector<Value, 8> operands;
743 SmallVector<Block *, 2> successors;
744 void *propertiesStorage = nullptr;
745};
746
747/// Replacing an operation. Erasing an operation is treated as a special case
748/// with "null" replacements. This rewrite is not immediately reflected in the
749/// IR. An internal IR mapping is updated, but values are not replaced and the
750/// original op is not erased until the rewrite is committed.
751class ReplaceOperationRewrite : public OperationRewrite {
752public:
753 ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
754 Operation *op, const TypeConverter *converter)
755 : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
756 converter(converter) {}
757
758 static bool classof(const IRRewrite *rewrite) {
759 return rewrite->getKind() == Kind::ReplaceOperation;
760 }
761
762 void commit(RewriterBase &rewriter) override;
763
764 void rollback() override;
765
766 void cleanup(RewriterBase &rewriter) override;
767
768private:
769 /// An optional type converter that can be used to materialize conversions
770 /// between the new and old values if necessary.
771 const TypeConverter *converter;
772};
773
774class CreateOperationRewrite : public OperationRewrite {
775public:
776 CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
777 Operation *op)
778 : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
779
780 static bool classof(const IRRewrite *rewrite) {
781 return rewrite->getKind() == Kind::CreateOperation;
782 }
783
784 void commit(RewriterBase &rewriter) override {
785 // The operation was already created and inserted. Just inform the listener.
786 if (auto *listener = rewriter.getListener())
787 listener->notifyOperationInserted(op, /*previous=*/{});
788 }
789
790 void rollback() override;
791};
792
793/// The type of materialization.
794enum MaterializationKind {
795 /// This materialization materializes a conversion from an illegal type to a
796 /// legal one.
797 Target,
798
799 /// This materialization materializes a conversion from a legal type back to
800 /// an illegal one.
801 Source
802};
803
804/// Helper class that stores metadata about an unresolved materialization.
805class UnresolvedMaterializationInfo {
806public:
807 UnresolvedMaterializationInfo() = default;
808 UnresolvedMaterializationInfo(const TypeConverter *converter,
809 MaterializationKind kind, Type originalType)
810 : converterAndKind(converter, kind), originalType(originalType) {}
811
812 /// Return the type converter of this materialization (which may be null).
813 const TypeConverter *getConverter() const {
814 return converterAndKind.getPointer();
815 }
816
817 /// Return the kind of this materialization.
818 MaterializationKind getMaterializationKind() const {
819 return converterAndKind.getInt();
820 }
821
822 /// Return the original type of the SSA value.
823 Type getOriginalType() const { return originalType; }
824
825private:
826 /// The corresponding type converter to use when resolving this
827 /// materialization, and the kind of this materialization.
828 llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
829 converterAndKind;
830
831 /// The original type of the SSA value. Only used for target
832 /// materializations.
833 Type originalType;
834};
835
836/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
837/// op. Unresolved materializations fold away or are replaced with
838/// source/target materializations at the end of the dialect conversion.
839class UnresolvedMaterializationRewrite : public OperationRewrite {
840public:
841 UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
842 UnrealizedConversionCastOp op,
843 ValueVector mappedValues)
844 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
845 mappedValues(std::move(mappedValues)) {}
846
847 static bool classof(const IRRewrite *rewrite) {
848 return rewrite->getKind() == Kind::UnresolvedMaterialization;
849 }
850
851 void rollback() override;
852
853 UnrealizedConversionCastOp getOperation() const {
854 return cast<UnrealizedConversionCastOp>(op);
855 }
856
857private:
858 /// The values in the conversion value mapping that are being replaced by the
859 /// results of this unresolved materialization.
860 ValueVector mappedValues;
861};
862} // namespace
863
864#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
865/// Return "true" if there is an operation rewrite that matches the specified
866/// rewrite type and operation among the given rewrites.
867template <typename RewriteTy, typename R>
868static bool hasRewrite(R &&rewrites, Operation *op) {
869 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
870 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
871 return rewriteTy && rewriteTy->getOperation() == op;
872 });
873}
874
875/// Return "true" if there is a block rewrite that matches the specified
876/// rewrite type and block among the given rewrites.
877template <typename RewriteTy, typename R>
878static bool hasRewrite(R &&rewrites, Block *block) {
879 return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
880 auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
881 return rewriteTy && rewriteTy->getBlock() == block;
882 });
883}
884#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
885
886//===----------------------------------------------------------------------===//
887// ConversionPatternRewriterImpl
888//===----------------------------------------------------------------------===//
889namespace mlir {
890namespace detail {
892 explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
893 const ConversionConfig &config,
897
898 //===--------------------------------------------------------------------===//
899 // State Management
900 //===--------------------------------------------------------------------===//
901
902 /// Return the current state of the rewriter.
903 RewriterState getCurrentState();
904
905 /// Apply all requested operation rewrites. This method is invoked when the
906 /// conversion process succeeds.
907 void applyRewrites();
908
909 /// Reset the state of the rewriter to a previously saved point. Optionally,
910 /// the name of the pattern that triggered the rollback can specified for
911 /// debugging purposes.
912 void resetState(RewriterState state, StringRef patternName = "");
913
914 /// Append a rewrite. Rewrites are committed upon success and rolled back upon
915 /// failure.
916 template <typename RewriteTy, typename... Args>
917 void appendRewrite(Args &&...args) {
918 assert(config.allowPatternRollback && "appending rewrites is not allowed");
919 rewrites.push_back(
920 std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
921 }
922
923 /// Undo the rewrites (motions, splits) one by one in reverse order until
924 /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
925 /// that triggered the rollback can specified for debugging purposes.
926 void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
927
928 /// Remap the given values to those with potentially different types. Returns
929 /// success if the values could be remapped, failure otherwise. `valueDiagTag`
930 /// is the tag used when describing a value within a diagnostic, e.g.
931 /// "operand".
932 LogicalResult remapValues(StringRef valueDiagTag,
933 std::optional<Location> inputLoc, ValueRange values,
934 SmallVector<ValueVector> &remapped);
935
936 /// Return "true" if the given operation is ignored, and does not need to be
937 /// converted.
938 bool isOpIgnored(Operation *op) const;
939
940 /// Return "true" if the given operation was replaced or erased.
941 bool wasOpReplaced(Operation *op) const;
942
943 /// Lookup the most recently mapped values with the desired types in the
944 /// mapping, taking into account only replacements. Perform a best-effort
945 /// search for existing materializations with the desired types.
946 ///
947 /// If `skipPureTypeConversions` is "true", materializations that are pure
948 /// type conversions are not considered.
949 ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
950 bool skipPureTypeConversions = false) const;
951
952 /// Lookup the given value within the map, or return an empty vector if the
953 /// value is not mapped. If it is mapped, this follows the same behavior
954 /// as `lookupOrDefault`.
955 ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
956
957 //===--------------------------------------------------------------------===//
958 // IR Rewrites / Type Conversion
959 //===--------------------------------------------------------------------===//
960
961 /// Convert the types of block arguments within the given region.
962 FailureOr<Block *>
963 convertRegionTypes(Region *region, const TypeConverter &converter,
964 TypeConverter::SignatureConversion *entryConversion);
965
966 /// Apply the given signature conversion on the given block. The new block
967 /// containing the updated signature is returned. If no conversions were
968 /// necessary, e.g. if the block has no arguments, `block` is returned.
969 /// `converter` is used to generate any necessary cast operations that
970 /// translate between the origin argument types and those specified in the
971 /// signature conversion.
972 Block *applySignatureConversion(
973 Block *block, const TypeConverter *converter,
974 TypeConverter::SignatureConversion &signatureConversion);
975
976 /// Replace the results of the given operation with the given values and
977 /// erase the operation.
978 ///
979 /// There can be multiple replacement values for each result (1:N
980 /// replacement). If the replacement values are empty, the respective result
981 /// is dropped and a source materialization is built if the result still has
982 /// uses.
983 void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
984
985 /// Replace the uses of the given value with the given values. The specified
986 /// converter is used to build materializations (if necessary). If `functor`
987 /// is specified, only the uses that the functor returns "true" for are
988 /// replaced.
989 void replaceValueUses(Value from, ValueRange to,
990 const TypeConverter *converter,
991 function_ref<bool(OpOperand &)> functor = nullptr);
992
993 /// Erase the given block and its contents.
994 void eraseBlock(Block *block);
995
996 /// Inline the source block into the destination block before the given
997 /// iterator.
998 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
999
1000 //===--------------------------------------------------------------------===//
1001 // Materializations
1002 //===--------------------------------------------------------------------===//
1003
1004 /// Build an unresolved materialization operation given a range of output
1005 /// types and a list of input operands. Returns the inputs if they their
1006 /// types match the output types.
1007 ///
1008 /// If a cast op was built, it can optionally be returned with the `castOp`
1009 /// output argument.
1010 ///
1011 /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
1012 /// the results of the unresolved materialization in the conversion value
1013 /// mapping.
1014 ///
1015 /// If `isPureTypeConversion` is "true", the materialization is created only
1016 /// to resolve a type mismatch. That means it is not a regular value
1017 /// replacement issued by the user. (Replacement values that are created
1018 /// "out of thin air" appear like unresolved materializations because they are
1019 /// unrealized_conversion_cast ops. However, they must be treated like
1020 /// regular value replacements.)
1021 ValueRange buildUnresolvedMaterialization(
1022 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1023 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1024 Type originalType, const TypeConverter *converter,
1025 bool isPureTypeConversion = true);
1026
1027 /// Find a replacement value for the given SSA value in the conversion value
1028 /// mapping. The replacement value must have the same type as the given SSA
1029 /// value. If there is no replacement value with the correct type, find the
1030 /// latest replacement value (regardless of the type) and build a source
1031 /// materialization.
1032 Value findOrBuildReplacementValue(Value value,
1033 const TypeConverter *converter);
1034
1035 //===--------------------------------------------------------------------===//
1036 // Rewriter Notification Hooks
1037 //===--------------------------------------------------------------------===//
1038
1039 //// Notifies that an op was inserted.
1040 void notifyOperationInserted(Operation *op,
1041 OpBuilder::InsertPoint previous) override;
1042
1043 /// Notifies that a block was inserted.
1044 void notifyBlockInserted(Block *block, Region *previous,
1045 Region::iterator previousIt) override;
1046
1047 /// Notifies that a pattern match failed for the given reason.
1048 void
1049 notifyMatchFailure(Location loc,
1050 function_ref<void(Diagnostic &)> reasonCallback) override;
1051
1052 //===--------------------------------------------------------------------===//
1053 // IR Erasure
1054 //===--------------------------------------------------------------------===//
1055
1056 /// A rewriter that keeps track of erased ops and blocks. It ensures that no
1057 /// operation or block is erased multiple times. This rewriter assumes that
1058 /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
1060 public:
1063 std::function<void(Operation *)> opErasedCallback = nullptr)
1064 : RewriterBase(context, /*listener=*/this),
1065 opErasedCallback(std::move(opErasedCallback)) {}
1066
1067 /// Erase the given op (unless it was already erased).
1068 void eraseOp(Operation *op) override {
1069 if (wasErased(op))
1070 return;
1071 op->dropAllUses();
1073 }
1074
1075 /// Erase the given block (unless it was already erased).
1076 void eraseBlock(Block *block) override {
1077 if (wasErased(block))
1078 return;
1079 assert(block->empty() && "expected empty block");
1080 block->dropAllDefinedValueUses();
1082 }
1083
1084 bool wasErased(void *ptr) const { return erased.contains(ptr); }
1085
1087 erased.insert(op);
1088 if (opErasedCallback)
1089 opErasedCallback(op);
1090 }
1091
1092 void notifyBlockErased(Block *block) override { erased.insert(block); }
1093
1094 private:
1095 /// Pointers to all erased operations and blocks.
1096 DenseSet<void *> erased;
1097
1098 /// A callback that is invoked when an operation is erased.
1099 std::function<void(Operation *)> opErasedCallback;
1100 };
1101
1102 //===--------------------------------------------------------------------===//
1103 // State
1104 //===--------------------------------------------------------------------===//
1105
1106 /// The rewriter that is used to perform the conversion.
1107 ConversionPatternRewriter &rewriter;
1108
1109 // Mapping between replaced values that differ in type. This happens when
1110 // replacing a value with one of a different type.
1111 ConversionValueMapping mapping;
1112
1113 /// Ordered list of block operations (creations, splits, motions).
1114 /// This vector is maintained only if `allowPatternRollback` is set to
1115 /// "true". Otherwise, all IR rewrites are materialized immediately and no
1116 /// bookkeeping is needed.
1118
1119 /// A set of operations that should no longer be considered for legalization.
1120 /// E.g., ops that are recursively legal. Ops that were replaced/erased are
1121 /// tracked separately.
1123
1124 /// A set of operations that were replaced/erased. Such ops are not erased
1125 /// immediately but only when the dialect conversion succeeds. In the mean
1126 /// time, they should no longer be considered for legalization and any attempt
1127 /// to modify/access them is invalid rewriter API usage.
1129
1130 /// A set of operations that were created by the current pattern.
1132
1133 /// A set of operations that were modified by the current pattern.
1135
1136 /// A list of unresolved materializations that were created by the current
1137 /// pattern.
1139
1140 /// A mapping for looking up metadata of unresolved materializations.
1141 llvm::MapVector<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
1143
1144 /// The current type converter, or nullptr if no type converter is currently
1145 /// active.
1147
1148 /// A mapping of regions to type converters that should be used when
1149 /// converting the arguments of blocks within that region.
1151
1152 /// Dialect conversion configuration.
1153 const ConversionConfig &config;
1154
1155 /// The operation converter to use for recursive legalization.
1157
1158 /// A set of erased operations. This set is utilized only if
1159 /// `allowPatternRollback` is set to "false". Conceptually, this set is
1160 /// similar to `replacedOps` (which is maintained when the flag is set to
1161 /// "true"). However, erasing from a DenseSet is more efficient than erasing
1162 /// from a SetVector.
1164
1165 /// A set of erased blocks. This set is utilized only if
1166 /// `allowPatternRollback` is set to "false".
1168
1169 /// A rewriter that notifies the listener (if any) about all IR
1170 /// modifications. This rewriter is utilized only if `allowPatternRollback`
1171 /// is set to "false". If the flag is set to "true", the listener is notified
1172 /// with a separate mechanism (e.g., in `IRRewrite::commit`).
1174
1175#ifndef NDEBUG
1176 /// A set of replaced values. This set is for debugging purposes only and it
1177 /// is maintained only if `allowPatternRollback` is set to "true".
1179
1180 /// A set of operations that have pending updates. This tracking isn't
1181 /// strictly necessary, and is thus only active during debug builds for extra
1182 /// verification.
1184
1185 /// A raw output stream used to prefix the debug log.
1186 llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
1187 llvm::dbgs()};
1188
1189 /// A logger used to emit diagnostics during the conversion process.
1190 llvm::ScopedPrinter logger{os};
1191 std::string logPrefix;
1192#endif
1193};
1194} // namespace detail
1195} // namespace mlir
1196
1197const ConversionConfig &IRRewrite::getConfig() const {
1198 return rewriterImpl.config;
1199}
1200
1201void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
1202 // Inform the listener about all IR modifications that have already taken
1203 // place: References to the original block have been replaced with the new
1204 // block.
1205 if (auto *listener =
1206 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
1207 for (Operation *op : getNewBlock()->getUsers())
1208 listener->notifyOperationModified(op);
1209}
1210
1211void BlockTypeConversionRewrite::rollback() {
1212 getNewBlock()->replaceAllUsesWith(getOrigBlock());
1213}
1214
1215/// Replace all uses of `from` with `repl`.
1216static void
1218 function_ref<bool(OpOperand &)> functor = nullptr) {
1219 if (isa<BlockArgument>(repl)) {
1220 // `repl` is a block argument. Directly replace all uses.
1221 if (functor) {
1222 rewriter.replaceUsesWithIf(from, repl, functor);
1223 } else {
1224 rewriter.replaceAllUsesWith(from, repl);
1225 }
1226 return;
1227 }
1228
1229 // If the replacement value is an operation, only replace those uses that:
1230 // - are in a different block than the replacement operation, or
1231 // - are in the same block but after the replacement operation.
1232 //
1233 // Example:
1234 // ^bb0(%arg0: i32):
1235 // %0 = "consumer"(%arg0) : (i32) -> (i32)
1236 // "another_consumer"(%arg0) : (i32) -> ()
1237 //
1238 // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1239 // use in "another_consumer" but not the use in "consumer". When using the
1240 // normal RewriterBase API, this would typically be done with
1241 // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1242 // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1243 // it cannot be supported efficiently with `allowPatternRollback` set to
1244 // "true". Therefore, the conversion driver is trying to be smart and replaces
1245 // only those uses that do not lead to a dominance violation. E.g., the
1246 // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1247 // behavior.
1248 //
1249 // TODO: As we move more and more towards `allowPatternRollback` set to
1250 // "false", we should remove this special handling, in order to align the
1251 // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1252 Operation *replOp = repl.getDefiningOp();
1253 Block *replBlock = replOp->getBlock();
1254 rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
1255 Operation *user = operand.getOwner();
1256 bool result =
1257 user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
1258 if (result && functor)
1259 result &= functor(operand);
1260 return result;
1261 });
1262}
1263
1264void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1265 Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
1266 if (!repl)
1267 return;
1268 performReplaceValue(rewriter, value, repl);
1269}
1270
1271void ReplaceValueRewrite::rollback() {
1272 rewriterImpl.mapping.erase({value});
1273#ifndef NDEBUG
1274 rewriterImpl.replacedValues.erase(value);
1275#endif // NDEBUG
1276}
1277
1278void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
1279 auto *listener =
1280 dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
1281
1282 // Compute replacement values.
1283 SmallVector<Value> replacements =
1284 llvm::map_to_vector(op->getResults(), [&](OpResult result) {
1285 return rewriterImpl.findOrBuildReplacementValue(result, converter);
1286 });
1287
1288 // Notify the listener that the operation is about to be replaced.
1289 if (listener)
1290 listener->notifyOperationReplaced(op, replacements);
1291
1292 // Replace all uses with the new values.
1293 for (auto [result, newValue] :
1294 llvm::zip_equal(op->getResults(), replacements))
1295 if (newValue)
1296 rewriter.replaceAllUsesWith(result, newValue);
1297
1298 // The original op will be erased, so remove it from the set of unlegalized
1299 // ops.
1300 if (getConfig().unlegalizedOps)
1301 getConfig().unlegalizedOps->erase(op);
1302
1303 // Notify the listener that the operation and its contents are being erased.
1304 if (listener)
1305 notifyIRErased(listener, *op);
1306
1307 // Do not erase the operation yet. It may still be referenced in `mapping`.
1308 // Just unlink it for now and erase it during cleanup.
1309 if (!op->getBlock())
1310 llvm::reportFatalInternalError(
1311 "dialect conversion attempted to replace a root operation that has no "
1312 "parent block; the pass must ensure its target op is nested in a "
1313 "block");
1314 op->getBlock()->getOperations().remove(op);
1315}
1316
1317void ReplaceOperationRewrite::rollback() {
1318 for (auto result : op->getResults())
1319 rewriterImpl.mapping.erase({result});
1320}
1321
1322void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
1323 rewriter.eraseOp(op);
1324}
1325
1326void CreateOperationRewrite::rollback() {
1327 for (Region &region : op->getRegions()) {
1328 while (!region.getBlocks().empty())
1329 region.getBlocks().remove(region.getBlocks().begin());
1330 }
1331 op->dropAllUses();
1332 op->erase();
1333}
1334
1335void UnresolvedMaterializationRewrite::rollback() {
1336 if (!mappedValues.empty())
1337 rewriterImpl.mapping.erase(mappedValues);
1338 rewriterImpl.unresolvedMaterializations.erase(getOperation());
1339 op->erase();
1340}
1341
1343 // Commit all rewrites. Use a new rewriter, so the modifications are not
1344 // tracked for rollback purposes etc.
1345 IRRewriter irRewriter(rewriter.getContext(), config.listener);
1346 // Note: New rewrites may be added during the "commit" phase and the
1347 // `rewrites` vector may reallocate.
1348 for (size_t i = 0; i < rewrites.size(); ++i)
1349 rewrites[i]->commit(irRewriter);
1350
1351 // Clean up all rewrites.
1352 SingleEraseRewriter eraseRewriter(
1353 rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
1354 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1355 unresolvedMaterializations.erase(castOp);
1356 });
1357 for (auto &rewrite : rewrites)
1358 rewrite->cleanup(eraseRewriter);
1359}
1360
1361//===----------------------------------------------------------------------===//
1362// State Management
1363//===----------------------------------------------------------------------===//
1364
1366 Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
1367 // Helper function that looks up a single value.
1368 auto lookup = [&](const ValueVector &values) -> ValueVector {
1369 assert(!values.empty() && "expected non-empty value vector");
1370
1371 // If the pattern rollback is enabled, use the mapping to look up the
1372 // values.
1373 if (config.allowPatternRollback)
1374 return mapping.lookup(values);
1375
1376 // Otherwise, look up values by examining the IR. All replacements have
1377 // already been materialized in IR.
1378 Operation *op = getCommonDefiningOp(values);
1379 if (!op)
1380 return {};
1381 auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
1382 if (!castOp)
1383 return {};
1384 if (!this->unresolvedMaterializations.contains(castOp))
1385 return {};
1386 if (castOp.getOutputs() != values)
1387 return {};
1388 return castOp.getInputs();
1389 };
1390
1391 // Helper function that looks up each value in `values` individually and then
1392 // composes the results. If that fails, it tries to look up the entire vector
1393 // at once.
1394 auto composedLookup = [&](const ValueVector &values) -> ValueVector {
1395 // If possible, replace each value with (one or multiple) mapped values.
1396 ValueVector next;
1397 for (Value v : values) {
1398 ValueVector r = lookup({v});
1399 if (!r.empty()) {
1400 llvm::append_range(next, r);
1401 } else {
1402 next.push_back(v);
1403 }
1404 }
1405 if (next != values) {
1406 // At least one value was replaced.
1407 return next;
1408 }
1409
1410 // Otherwise: Check if there is a mapping for the entire vector. Such
1411 // mappings are materializations. (N:M mapping are not supported for value
1412 // replacements.)
1413 //
1414 // Note: From a correctness point of view, materializations do not have to
1415 // be stored (and looked up) in the mapping. But for performance reasons,
1416 // we choose to reuse existing IR (when possible) instead of creating it
1417 // multiple times.
1418 ValueVector r = lookup(values);
1419 if (r.empty()) {
1420 // No mapping found: The lookup stops here.
1421 return {};
1422 }
1423 return r;
1424 };
1425
1426 // Try to find the deepest values that have the desired types. If there is no
1427 // such mapping, simply return the deepest values.
1428 ValueVector desiredValue;
1429 ValueVector current{from};
1430 ValueVector lastNonMaterialization{from};
1431 do {
1432 // Store the current value if the types match.
1433 bool match = TypeRange(ValueRange(current)) == desiredTypes;
1434 if (skipPureTypeConversions) {
1435 // Skip pure type conversions, if requested.
1436 bool pureConversion = isPureTypeConversion(current);
1437 match &= !pureConversion;
1438 // Keep track of the last mapped value that was not a pure type
1439 // conversion.
1440 if (!pureConversion)
1441 lastNonMaterialization = current;
1442 }
1443 if (match)
1444 desiredValue = current;
1445
1446 // Lookup next value in the mapping.
1447 ValueVector next = composedLookup(current);
1448 if (next.empty())
1449 break;
1450 current = std::move(next);
1451 } while (true);
1452
1453 // If the desired values were found use them, otherwise default to the leaf
1454 // values. (Skip pure type conversions, if requested.)
1455 if (!desiredTypes.empty())
1456 return desiredValue;
1457 if (skipPureTypeConversions)
1458 return lastNonMaterialization;
1459 return current;
1460}
1461
1464 TypeRange desiredTypes) const {
1465 ValueVector result = lookupOrDefault(from, desiredTypes);
1466 if (result == ValueVector{from} ||
1467 (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
1468 return {};
1469 return result;
1470}
1471
1473 return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
1474}
1475
1477 StringRef patternName) {
1478 // Undo any rewrites.
1479 undoRewrites(state.numRewrites, patternName);
1480
1481 // Pop all of the recorded ignored operations that are no longer valid.
1482 while (ignoredOps.size() != state.numIgnoredOperations)
1483 ignoredOps.pop_back();
1484
1485 while (replacedOps.size() != state.numReplacedOps)
1486 replacedOps.pop_back();
1487}
1488
1489void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
1490 StringRef patternName) {
1491 for (auto &rewrite :
1492 llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
1493 rewrite->rollback();
1494 rewrites.resize(numRewritesToKeep);
1495}
1496
1498 StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
1499 SmallVector<ValueVector> &remapped) {
1500 remapped.reserve(llvm::size(values));
1501
1502 for (const auto &it : llvm::enumerate(values)) {
1503 Value operand = it.value();
1504 Type origType = operand.getType();
1505 Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1506
1507 if (!currentTypeConverter) {
1508 // The current pattern does not have a type converter. Pass the most
1509 // recently mapped values, excluding materializations. Materializations
1510 // are intentionally excluded because their presence may depend on other
1511 // patterns. Including materializations would make the lookup fragile
1512 // and unpredictable.
1513 remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1514 /*skipPureTypeConversions=*/true));
1515 continue;
1516 }
1517
1518 // If there is no legal conversion, fail to match this pattern.
1519 SmallVector<Type, 1> legalTypes;
1520 if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
1521 notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
1522 diag << "unable to convert type for " << valueDiagTag << " #"
1523 << it.index() << ", type was " << origType;
1524 });
1525 return failure();
1526 }
1527 // If a type is converted to 0 types, there is nothing to do.
1528 if (legalTypes.empty()) {
1529 remapped.push_back({});
1530 continue;
1531 }
1532
1533 ValueVector repl = lookupOrDefault(operand, legalTypes);
1534 if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
1535 // Mapped values have the correct type or there is an existing
1536 // materialization. Or the operand is not mapped at all and has the
1537 // correct type.
1538 remapped.push_back(std::move(repl));
1539 continue;
1540 }
1541
1542 // Create a materialization for the most recently mapped values.
1543 repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1544 /*skipPureTypeConversions=*/true);
1546 MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
1547 /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
1548 /*originalType=*/origType, currentTypeConverter);
1549 remapped.push_back(castValues);
1550 }
1551 return success();
1552}
1553
1555 // Check to see if this operation is ignored or was replaced.
1556 return wasOpReplaced(op) || ignoredOps.count(op);
1557}
1558
1560 // Check to see if this operation was replaced.
1561 return replacedOps.count(op) || erasedOps.count(op);
1562}
1563
1564//===----------------------------------------------------------------------===//
1565// Type Conversion
1566//===----------------------------------------------------------------------===//
1567
1569 Region *region, const TypeConverter &converter,
1570 TypeConverter::SignatureConversion *entryConversion) {
1571 regionToConverter[region] = &converter;
1572 if (region->empty())
1573 return nullptr;
1574
1575 // Convert the arguments of each non-entry block within the region.
1576 for (Block &block :
1577 llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1578 // Compute the signature for the block with the provided converter.
1579 std::optional<TypeConverter::SignatureConversion> conversion =
1580 converter.convertBlockSignature(&block);
1581 if (!conversion)
1582 return failure();
1583 // Convert the block with the computed signature.
1584 applySignatureConversion(&block, &converter, *conversion);
1585 }
1586
1587 // Convert the entry block. If an entry signature conversion was provided,
1588 // use that one. Otherwise, compute the signature with the type converter.
1589 if (entryConversion)
1590 return applySignatureConversion(&region->front(), &converter,
1591 *entryConversion);
1592 std::optional<TypeConverter::SignatureConversion> conversion =
1593 converter.convertBlockSignature(&region->front());
1594 if (!conversion)
1595 return failure();
1596 return applySignatureConversion(&region->front(), &converter, *conversion);
1597}
1598
1600 Block *block, const TypeConverter *converter,
1601 TypeConverter::SignatureConversion &signatureConversion) {
1602#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1603 // A block cannot be converted multiple times.
1604 if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1605 llvm::reportFatalInternalError("block was already converted");
1606#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1607
1609
1610 // If no arguments are being changed or added, there is nothing to do.
1611 unsigned origArgCount = block->getNumArguments();
1612 auto convertedTypes = signatureConversion.getConvertedTypes();
1613 if (llvm::equal(block->getArgumentTypes(), convertedTypes))
1614 return block;
1615
1616 // Compute the locations of all block arguments in the new block.
1617 SmallVector<Location> newLocs(convertedTypes.size(),
1618 rewriter.getUnknownLoc());
1619 for (unsigned i = 0; i < origArgCount; ++i) {
1620 auto inputMap = signatureConversion.getInputMapping(i);
1621 if (!inputMap || inputMap->replacedWithValues())
1622 continue;
1623 Location origLoc = block->getArgument(i).getLoc();
1624 for (unsigned j = 0; j < inputMap->size; ++j)
1625 newLocs[inputMap->inputNo + j] = origLoc;
1626 }
1627
1628 // Insert a new block with the converted block argument types and move all ops
1629 // from the old block to the new block.
1630 Block *newBlock =
1631 rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
1632 convertedTypes, newLocs);
1633
1634 // If a listener is attached to the dialect conversion, ops cannot be moved
1635 // to the destination block in bulk ("fast path"). This is because at the time
1636 // the notifications are sent, it is unknown which ops were moved. Instead,
1637 // ops should be moved one-by-one ("slow path"), so that a separate
1638 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
1639 // a bit more efficient, so we try to do that when possible.
1640 bool fastPath = !config.listener;
1641 if (fastPath) {
1642 if (config.allowPatternRollback)
1643 appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
1644 newBlock->getOperations().splice(newBlock->end(), block->getOperations());
1645 } else {
1646 while (!block->empty())
1647 rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
1648 }
1649
1650 // Replace all uses of the old block with the new block.
1651 block->replaceAllUsesWith(newBlock);
1652
1653 for (unsigned i = 0; i != origArgCount; ++i) {
1654 BlockArgument origArg = block->getArgument(i);
1655 Type origArgType = origArg.getType();
1656
1657 std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1658 signatureConversion.getInputMapping(i);
1659 if (!inputMap) {
1660 // This block argument was dropped and no replacement value was provided.
1661 // Materialize a replacement value "out of thin air".
1662 // Note: Materialization must be built here because we cannot find a
1663 // valid insertion point in the new block. (Will point to the old block.)
1664 Value mat =
1666 MaterializationKind::Source,
1667 OpBuilder::InsertPoint(newBlock, newBlock->begin()),
1668 origArg.getLoc(),
1669 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1670 /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1671 /*isPureTypeConversion=*/false)
1672 .front();
1673 replaceValueUses(origArg, mat, converter);
1674 continue;
1675 }
1676
1677 if (inputMap->replacedWithValues()) {
1678 // This block argument was dropped and replacement values were provided.
1679 assert(inputMap->size == 0 &&
1680 "invalid to provide a replacement value when the argument isn't "
1681 "dropped");
1682 replaceValueUses(origArg, inputMap->replacementValues, converter);
1683 continue;
1684 }
1685
1686 // This is a 1->1+ mapping.
1687 auto replArgs =
1688 newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1689 replaceValueUses(origArg, replArgs, converter);
1690 }
1691
1692 if (config.allowPatternRollback)
1693 appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
1694
1695 // Erase the old block. (It is just unlinked for now and will be erased during
1696 // cleanup.)
1697 rewriter.eraseBlock(block);
1698
1699 return newBlock;
1700}
1701
1702//===----------------------------------------------------------------------===//
1703// Materializations
1704//===----------------------------------------------------------------------===//
1705
1706/// Build an unresolved materialization operation given an output type and set
1707/// of input operands.
1709 MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1710 ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
1711 Type originalType, const TypeConverter *converter,
1712 bool isPureTypeConversion) {
1713 assert((!originalType || kind == MaterializationKind::Target) &&
1714 "original type is valid only for target materializations");
1715 assert(TypeRange(inputs) != outputTypes &&
1716 "materialization is not necessary");
1717
1718 // Create an unresolved materialization. We use a new OpBuilder to avoid
1719 // tracking the materialization like we do for other operations.
1720 OpBuilder builder(outputTypes.front().getContext());
1721 builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
1722 UnrealizedConversionCastOp convertOp =
1723 UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
1724 if (config.attachDebugMaterializationKind) {
1725 StringRef kindStr =
1726 kind == MaterializationKind::Source ? "source" : "target";
1727 convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
1728 }
1730 convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
1731
1732 // Register the materialization.
1733 unresolvedMaterializations[convertOp] =
1734 UnresolvedMaterializationInfo(converter, kind, originalType);
1735 if (config.allowPatternRollback) {
1736 if (!valuesToMap.empty())
1737 mapping.map(valuesToMap, convertOp.getResults());
1739 std::move(valuesToMap));
1740 } else {
1741 patternMaterializations.insert(convertOp);
1742 }
1743 return convertOp.getResults();
1744}
1745
1747 Value value, const TypeConverter *converter) {
1748 assert(config.allowPatternRollback &&
1749 "this code path is valid only in rollback mode");
1750
1751 // Try to find a replacement value with the same type in the conversion value
1752 // mapping. This includes cached materializations. We try to reuse those
1753 // instead of generating duplicate IR.
1754 ValueVector repl = lookupOrNull(value, value.getType());
1755 if (!repl.empty())
1756 return repl.front();
1757
1758 // Check if the value is dead. No replacement value is needed in that case.
1759 // This is an approximate check that may have false negatives but does not
1760 // require computing and traversing an inverse mapping. (We may end up
1761 // building source materializations that are never used and that fold away.)
1762 if (llvm::all_of(value.getUsers(),
1763 [&](Operation *op) { return replacedOps.contains(op); }) &&
1764 !mapping.isMappedTo(value))
1765 return Value();
1766
1767 // No replacement value was found. Get the latest replacement value
1768 // (regardless of the type) and build a source materialization to the
1769 // original type.
1770 repl = lookupOrNull(value);
1771
1772 // Compute the insertion point of the materialization.
1774 if (repl.empty()) {
1775 // The source materialization has no inputs. Insert it right before the
1776 // value that it is replacing.
1777 ip = computeInsertPoint(value);
1778 } else {
1779 // Compute the "earliest" insertion point at which all values in `repl` are
1780 // defined. It is important to emit the materialization at that location
1781 // because the same materialization may be reused in a different context.
1782 // (That's because materializations are cached in the conversion value
1783 // mapping.) The insertion point of the materialization must be valid for
1784 // all future users that may be created later in the conversion process.
1785 ip = computeInsertPoint(repl);
1786 }
1788 MaterializationKind::Source, ip, value.getLoc(),
1789 /*valuesToMap=*/repl, /*inputs=*/repl,
1790 /*outputTypes=*/value.getType(),
1791 /*originalType=*/Type(), converter,
1792 /*isPureTypeConversion=*/!repl.empty())
1793 .front();
1794 return castValue;
1795}
1796
1797//===----------------------------------------------------------------------===//
1798// Rewriter Notification Hooks
1799//===----------------------------------------------------------------------===//
1800
1802 Operation *op, OpBuilder::InsertPoint previous) {
1803 // If no previous insertion point is provided, the op used to be detached.
1804 bool wasDetached = !previous.isSet();
1805 LLVM_DEBUG({
1806 logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
1807 << ")";
1808 if (wasDetached)
1809 logger.getOStream() << " (was detached)";
1810 logger.getOStream() << "\n";
1811 });
1812
1813 // In rollback mode, it is easier to misuse the API, so perform extra error
1814 // checking.
1815 assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
1816 "attempting to insert into a block within a replaced/erased op");
1817
1818 // In "no rollback" mode, the listener is always notified immediately.
1819 if (!config.allowPatternRollback && config.listener)
1820 config.listener->notifyOperationInserted(op, previous);
1821
1822 if (wasDetached) {
1823 // If the op was detached, it is most likely a newly created op. Add it the
1824 // set of newly created ops, so that it will be legalized. If this op is
1825 // not a newly created op, it will be legalized a second time, which is
1826 // inefficient but harmless.
1827 patternNewOps.insert(op);
1828
1829 if (config.allowPatternRollback) {
1830 // TODO: If the same op is inserted multiple times from a detached
1831 // state, the rollback mechanism may erase the same op multiple times.
1832 // This is a bug in the rollback-based dialect conversion driver.
1834 } else {
1835 // In "no rollback" mode, there is an extra data structure for tracking
1836 // erased operations that must be kept up to date.
1837 erasedOps.erase(op);
1838 }
1839 return;
1840 }
1841
1842 // The op was moved from one place to another.
1843 if (config.allowPatternRollback)
1845}
1846
1847/// Given that `fromRange` is about to be replaced with `toRange`, compute
1848/// replacement values with the types of `fromRange`.
1849static SmallVector<Value>
1851 const SmallVector<SmallVector<Value>> &toRange,
1852 const TypeConverter *converter) {
1853 assert(!impl.config.allowPatternRollback &&
1854 "this code path is valid only in 'no rollback' mode");
1855 SmallVector<Value> repls;
1856 for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
1857 if (from.use_empty()) {
1858 // The replaced value is dead. No replacement value is needed.
1859 repls.push_back(Value());
1860 continue;
1861 }
1862
1863 if (to.empty()) {
1864 // The replaced value is dropped. Materialize a replacement value "out of
1865 // thin air".
1866 Value srcMat = impl.buildUnresolvedMaterialization(
1867 MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
1868 /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1869 /*outputTypes=*/from.getType(), /*originalType=*/Type(),
1870 converter)[0];
1871 repls.push_back(srcMat);
1872 continue;
1873 }
1874
1875 if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
1876 // The replacement value already has the correct type. Use it directly.
1877 repls.push_back(to[0]);
1878 continue;
1879 }
1880
1881 // The replacement value has the wrong type. Build a source materialization
1882 // to the original type.
1883 // TODO: This is a bit inefficient. We should try to reuse existing
1884 // materializations if possible. This would require an extension of the
1885 // `lookupOrDefault` API.
1886 Value srcMat = impl.buildUnresolvedMaterialization(
1887 MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
1888 /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
1889 /*originalType=*/Type(), converter)[0];
1890 repls.push_back(srcMat);
1891 }
1892
1893 return repls;
1894}
1895
1897 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
1898 assert(newValues.size() == op->getNumResults() &&
1899 "incorrect number of replacement values");
1900 LLVM_DEBUG({
1901 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
1902 << ")\n";
1904 // If the user-provided replacement types are different from the
1905 // legalized types, as per the current type converter, print a note.
1906 // In most cases, the replacement types are expected to match the types
1907 // produced by the type converter, so this could indicate a bug in the
1908 // user code.
1909 for (auto [result, repls] :
1910 llvm::zip_equal(op->getResults(), newValues)) {
1911 Type resultType = result.getType();
1912 auto logProlog = [&, repls = repls]() {
1913 logger.startLine() << " Note: Replacing op result of type "
1914 << resultType << " with value(s) of type (";
1915 llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) {
1916 logger.getOStream() << v.getType();
1917 });
1918 logger.getOStream() << ")";
1919 };
1920 SmallVector<Type> convertedTypes;
1921 if (failed(currentTypeConverter->convertTypes(resultType,
1922 convertedTypes))) {
1923 logProlog();
1924 logger.getOStream() << ", but the type converter failed to legalize "
1925 "the original type.\n";
1926 continue;
1927 }
1928 if (TypeRange(convertedTypes) != TypeRange(ValueRange(repls))) {
1929 logProlog();
1930 logger.getOStream() << ", but the legalized type(s) is/are (";
1931 llvm::interleaveComma(convertedTypes, logger.getOStream(),
1932 [&](Type t) { logger.getOStream() << t; });
1933 logger.getOStream() << ")\n";
1934 }
1935 }
1936 }
1937 });
1938
1939 if (!config.allowPatternRollback) {
1940 // Pattern rollback is not allowed: materialize all IR changes immediately.
1942 *this, op->getResults(), newValues, currentTypeConverter);
1943 // Update internal data structures, so that there are no dangling pointers
1944 // to erased IR.
1945 op->walk([&](Operation *op) {
1946 erasedOps.insert(op);
1947 ignoredOps.remove(op);
1948 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1949 unresolvedMaterializations.erase(castOp);
1950 patternMaterializations.erase(castOp);
1951 }
1952 // The original op will be erased, so remove it from the set of
1953 // unlegalized ops.
1954 if (config.unlegalizedOps)
1955 config.unlegalizedOps->erase(op);
1956 });
1957 op->walk([&](Block *block) { erasedBlocks.insert(block); });
1958 // Replace the op with the replacement values and notify the listener.
1959 notifyingRewriter.replaceOp(op, repls);
1960 return;
1961 }
1962
1963 assert(!ignoredOps.contains(op) && "operation was already replaced");
1964#ifndef NDEBUG
1965 for (Value v : op->getResults())
1966 assert(!replacedValues.contains(v) &&
1967 "attempting to replace a value that was already replaced");
1968#endif // NDEBUG
1969
1970 // Check if replaced op is an unresolved materialization, i.e., an
1971 // unrealized_conversion_cast op that was created by the conversion driver.
1972 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1973 // Make sure that the user does not mess with unresolved materializations
1974 // that were inserted by the conversion driver. We keep track of these
1975 // ops in internal data structures.
1976 assert(!unresolvedMaterializations.contains(castOp) &&
1977 "attempting to replace/erase an unresolved materialization");
1978 }
1979
1980 // Create mappings for each of the new result values.
1981 for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
1982 mapping.map(static_cast<Value>(result), std::move(repl));
1983
1985 // Mark this operation and all nested ops as replaced.
1986 op->walk([&](Operation *op) { replacedOps.insert(op); });
1987}
1988
1990 Value from, ValueRange to, const TypeConverter *converter,
1991 function_ref<bool(OpOperand &)> functor) {
1992 LLVM_DEBUG({
1993 logger.startLine() << "** Replace Value : '" << from << "'";
1994 if (auto blockArg = dyn_cast<BlockArgument>(from)) {
1995 if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
1996 logger.getOStream() << " (in region of '" << parentOp->getName()
1997 << "' (" << parentOp << ")";
1998 } else {
1999 logger.getOStream() << " (unlinked block)";
2000 }
2001 }
2002 if (functor) {
2003 logger.getOStream() << ", conditional replacement";
2004 }
2005 });
2006
2007 if (!config.allowPatternRollback) {
2008 SmallVector<Value> toConv = llvm::to_vector(to);
2009 SmallVector<Value> repls =
2010 getReplacementValues(*this, from, {toConv}, converter);
2011 IRRewriter r(from.getContext());
2012 Value repl = repls.front();
2013 if (!repl)
2014 return;
2015
2016 performReplaceValue(r, from, repl, functor);
2017 return;
2018 }
2019
2020#ifndef NDEBUG
2021 // Make sure that a value is not replaced multiple times. In rollback mode,
2022 // `replaceAllUsesWith` replaces not only all current uses of the given value,
2023 // but also all future uses that may be introduced by future pattern
2024 // applications. Therefore, it does not make sense to call
2025 // `replaceAllUsesWith` multiple times with the same value. Doing so would
2026 // overwrite the mapping and mess with the internal state of the dialect
2027 // conversion driver.
2028 assert(!replacedValues.contains(from) &&
2029 "attempting to replace a value that was already replaced");
2030 assert(!wasOpReplaced(from.getDefiningOp()) &&
2031 "attempting to replace a op result that was already replaced");
2032 replacedValues.insert(from);
2033#endif // NDEBUG
2034
2035 if (functor)
2036 llvm::reportFatalInternalError(
2037 "conditional value replacement is not supported in rollback mode");
2038 mapping.map(from, to);
2039 appendRewrite<ReplaceValueRewrite>(from, converter);
2040}
2041
2043 if (!config.allowPatternRollback) {
2044 // Pattern rollback is not allowed: materialize all IR changes immediately.
2045 // Update internal data structures, so that there are no dangling pointers
2046 // to erased IR.
2047 block->walk([&](Operation *op) {
2048 erasedOps.insert(op);
2049 ignoredOps.remove(op);
2050 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
2051 unresolvedMaterializations.erase(castOp);
2052 patternMaterializations.erase(castOp);
2053 }
2054 // The original op will be erased, so remove it from the set of
2055 // unlegalized ops.
2056 if (config.unlegalizedOps)
2057 config.unlegalizedOps->erase(op);
2058 });
2059 block->walk([&](Block *block) { erasedBlocks.insert(block); });
2060 // Erase the block and notify the listener.
2061 notifyingRewriter.eraseBlock(block);
2062 return;
2063 }
2064
2065 assert(!wasOpReplaced(block->getParentOp()) &&
2066 "attempting to erase a block within a replaced/erased op");
2068
2069 // Unlink the block from its parent region. The block is kept in the rewrite
2070 // object and will be actually destroyed when rewrites are applied. This
2071 // allows us to keep the operations in the block live and undo the removal by
2072 // re-inserting the block.
2073 block->getParent()->getBlocks().remove(block);
2074
2075 // Mark all nested ops as erased.
2076 block->walk([&](Operation *op) { replacedOps.insert(op); });
2077}
2078
2080 Block *block, Region *previous, Region::iterator previousIt) {
2081 // If no previous insertion point is provided, the block used to be detached.
2082 bool wasDetached = !previous;
2083 Operation *newParentOp = block->getParentOp();
2084 LLVM_DEBUG(
2085 {
2086 Operation *parent = newParentOp;
2087 if (parent) {
2088 logger.startLine() << "** Insert Block into : '" << parent->getName()
2089 << "' (" << parent << ")";
2090 } else {
2091 logger.startLine()
2092 << "** Insert Block into detached Region (nullptr parent op)";
2093 }
2094 if (wasDetached)
2095 logger.getOStream() << " (was detached)";
2096 logger.getOStream() << "\n";
2097 });
2098
2099 // In rollback mode, it is easier to misuse the API, so perform extra error
2100 // checking.
2101 assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
2102 "attempting to insert into a region within a replaced/erased op");
2103 (void)newParentOp;
2104
2105 // In "no rollback" mode, the listener is always notified immediately.
2106 if (!config.allowPatternRollback && config.listener)
2107 config.listener->notifyBlockInserted(block, previous, previousIt);
2108
2109 if (wasDetached) {
2110 // If the block was detached, it is most likely a newly created block.
2111 if (config.allowPatternRollback) {
2112 // TODO: If the same block is inserted multiple times from a detached
2113 // state, the rollback mechanism may erase the same block multiple times.
2114 // This is a bug in the rollback-based dialect conversion driver.
2116 } else {
2117 // In "no rollback" mode, there is an extra data structure for tracking
2118 // erased blocks that must be kept up to date.
2119 erasedBlocks.erase(block);
2120 }
2121 return;
2122 }
2123
2124 // The block was moved from one place to another.
2125 if (config.allowPatternRollback)
2126 appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
2127}
2128
2130 Block *dest,
2131 Block::iterator before) {
2132 appendRewrite<InlineBlockRewrite>(dest, source, before);
2133}
2134
2136 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
2137 LLVM_DEBUG({
2139 reasonCallback(diag);
2140 logger.startLine() << "** Failure : " << diag.str() << "\n";
2141 if (config.notifyCallback)
2142 config.notifyCallback(diag);
2143 });
2144}
2145
2146//===----------------------------------------------------------------------===//
2147// ConversionPatternRewriter
2148//===----------------------------------------------------------------------===//
2149
2150ConversionPatternRewriter::ConversionPatternRewriter(
2151 MLIRContext *ctx, const ConversionConfig &config,
2152 OperationConverter &opConverter)
2154 *this, config, opConverter)) {
2155 setListener(impl.get());
2156}
2157
2158ConversionPatternRewriter::~ConversionPatternRewriter() = default;
2159
2160const ConversionConfig &ConversionPatternRewriter::getConfig() const {
2161 return impl->config;
2162}
2163
2164void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
2165 assert(op && newOp && "expected non-null op");
2166 replaceOp(op, newOp->getResults());
2167}
2168
2169void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
2170 assert(op->getNumResults() == newValues.size() &&
2171 "incorrect # of replacement values");
2172
2173 // If the current insertion point is before the erased operation, we adjust
2174 // the insertion point to be after the operation.
2175 if (getInsertionPoint() == op->getIterator())
2177
2178 SmallVector<SmallVector<Value>> newVals =
2179 llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
2180 return v ? SmallVector<Value>{v} : SmallVector<Value>();
2181 });
2182 impl->replaceOp(op, std::move(newVals));
2183}
2184
2185void ConversionPatternRewriter::replaceOpWithMultiple(
2186 Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
2187 assert(op->getNumResults() == newValues.size() &&
2188 "incorrect # of replacement values");
2189
2190 // If the current insertion point is before the erased operation, we adjust
2191 // the insertion point to be after the operation.
2192 if (getInsertionPoint() == op->getIterator())
2194
2195 impl->replaceOp(op, std::move(newValues));
2196}
2197
2198void ConversionPatternRewriter::eraseOp(Operation *op) {
2199 LLVM_DEBUG({
2200 impl->logger.startLine()
2201 << "** Erase : '" << op->getName() << "'(" << op << ")\n";
2202 });
2203
2204 // If the current insertion point is before the erased operation, we adjust
2205 // the insertion point to be after the operation.
2206 if (getInsertionPoint() == op->getIterator())
2208
2209 SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
2210 impl->replaceOp(op, std::move(nullRepls));
2211}
2212
2213void ConversionPatternRewriter::eraseBlock(Block *block) {
2214 impl->eraseBlock(block);
2215}
2216
2217Block *ConversionPatternRewriter::applySignatureConversion(
2218 Block *block, TypeConverter::SignatureConversion &conversion,
2219 const TypeConverter *converter) {
2220 assert(!impl->wasOpReplaced(block->getParentOp()) &&
2221 "attempting to apply a signature conversion to a block within a "
2222 "replaced/erased op");
2223 return impl->applySignatureConversion(block, converter, conversion);
2224}
2225
2226FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
2227 Region *region, const TypeConverter &converter,
2228 TypeConverter::SignatureConversion *entryConversion) {
2229 assert(!impl->wasOpReplaced(region->getParentOp()) &&
2230 "attempting to apply a signature conversion to a block within a "
2231 "replaced/erased op");
2232 return impl->convertRegionTypes(region, converter, entryConversion);
2233}
2234
2235void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
2236 impl->replaceValueUses(from, to, impl->currentTypeConverter);
2237}
2238
2239void ConversionPatternRewriter::replaceUsesWithIf(
2240 Value from, ValueRange to, function_ref<bool(OpOperand &)> functor,
2241 bool *allUsesReplaced) {
2242 assert(!allUsesReplaced &&
2243 "allUsesReplaced is not supported in a dialect conversion");
2244 impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
2245}
2246
2247Value ConversionPatternRewriter::getRemappedValue(Value key) {
2248 SmallVector<ValueVector> remappedValues;
2249 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
2250 remappedValues)))
2251 return nullptr;
2252 assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
2253 return remappedValues.front().front();
2254}
2255
2256LogicalResult
2257ConversionPatternRewriter::getRemappedValues(ValueRange keys,
2258 SmallVectorImpl<Value> &results) {
2259 if (keys.empty())
2260 return success();
2261 SmallVector<ValueVector> remapped;
2262 if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
2263 remapped)))
2264 return failure();
2265 for (const auto &values : remapped) {
2266 assert(values.size() == 1 && "1:N conversion not supported");
2267 results.push_back(values.front());
2268 }
2269 return success();
2270}
2271
2272void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
2273 Block::iterator before,
2274 ValueRange argValues) {
2275#ifndef NDEBUG
2276 assert(argValues.size() == source->getNumArguments() &&
2277 "incorrect # of argument replacement values");
2278 assert(!impl->wasOpReplaced(source->getParentOp()) &&
2279 "attempting to inline a block from a replaced/erased op");
2280 assert(!impl->wasOpReplaced(dest->getParentOp()) &&
2281 "attempting to inline a block into a replaced/erased op");
2282 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
2283 // The source block will be deleted, so it should not have any users (i.e.,
2284 // there should be no predecessors).
2285 assert(llvm::all_of(source->getUsers(), opIgnored) &&
2286 "expected 'source' to have no predecessors");
2287#endif // NDEBUG
2288
2289 // If a listener is attached to the dialect conversion, ops cannot be moved
2290 // to the destination block in bulk ("fast path"). This is because at the time
2291 // the notifications are sent, it is unknown which ops were moved. Instead,
2292 // ops should be moved one-by-one ("slow path"), so that a separate
2293 // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
2294 // a bit more efficient, so we try to do that when possible.
2295 bool fastPath = !getConfig().listener;
2296
2297 if (fastPath && impl->config.allowPatternRollback)
2298 impl->inlineBlockBefore(source, dest, before);
2299
2300 // Replace all uses of block arguments.
2301 for (auto it : llvm::zip(source->getArguments(), argValues))
2302 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
2303
2304 if (fastPath) {
2305 // Move all ops at once.
2306 dest->getOperations().splice(before, source->getOperations());
2307 } else {
2308 // Move op by op.
2309 while (!source->empty())
2310 moveOpBefore(&source->front(), dest, before);
2311 }
2312
2313 // If the current insertion point is within the source block, adjust the
2314 // insertion point to the destination block.
2315 if (getInsertionBlock() == source)
2316 setInsertionPoint(dest, getInsertionPoint());
2317
2318 // Erase the source block.
2319 eraseBlock(source);
2320}
2321
2322void ConversionPatternRewriter::startOpModification(Operation *op) {
2323 if (!impl->config.allowPatternRollback) {
2324 // Pattern rollback is not allowed: no extra bookkeeping is needed.
2326 return;
2327 }
2328 assert(!impl->wasOpReplaced(op) &&
2329 "attempting to modify a replaced/erased op");
2330#ifndef NDEBUG
2331 impl->pendingRootUpdates.insert(op);
2332#endif
2333 impl->appendRewrite<ModifyOperationRewrite>(op);
2334}
2335
2336void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
2337 impl->patternModifiedOps.insert(op);
2338 if (!impl->config.allowPatternRollback) {
2340 if (getConfig().listener)
2341 getConfig().listener->notifyOperationModified(op);
2342 return;
2343 }
2344
2345 // There is nothing to do here, we only need to track the operation at the
2346 // start of the update.
2347#ifndef NDEBUG
2348 assert(!impl->wasOpReplaced(op) &&
2349 "attempting to modify a replaced/erased op");
2350 assert(impl->pendingRootUpdates.erase(op) &&
2351 "operation did not have a pending in-place update");
2352#endif
2353}
2354
2355void ConversionPatternRewriter::cancelOpModification(Operation *op) {
2356 if (!impl->config.allowPatternRollback) {
2358 return;
2359 }
2360#ifndef NDEBUG
2361 assert(impl->pendingRootUpdates.erase(op) &&
2362 "operation did not have a pending in-place update");
2363#endif
2364 // Erase the last update for this operation.
2365 auto it = llvm::find_if(
2366 llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
2367 auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
2368 return modifyRewrite && modifyRewrite->getOperation() == op;
2369 });
2370 assert(it != impl->rewrites.rend() && "no root update started on op");
2371 (*it)->rollback();
2372 int updateIdx = std::prev(impl->rewrites.rend()) - it;
2373 impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
2374}
2375
2376detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
2377 return *impl;
2378}
2379
2380//===----------------------------------------------------------------------===//
2381// ConversionPattern
2382//===----------------------------------------------------------------------===//
2383
2384FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
2385 ArrayRef<ValueRange> operands) const {
2386 SmallVector<Value> oneToOneOperands;
2387 oneToOneOperands.reserve(operands.size());
2388 for (ValueRange operand : operands) {
2389 if (operand.size() != 1)
2390 return failure();
2391
2392 oneToOneOperands.push_back(operand.front());
2393 }
2394 return std::move(oneToOneOperands);
2395}
2396
2397LogicalResult
2398ConversionPattern::matchAndRewrite(Operation *op,
2399 PatternRewriter &rewriter) const {
2400 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
2401 auto &rewriterImpl = dialectRewriter.getImpl();
2402
2403 // Track the current conversion pattern type converter in the rewriter.
2404 llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
2405 getTypeConverter());
2406
2407 // Remap the operands of the operation.
2408 SmallVector<ValueVector> remapped;
2409 if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
2410 op->getOperands(), remapped))) {
2411 return failure();
2412 }
2413 SmallVector<ValueRange> remappedAsRange =
2414 llvm::to_vector_of<ValueRange>(remapped);
2415 return matchAndRewrite(op, remappedAsRange, dialectRewriter);
2416}
2417
2418//===----------------------------------------------------------------------===//
2419// OperationLegalizer
2420//===----------------------------------------------------------------------===//
2421
2422namespace {
2423/// A set of rewrite patterns that can be used to legalize a given operation.
2424using LegalizationPatterns = SmallVector<const Pattern *, 1>;
2425
2426/// This class defines a recursive operation legalizer.
2427class OperationLegalizer {
2428public:
2429 using LegalizationAction = ConversionTarget::LegalizationAction;
2430
2431 OperationLegalizer(ConversionPatternRewriter &rewriter,
2432 const ConversionTarget &targetInfo,
2433 const FrozenRewritePatternSet &patterns);
2434
2435 /// Returns true if the given operation is known to be illegal on the target.
2436 bool isIllegal(Operation *op) const;
2437
2438 /// Attempt to legalize the given operation. Returns success if the operation
2439 /// was legalized, failure otherwise.
2440 LogicalResult legalize(Operation *op);
2441
2442 /// Returns the conversion target in use by the legalizer.
2443 const ConversionTarget &getTarget() { return target; }
2444
2445private:
2446 /// Attempt to legalize the given operation by folding it.
2447 LogicalResult legalizeWithFold(Operation *op);
2448
2449 /// Attempt to legalize the given operation by applying a pattern. Returns
2450 /// success if the operation was legalized, failure otherwise.
2451 LogicalResult legalizeWithPattern(Operation *op);
2452
2453 /// Return true if the given pattern may be applied to the given operation,
2454 /// false otherwise.
2455 bool canApplyPattern(Operation *op, const Pattern &pattern);
2456
2457 /// Legalize the resultant IR after successfully applying the given pattern.
2458 LogicalResult
2459 legalizePatternResult(Operation *op, const Pattern &pattern,
2460 const RewriterState &curState,
2461 const SetVector<Operation *> &newOps,
2462 const SetVector<Operation *> &modifiedOps);
2463
2464 LogicalResult
2465 legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
2466 LogicalResult
2467 legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
2468
2469 //===--------------------------------------------------------------------===//
2470 // Cost Model
2471 //===--------------------------------------------------------------------===//
2472
2473 /// Build an optimistic legalization graph given the provided patterns. This
2474 /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
2475 /// patterns for operations that are not directly legal, but may be
2476 /// transitively legal for the current target given the provided patterns.
2477 void buildLegalizationGraph(
2478 LegalizationPatterns &anyOpLegalizerPatterns,
2480
2481 /// Compute the benefit of each node within the computed legalization graph.
2482 /// This orders the patterns within 'legalizerPatterns' based upon two
2483 /// criteria:
2484 /// 1) Prefer patterns that have the lowest legalization depth, i.e.
2485 /// represent the more direct mapping to the target.
2486 /// 2) When comparing patterns with the same legalization depth, prefer the
2487 /// pattern with the highest PatternBenefit. This allows for users to
2488 /// prefer specific legalizations over others.
2489 void computeLegalizationGraphBenefit(
2490 LegalizationPatterns &anyOpLegalizerPatterns,
2492
2493 /// Compute the legalization depth when legalizing an operation of the given
2494 /// type.
2495 unsigned computeOpLegalizationDepth(
2496 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
2498
2499 /// Apply the conversion cost model to the given set of patterns, and return
2500 /// the smallest legalization depth of any of the patterns. See
2501 /// `computeLegalizationGraphBenefit` for the breakdown of the cost model.
2502 unsigned applyCostModelToPatterns(
2503 LegalizationPatterns &patterns,
2504 DenseMap<OperationName, unsigned> &minOpPatternDepth,
2506
2507 /// The current set of patterns that have been applied.
2508 SmallPtrSet<const Pattern *, 8> appliedPatterns;
2509
2510 /// The rewriter to use when converting operations.
2511 ConversionPatternRewriter &rewriter;
2512
2513 /// The legalization information provided by the target.
2514 const ConversionTarget &target;
2515
2516 /// The pattern applicator to use for conversions.
2517 PatternApplicator applicator;
2518};
2519} // namespace
2520
2521OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
2522 const ConversionTarget &targetInfo,
2523 const FrozenRewritePatternSet &patterns)
2524 : rewriter(rewriter), target(targetInfo), applicator(patterns) {
2525 // The set of patterns that can be applied to illegal operations to transform
2526 // them into legal ones.
2528 LegalizationPatterns anyOpLegalizerPatterns;
2529
2530 buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns);
2531 computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns);
2532}
2533
2534bool OperationLegalizer::isIllegal(Operation *op) const {
2535 return target.isIllegal(op);
2536}
2537
2538LogicalResult OperationLegalizer::legalize(Operation *op) {
2539#ifndef NDEBUG
2540 const char *logLineComment =
2541 "//===-------------------------------------------===//\n";
2542
2543 auto &logger = rewriter.getImpl().logger;
2544#endif
2545
2546 // Check to see if the operation is ignored and doesn't need to be converted.
2547 bool isIgnored = rewriter.getImpl().isOpIgnored(op);
2548
2549 LLVM_DEBUG({
2550 logger.getOStream() << "\n";
2551 logger.startLine() << logLineComment;
2552 logger.startLine() << "Legalizing operation : ";
2553 // Do not print the operation name if the operation is ignored. Ignored ops
2554 // may have been erased and should not be accessed. The pointer can be
2555 // printed safely.
2556 if (!isIgnored)
2557 logger.getOStream() << "'" << op->getName() << "' ";
2558 logger.getOStream() << "(" << op << ") {\n";
2559 logger.indent();
2560
2561 // If the operation has no regions, just print it here.
2562 if (!isIgnored && op->getNumRegions() == 0) {
2563 logger.startLine() << OpWithFlags(op,
2564 OpPrintingFlags().printGenericOpForm())
2565 << "\n";
2566 }
2567 });
2568
2569 if (isIgnored) {
2570 LLVM_DEBUG({
2571 logSuccess(logger, "operation marked 'ignored' during conversion");
2572 logger.startLine() << logLineComment;
2573 });
2574 return success();
2575 }
2576
2577 // Check if this operation is legal on the target.
2578 if (auto legalityInfo = target.isLegal(op)) {
2579 LLVM_DEBUG({
2580 logSuccess(
2581 logger, "operation marked legal by the target{0}",
2582 legalityInfo->isRecursivelyLegal
2583 ? "; NOTE: operation is recursively legal; skipping internals"
2584 : "");
2585 logger.startLine() << logLineComment;
2586 });
2587
2588 // If this operation is recursively legal, mark its children as ignored so
2589 // that we don't consider them for legalization.
2590 if (legalityInfo->isRecursivelyLegal) {
2591 op->walk([&](Operation *nested) {
2592 if (op != nested)
2593 rewriter.getImpl().ignoredOps.insert(nested);
2594 });
2595 }
2596
2597 return success();
2598 }
2599
2600 // If the operation is not legal, try to fold it in-place if the folding mode
2601 // is 'BeforePatterns'. 'Never' will skip this.
2602 const ConversionConfig &config = rewriter.getConfig();
2603 if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
2604 if (succeeded(legalizeWithFold(op))) {
2605 LLVM_DEBUG({
2606 logSuccess(logger, "operation was folded");
2607 logger.startLine() << logLineComment;
2608 });
2609 return success();
2610 }
2611 }
2612
2613 // Otherwise, we need to apply a legalization pattern to this operation.
2614 if (succeeded(legalizeWithPattern(op))) {
2615 LLVM_DEBUG({
2616 logSuccess(logger, "");
2617 logger.startLine() << logLineComment;
2618 });
2619 return success();
2620 }
2621
2622 // If the operation can't be legalized via patterns, try to fold it in-place
2623 // if the folding mode is 'AfterPatterns'.
2624 if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
2625 if (succeeded(legalizeWithFold(op))) {
2626 LLVM_DEBUG({
2627 logSuccess(logger, "operation was folded");
2628 logger.startLine() << logLineComment;
2629 });
2630 return success();
2631 }
2632 }
2633
2634 LLVM_DEBUG({
2635 logFailure(logger, "no matched legalization pattern");
2636 logger.startLine() << logLineComment;
2637 });
2638 return failure();
2639}
2640
2641/// Helper function that moves and returns the given object. Also resets the
2642/// original object, so that it is in a valid, empty state again.
2643template <typename T>
2644static T moveAndReset(T &obj) {
2645 T result = std::move(obj);
2646 obj = T();
2647 return result;
2648}
2649
2650LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
2651 auto &rewriterImpl = rewriter.getImpl();
2652 LLVM_DEBUG({
2653 rewriterImpl.logger.startLine() << "* Fold {\n";
2654 rewriterImpl.logger.indent();
2655 });
2656
2657 // Clear pattern state, so that the next pattern application starts with a
2658 // clean slate. (The op/block sets are populated by listener notifications.)
2659 llvm::scope_exit cleanup([&]() {
2660 rewriterImpl.patternNewOps.clear();
2661 rewriterImpl.patternModifiedOps.clear();
2662 });
2663
2664 // Upon failure, undo all changes made by the folder.
2665 RewriterState curState = rewriterImpl.getCurrentState();
2666
2667 // Try to fold the operation.
2668 StringRef opName = op->getName().getStringRef();
2669 SmallVector<Value, 2> replacementValues;
2670 SmallVector<Operation *, 2> newOps;
2671 rewriter.setInsertionPoint(op);
2672 rewriter.startOpModification(op);
2673 if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
2674 LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
2675 rewriter.cancelOpModification(op);
2676 return failure();
2677 }
2678 rewriter.finalizeOpModification(op);
2679
2680 // An empty list of replacement values indicates that the fold was in-place.
2681 // As the operation changed, a new legalization needs to be attempted.
2682 if (replacementValues.empty())
2683 return legalize(op);
2684
2685 // Insert a replacement for 'op' with the folded replacement values.
2686 rewriter.replaceOp(op, replacementValues);
2687
2688 // Recursively legalize any new constant operations.
2689 for (Operation *newOp : newOps) {
2690 if (failed(legalize(newOp))) {
2691 LLVM_DEBUG(logFailure(rewriterImpl.logger,
2692 "failed to legalize generated constant '{0}'",
2693 newOp->getName()));
2694 if (!rewriter.getConfig().allowPatternRollback) {
2695 // Rolling back a folder is like rolling back a pattern.
2696 llvm::reportFatalInternalError(
2697 "op '" + opName +
2698 "' folder rollback of IR modifications requested");
2699 }
2700 rewriterImpl.resetState(
2701 curState, std::string(op->getName().getStringRef()) + " folder");
2702 return failure();
2703 }
2704 }
2705
2706 LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
2707 return success();
2708}
2709
2710/// Report a fatal error indicating that newly produced or modified IR could
2711/// not be legalized.
2712static void
2714 const SetVector<Operation *> &newOps,
2715 const SetVector<Operation *> &modifiedOps) {
2716 auto newOpNames = llvm::map_range(
2717 newOps, [](Operation *op) { return op->getName().getStringRef(); });
2718 auto modifiedOpNames = llvm::map_range(
2719 modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
2720 llvm::reportFatalInternalError("pattern '" + pattern.getDebugName() +
2721 "' produced IR that could not be legalized. " +
2722 "new ops: {" + llvm::join(newOpNames, ", ") +
2723 "}, " + "modified ops: {" +
2724 llvm::join(modifiedOpNames, ", ") + "}");
2725}
2726
2727LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
2728 auto &rewriterImpl = rewriter.getImpl();
2729 const ConversionConfig &config = rewriter.getConfig();
2730
2731#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2732 Operation *checkOp;
2733 std::optional<OperationFingerPrint> topLevelFingerPrint;
2734 if (!rewriterImpl.config.allowPatternRollback) {
2735 // The op may be getting erased, so we have to check the parent op.
2736 // (In rare cases, a pattern may even erase the parent op, which will cause
2737 // a crash here. Expensive checks are "best effort".) Skip the check if the
2738 // op does not have a parent op.
2739 if ((checkOp = op->getParentOp())) {
2740 if (!op->getContext()->isMultithreadingEnabled()) {
2741 topLevelFingerPrint = OperationFingerPrint(checkOp);
2742 } else {
2743 // Another thread may be modifying a sibling operation. Therefore, the
2744 // fingerprinting mechanism of the parent op works only in
2745 // single-threaded mode.
2746 LLVM_DEBUG({
2747 rewriterImpl.logger.startLine()
2748 << "WARNING: Multi-threadeding is enabled. Some dialect "
2749 "conversion expensive checks are skipped in multithreading "
2750 "mode!\n";
2751 });
2752 }
2753 }
2754 }
2755#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2756
2757 // Functor that returns if the given pattern may be applied.
2758 auto canApply = [&](const Pattern &pattern) {
2759 bool canApply = canApplyPattern(op, pattern);
2760 if (canApply && config.listener)
2761 config.listener->notifyPatternBegin(pattern, op);
2762 return canApply;
2763 };
2764
2765 // Functor that cleans up the rewriter state after a pattern failed to match.
2766 RewriterState curState = rewriterImpl.getCurrentState();
2767 auto onFailure = [&](const Pattern &pattern) {
2768 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2769 if (!rewriterImpl.config.allowPatternRollback) {
2770 // Erase all unresolved materializations.
2771 for (auto op : rewriterImpl.patternMaterializations) {
2772 rewriterImpl.unresolvedMaterializations.erase(op);
2773 op.erase();
2774 }
2775 rewriterImpl.patternMaterializations.clear();
2776#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2777 // Expensive pattern check that can detect API violations.
2778 if (checkOp && topLevelFingerPrint) {
2779 OperationFingerPrint fingerPrintAfterPattern(checkOp);
2780 if (fingerPrintAfterPattern != *topLevelFingerPrint)
2781 llvm::reportFatalInternalError(
2782 "pattern '" + pattern.getDebugName() +
2783 "' returned failure but IR did change");
2784 }
2785#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2786 }
2787 rewriterImpl.patternNewOps.clear();
2788 rewriterImpl.patternModifiedOps.clear();
2789 LLVM_DEBUG({
2790 logFailure(rewriterImpl.logger, "pattern failed to match");
2791 if (rewriterImpl.config.notifyCallback) {
2792 Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
2793 diag << "Failed to apply pattern \"" << pattern.getDebugName()
2794 << "\" on op:\n"
2795 << *op;
2796 rewriterImpl.config.notifyCallback(diag);
2797 }
2798 });
2799 if (config.listener)
2800 config.listener->notifyPatternEnd(pattern, failure());
2801 rewriterImpl.resetState(curState, pattern.getDebugName());
2802 appliedPatterns.erase(&pattern);
2803 };
2804
2805 // Functor that performs additional legalization when a pattern is
2806 // successfully applied.
2807 auto onSuccess = [&](const Pattern &pattern) {
2808 assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2809 if (!rewriterImpl.config.allowPatternRollback) {
2810 // Eagerly erase unused materializations.
2811 for (auto op : rewriterImpl.patternMaterializations) {
2812 if (op->use_empty()) {
2813 rewriterImpl.unresolvedMaterializations.erase(op);
2814 op.erase();
2815 }
2816 }
2817 rewriterImpl.patternMaterializations.clear();
2818 }
2819 SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2820 SetVector<Operation *> modifiedOps =
2821 moveAndReset(rewriterImpl.patternModifiedOps);
2822 auto result =
2823 legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
2824 appliedPatterns.erase(&pattern);
2825 if (failed(result)) {
2826 if (!rewriterImpl.config.allowPatternRollback)
2827 reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
2828 rewriterImpl.resetState(curState, pattern.getDebugName());
2829 }
2830 if (config.listener)
2831 config.listener->notifyPatternEnd(pattern, result);
2832 return result;
2833 };
2834
2835 // Try to match and rewrite a pattern on this operation.
2836 return applicator.matchAndRewrite(op, rewriter, canApply, onFailure,
2837 onSuccess);
2838}
2839
2840bool OperationLegalizer::canApplyPattern(Operation *op,
2841 const Pattern &pattern) {
2842 LLVM_DEBUG({
2843 auto &os = rewriter.getImpl().logger;
2844 os.getOStream() << "\n";
2845 os.startLine() << "* Pattern : '" << op->getName() << " -> (";
2846 llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
2847 os.getOStream() << ")' {\n";
2848 os.indent();
2849 });
2850
2851 // Ensure that we don't cycle by not allowing the same pattern to be
2852 // applied twice in the same recursion stack if it is not known to be safe.
2853 if (!pattern.hasBoundedRewriteRecursion() &&
2854 !appliedPatterns.insert(&pattern).second) {
2855 LLVM_DEBUG(
2856 logFailure(rewriter.getImpl().logger, "pattern was already applied"));
2857 return false;
2858 }
2859 return true;
2860}
2861
2862LogicalResult OperationLegalizer::legalizePatternResult(
2863 Operation *op, const Pattern &pattern, const RewriterState &curState,
2864 const SetVector<Operation *> &newOps,
2865 const SetVector<Operation *> &modifiedOps) {
2866 [[maybe_unused]] auto &impl = rewriter.getImpl();
2867 assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2868
2869#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2870 if (impl.config.allowPatternRollback) {
2871 // Check that the root was either replaced or updated in place.
2872 auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
2873 auto replacedRoot = [&] {
2874 return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
2875 };
2876 auto updatedRootInPlace = [&] {
2877 return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
2878 };
2879 if (!replacedRoot() && !updatedRootInPlace())
2880 llvm::reportFatalInternalError(
2881 "expected pattern to replace the root operation "
2882 "or modify it in place");
2883 }
2884#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2885
2886 // Legalize each of the actions registered during application.
2887 if (failed(legalizePatternRootUpdates(modifiedOps)) ||
2888 failed(legalizePatternCreatedOperations(newOps))) {
2889 return failure();
2890 }
2891
2892 LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
2893 return success();
2894}
2895
2896LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2897 const SetVector<Operation *> &newOps) {
2898 for (Operation *op : newOps) {
2899 if (failed(legalize(op))) {
2900 LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
2901 "failed to legalize generated operation '{0}'({1})",
2902 op->getName(), op));
2903 return failure();
2904 }
2905 }
2906 return success();
2907}
2908
2909LogicalResult OperationLegalizer::legalizePatternRootUpdates(
2910 const SetVector<Operation *> &modifiedOps) {
2911 for (Operation *op : modifiedOps) {
2912 if (failed(legalize(op))) {
2913 LLVM_DEBUG(
2914 logFailure(rewriter.getImpl().logger,
2915 "failed to legalize operation updated in-place '{0}'",
2916 op->getName()));
2917 return failure();
2918 }
2919 }
2920 return success();
2921}
2922
2923//===----------------------------------------------------------------------===//
2924// Cost Model
2925//===----------------------------------------------------------------------===//
2926
2927void OperationLegalizer::buildLegalizationGraph(
2928 LegalizationPatterns &anyOpLegalizerPatterns,
2930 // A mapping between an operation and a set of operations that can be used to
2931 // generate it.
2933 // A mapping between an operation and any currently invalid patterns it has.
2935 // A worklist of patterns to consider for legality.
2936 SetVector<const Pattern *> patternWorklist;
2937
2938 // Build the mapping from operations to the parent ops that may generate them.
2939 applicator.walkAllPatterns([&](const Pattern &pattern) {
2940 std::optional<OperationName> root = pattern.getRootKind();
2941
2942 // If the pattern has no specific root, we can't analyze the relationship
2943 // between the root op and generated operations. Given that, add all such
2944 // patterns to the legalization set.
2945 if (!root) {
2946 anyOpLegalizerPatterns.push_back(&pattern);
2947 return;
2948 }
2949
2950 // Skip operations that are always known to be legal.
2951 if (target.getOpAction(*root) == LegalizationAction::Legal)
2952 return;
2953
2954 // Add this pattern to the invalid set for the root op and record this root
2955 // as a parent for any generated operations.
2956 invalidPatterns[*root].insert(&pattern);
2957 for (auto op : pattern.getGeneratedOps())
2958 parentOps[op].insert(*root);
2959
2960 // Add this pattern to the worklist.
2961 patternWorklist.insert(&pattern);
2962 });
2963
2964 // If there are any patterns that don't have a specific root kind, we can't
2965 // make direct assumptions about what operations will never be legalized.
2966 // Note: Technically we could, but it would require an analysis that may
2967 // recurse into itself. It would be better to perform this kind of filtering
2968 // at a higher level than here anyways.
2969 if (!anyOpLegalizerPatterns.empty()) {
2970 for (const Pattern *pattern : patternWorklist)
2971 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2972 return;
2973 }
2974
2975 while (!patternWorklist.empty()) {
2976 auto *pattern = patternWorklist.pop_back_val();
2977
2978 // Check to see if any of the generated operations are invalid.
2979 if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
2980 std::optional<LegalizationAction> action = target.getOpAction(op);
2981 return !legalizerPatterns.count(op) &&
2982 (!action || action == LegalizationAction::Illegal);
2983 }))
2984 continue;
2985
2986 // Otherwise, if all of the generated operation are valid, this op is now
2987 // legal so add all of the child patterns to the worklist.
2988 legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
2989 invalidPatterns[*pattern->getRootKind()].erase(pattern);
2990
2991 // Add any invalid patterns of the parent operations to see if they have now
2992 // become legal.
2993 for (auto op : parentOps[*pattern->getRootKind()])
2994 patternWorklist.set_union(invalidPatterns[op]);
2995 }
2996}
2997
2998void OperationLegalizer::computeLegalizationGraphBenefit(
2999 LegalizationPatterns &anyOpLegalizerPatterns,
3001 // The smallest pattern depth, when legalizing an operation.
3002 DenseMap<OperationName, unsigned> minOpPatternDepth;
3003
3004 // For each operation that is transitively legal, compute a cost for it.
3005 for (auto &opIt : legalizerPatterns)
3006 if (!minOpPatternDepth.count(opIt.first))
3007 computeOpLegalizationDepth(opIt.first, minOpPatternDepth,
3008 legalizerPatterns);
3009
3010 // Apply the cost model to the patterns that can match any operation. Those
3011 // with a specific operation type are already resolved when computing the op
3012 // legalization depth.
3013 if (!anyOpLegalizerPatterns.empty())
3014 applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth,
3015 legalizerPatterns);
3016
3017 // Apply a cost model to the pattern applicator. We order patterns first by
3018 // depth then benefit. `legalizerPatterns` contains per-op patterns by
3019 // decreasing benefit.
3020 applicator.applyCostModel([&](const Pattern &pattern) {
3021 ArrayRef<const Pattern *> orderedPatternList;
3022 if (std::optional<OperationName> rootName = pattern.getRootKind())
3023 orderedPatternList = legalizerPatterns[*rootName];
3024 else
3025 orderedPatternList = anyOpLegalizerPatterns;
3026
3027 // If the pattern is not found, then it was removed and cannot be matched.
3028 auto *it = llvm::find(orderedPatternList, &pattern);
3029 if (it == orderedPatternList.end())
3031
3032 // Patterns found earlier in the list have higher benefit.
3033 return PatternBenefit(std::distance(it, orderedPatternList.end()));
3034 });
3035}
3036
3037unsigned OperationLegalizer::computeOpLegalizationDepth(
3038 OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth,
3040 // Check for existing depth.
3041 auto depthIt = minOpPatternDepth.find(op);
3042 if (depthIt != minOpPatternDepth.end())
3043 return depthIt->second;
3044
3045 // If a mapping for this operation does not exist, then this operation
3046 // is always legal. Return 0 as the depth for a directly legal operation.
3047 auto opPatternsIt = legalizerPatterns.find(op);
3048 if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty())
3049 return 0u;
3050
3051 // Record this initial depth in case we encounter this op again when
3052 // recursively computing the depth.
3053 minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max());
3054
3055 // Apply the cost model to the operation patterns, and update the minimum
3056 // depth.
3057 unsigned minDepth = applyCostModelToPatterns(
3058 opPatternsIt->second, minOpPatternDepth, legalizerPatterns);
3059 minOpPatternDepth[op] = minDepth;
3060 return minDepth;
3061}
3062
3063unsigned OperationLegalizer::applyCostModelToPatterns(
3064 LegalizationPatterns &patterns,
3065 DenseMap<OperationName, unsigned> &minOpPatternDepth,
3067 unsigned minDepth = std::numeric_limits<unsigned>::max();
3068
3069 // Compute the depth for each pattern within the set.
3070 SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
3071 patternsByDepth.reserve(patterns.size());
3072 for (const Pattern *pattern : patterns) {
3073 unsigned depth = 1;
3074 for (auto generatedOp : pattern->getGeneratedOps()) {
3075 unsigned generatedOpDepth = computeOpLegalizationDepth(
3076 generatedOp, minOpPatternDepth, legalizerPatterns);
3077 depth = std::max(depth, generatedOpDepth + 1);
3078 }
3079 patternsByDepth.emplace_back(pattern, depth);
3080
3081 // Update the minimum depth of the pattern list.
3082 minDepth = std::min(minDepth, depth);
3083 }
3084
3085 // If the operation only has one legalization pattern, there is no need to
3086 // sort them.
3087 if (patternsByDepth.size() == 1)
3088 return minDepth;
3089
3090 // Sort the patterns by those likely to be the most beneficial.
3091 llvm::stable_sort(patternsByDepth,
3092 [](const std::pair<const Pattern *, unsigned> &lhs,
3093 const std::pair<const Pattern *, unsigned> &rhs) {
3094 // First sort by the smaller pattern legalization
3095 // depth.
3096 if (lhs.second != rhs.second)
3097 return lhs.second < rhs.second;
3098
3099 // Then sort by the larger pattern benefit.
3100 auto lhsBenefit = lhs.first->getBenefit();
3101 auto rhsBenefit = rhs.first->getBenefit();
3102 return lhsBenefit > rhsBenefit;
3103 });
3104
3105 // Update the legalization pattern to use the new sorted list.
3106 patterns.clear();
3107 for (auto &patternIt : patternsByDepth)
3108 patterns.push_back(patternIt.first);
3109 return minDepth;
3110}
3111
3112//===----------------------------------------------------------------------===//
3113// Reconcile Unrealized Casts
3114//===----------------------------------------------------------------------===//
3115
3116/// Try to reconcile all given UnrealizedConversionCastOps and store the
3117/// left-over ops in `remainingCastOps` (if provided). See documentation in
3118/// DialectConversion.h for more details.
3119/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3120/// algorithm may visit an operand (or user) which is a cast op, but will not
3121/// try to reconcile it if not in the filtered set.
3122template <typename RangeT>
3124 RangeT castOps,
3125 function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3127 // A worklist of cast ops to process.
3128 SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
3129
3130 // Helper function that return the unrealized_conversion_cast op that
3131 // defines all inputs of the given op (in the same order). Return "nullptr"
3132 // if there is no such op.
3133 auto getInputCast =
3134 [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
3135 if (castOp.getInputs().empty())
3136 return {};
3137 auto inputCastOp =
3138 castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
3139 if (!inputCastOp)
3140 return {};
3141 if (inputCastOp.getOutputs() != castOp.getInputs())
3142 return {};
3143 return inputCastOp;
3144 };
3145
3146 // Process ops in the worklist bottom-to-top.
3147 while (!worklist.empty()) {
3148 UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3149
3150 // Traverse the chain of input cast ops to see if an op with the same
3151 // input types can be found.
3152 UnrealizedConversionCastOp nextCast = castOp;
3153 while (nextCast) {
3154 if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3155 if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
3156 return v.getDefiningOp() == castOp;
3157 })) {
3158 // Ran into a cycle.
3159 break;
3160 }
3161
3162 // Found a cast where the input types match the output types of the
3163 // matched op. We can directly use those inputs.
3164 castOp.replaceAllUsesWith(nextCast.getInputs());
3165 break;
3166 }
3167 nextCast = getInputCast(nextCast);
3168 }
3169 }
3170
3171 // A set of all alive cast ops. I.e., ops whose results are (transitively)
3172 // used by an op that is not a cast op.
3173 DenseSet<Operation *> liveOps;
3174
3175 // Helper function that marks the given op and transitively reachable input
3176 // cast ops as alive.
3177 auto markOpLive = [&](Operation *rootOp) {
3178 SmallVector<Operation *> worklist;
3179 worklist.push_back(rootOp);
3180 while (!worklist.empty()) {
3181 Operation *op = worklist.pop_back_val();
3182 if (liveOps.insert(op).second) {
3183 // Successfully inserted: process reachable input cast ops.
3184 for (Value v : op->getOperands())
3185 if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
3186 if (isCastOpOfInterestFn(castOp))
3187 worklist.push_back(castOp);
3188 }
3189 }
3190 };
3191
3192 // Find all alive cast ops.
3193 for (UnrealizedConversionCastOp op : castOps) {
3194 // The op may have been marked live already as being an operand of another
3195 // live cast op.
3196 if (liveOps.contains(op.getOperation()))
3197 continue;
3198 // If any of the users is not a cast op, mark the current op (and its
3199 // input ops) as live.
3200 if (llvm::any_of(op->getUsers(), [&](Operation *user) {
3201 auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3202 return !castOp || !isCastOpOfInterestFn(castOp);
3203 }))
3204 markOpLive(op);
3205 }
3206
3207 // Erase all dead cast ops.
3208 for (UnrealizedConversionCastOp op : castOps) {
3209 if (liveOps.contains(op)) {
3210 // Op is alive and was not erased. Add it to the remaining cast ops.
3211 if (remainingCastOps)
3212 remainingCastOps->push_back(op);
3213 continue;
3214 }
3215
3216 // Op is dead. Erase it.
3217 op->dropAllUses();
3218 op->erase();
3219 }
3220}
3221
3223 ArrayRef<UnrealizedConversionCastOp> castOps,
3224 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3225 // Set of all cast ops for faster lookups.
3226 DenseSet<UnrealizedConversionCastOp> castOpSet;
3227 for (UnrealizedConversionCastOp op : castOps)
3228 castOpSet.insert(op);
3229 reconcileUnrealizedCasts(castOpSet, remainingCastOps);
3230}
3231
3233 const DenseSet<UnrealizedConversionCastOp> &castOps,
3234 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3236 llvm::make_range(castOps.begin(), castOps.end()),
3237 [&](UnrealizedConversionCastOp castOp) {
3238 return castOps.contains(castOp);
3239 },
3240 remainingCastOps);
3241}
3242
3243namespace mlir {
3245 const llvm::MapVector<UnrealizedConversionCastOp,
3246 UnresolvedMaterializationInfo> &castOps,
3249 castOps.keys(),
3250 [&](UnrealizedConversionCastOp castOp) {
3251 return castOps.contains(castOp);
3252 },
3253 remainingCastOps);
3254}
3255} // namespace mlir
3256
3257//===----------------------------------------------------------------------===//
3258// OperationConverter
3259//===----------------------------------------------------------------------===//
3260
3261namespace mlir {
3262// This class converts operations to a given conversion target via a set of
3263// rewrite patterns. The conversion behaves differently depending on the
3264// conversion mode.
3267 const FrozenRewritePatternSet &patterns,
3268 const ConversionConfig &config,
3269 OpConversionMode mode)
3270 : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
3271 mode(mode) {}
3272
3273 /// Applies the conversion to the given operations (and their nested
3274 /// operations).
3275 LogicalResult applyConversion(ArrayRef<Operation *> ops);
3276
3277 /// Legalizes the given operations (and their nested operations) to the
3278 /// conversion target.
3279 template <typename Fn>
3280 LogicalResult legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
3281 bool isRecursiveLegalization = false);
3283 bool isRecursiveLegalization = false) {
3284 return legalizeOperations(
3285 ops, /*onFailure=*/[&]() {}, isRecursiveLegalization);
3286 }
3287
3288 /// Converts a single operation. If `isRecursiveLegalization` is "true", the
3289 /// conversion is a recursive legalization request, triggered from within a
3290 /// pattern. In that case, do not emit errors because there will be another
3291 /// attempt at legalizing the operation later (via the regular pre-order
3292 /// legalization mechanism).
3293 LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
3294
3295 const ConversionTarget &getTarget() { return opLegalizer.getTarget(); }
3296
3297private:
3298 /// The rewriter to use when converting operations.
3299 ConversionPatternRewriter rewriter;
3300
3301 /// The legalizer to use when converting operations.
3302 OperationLegalizer opLegalizer;
3303
3304 /// The conversion mode to use when legalizing operations.
3305 OpConversionMode mode;
3306};
3307} // namespace mlir
3308
3310 bool isRecursiveLegalization) {
3311 const ConversionConfig &config = rewriter.getConfig();
3312 auto emitFailedToLegalizeDiag = [&](bool wasExplicitlyIllegal) {
3314 << "failed to legalize operation '"
3315 << op->getName() << "'";
3316 if (wasExplicitlyIllegal)
3317 diag << " that was explicitly marked illegal";
3318 diag << ": " << OpWithFlags(op, OpPrintingFlags().skipRegions());
3319 };
3320
3321 // Legalize the given operation.
3322 if (failed(opLegalizer.legalize(op))) {
3323 // Handle the case of a failed conversion for each of the different modes.
3324 // Full conversions expect all operations to be converted.
3325 if (mode == OpConversionMode::Full) {
3326 if (!isRecursiveLegalization)
3327 emitFailedToLegalizeDiag(/*wasExplicitlyIllegal=*/false);
3328 return failure();
3329 }
3330 // Partial conversions allow conversions to fail iff the operation was not
3331 // explicitly marked as illegal. If the user provided a `unlegalizedOps`
3332 // set, non-legalizable ops are added to that set.
3333 if (mode == OpConversionMode::Partial) {
3334 if (opLegalizer.isIllegal(op)) {
3335 if (!isRecursiveLegalization)
3336 emitFailedToLegalizeDiag(/*wasExplicitlyIllegal=*/true);
3337 return failure();
3338 }
3339 if (config.unlegalizedOps && !isRecursiveLegalization)
3340 config.unlegalizedOps->insert(op);
3341 }
3342 } else if (mode == OpConversionMode::Analysis) {
3343 // Analysis conversions don't fail if any operations fail to legalize,
3344 // they are only interested in the operations that were successfully
3345 // legalized.
3346 if (config.legalizableOps && !isRecursiveLegalization)
3347 config.legalizableOps->insert(op);
3348 }
3349 return success();
3350}
3351
3352static LogicalResult
3354 UnrealizedConversionCastOp op,
3355 const UnresolvedMaterializationInfo &info) {
3356 assert(!op.use_empty() &&
3357 "expected that dead materializations have already been DCE'd");
3358 Operation::operand_range inputOperands = op.getOperands();
3359
3360 // Try to materialize the conversion.
3361 if (const TypeConverter *converter = info.getConverter()) {
3362 rewriter.setInsertionPoint(op);
3363 SmallVector<Value> newMaterialization;
3364 switch (info.getMaterializationKind()) {
3365 case MaterializationKind::Target:
3366 newMaterialization = converter->materializeTargetConversion(
3367 rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
3368 info.getOriginalType());
3369 break;
3370 case MaterializationKind::Source:
3371 assert(op->getNumResults() == 1 && "expected single result");
3372 Value sourceMat = converter->materializeSourceConversion(
3373 rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
3374 if (sourceMat)
3375 newMaterialization.push_back(sourceMat);
3376 break;
3377 }
3378 if (!newMaterialization.empty()) {
3379#ifndef NDEBUG
3380 ValueRange newMaterializationRange(newMaterialization);
3381 assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
3382 "materialization callback produced value of incorrect type");
3383#endif // NDEBUG
3384 rewriter.replaceOp(op, newMaterialization);
3385 return success();
3386 }
3387 }
3388
3389 InFlightDiagnostic diag = op->emitError()
3390 << "failed to legalize unresolved materialization "
3391 "from ("
3392 << inputOperands.getTypes() << ") to ("
3393 << op.getResultTypes()
3394 << ") that remained live after conversion";
3395 diag.attachNote(op->getUsers().begin()->getLoc())
3396 << "see existing live user here: " << *op->getUsers().begin();
3397 return failure();
3398}
3399
3400template <typename Fn>
3401LogicalResult
3403 bool isRecursiveLegalization) {
3404 const ConversionTarget &target = opLegalizer.getTarget();
3405
3406 // Compute the set of operations and blocks to convert.
3407 SmallVector<Operation *> toConvert;
3408 for (Operation *op : ops) {
3410 [&](Operation *op) {
3411 toConvert.push_back(op);
3412 // Don't check this operation's children for conversion if the
3413 // operation is recursively legal.
3414 auto legalityInfo = target.isLegal(op);
3415 if (legalityInfo && legalityInfo->isRecursivelyLegal)
3416 return WalkResult::skip();
3417 return WalkResult::advance();
3418 });
3419 }
3420 for (Operation *op : toConvert) {
3421 if (failed(convert(op, isRecursiveLegalization))) {
3422 // Failed to convert an operation.
3423 onFailure();
3424 return failure();
3425 }
3426 }
3427 return success();
3428}
3429
3430LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
3431 return impl->opConverter.legalizeOperations(op,
3432 /*isRecursiveLegalization=*/true);
3433}
3434
3435LogicalResult ConversionPatternRewriter::legalize(Region *r) {
3436 // Fast path: If the region is empty, there is nothing to legalize.
3437 if (r->empty())
3438 return success();
3439
3440 // Gather a list of all operations to legalize. This is done before
3441 // converting the entry block signature because unrealized_conversion_cast
3442 // ops should not be included.
3444 for (Block &b : *r)
3445 for (Operation &op : b)
3446 ops.push_back(&op);
3447
3448 // If the current pattern runs with a type converter, convert the entry block
3449 // signature.
3450 if (const TypeConverter *converter = impl->currentTypeConverter) {
3451 std::optional<TypeConverter::SignatureConversion> conversion =
3452 converter->convertBlockSignature(&r->front());
3453 if (!conversion)
3454 return failure();
3455 applySignatureConversion(&r->front(), *conversion, converter);
3456 }
3457
3458 // Legalize all operations in the region. This includes all nested
3459 // operations.
3460 return impl->opConverter.legalizeOperations(ops,
3461 /*isRecursiveLegalization=*/true);
3462}
3463
3465 // Convert each operation and discard rewrites on failure.
3466 ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
3467
3468 LogicalResult status = legalizeOperations(ops, /*onFailure=*/[&]() {
3469 // Dialect conversion failed.
3470 if (rewriterImpl.config.allowPatternRollback) {
3471 // Rollback is allowed: restore the original IR.
3472 rewriterImpl.undoRewrites();
3473 } else {
3474 // Rollback is not allowed: apply all modifications that have been
3475 // performed so far.
3476 rewriterImpl.applyRewrites();
3477 }
3478 });
3479 if (failed(status))
3480 return failure();
3481
3482 // After a successful conversion, apply rewrites.
3483 rewriterImpl.applyRewrites();
3484
3485 // Reconcile all UnrealizedConversionCastOps that were inserted by the
3486 // dialect conversion frameworks. (Not the ones that were inserted by
3487 // patterns.)
3488 const llvm::MapVector<UnrealizedConversionCastOp,
3489 UnresolvedMaterializationInfo> &materializations =
3490 rewriterImpl.unresolvedMaterializations;
3492 reconcileUnrealizedCasts(materializations, &remainingCastOps);
3493
3494 // Drop markers.
3495 for (UnrealizedConversionCastOp castOp : remainingCastOps)
3496 castOp->removeAttr(kPureTypeConversionMarker);
3497
3498 // Try to legalize all unresolved materializations.
3499 if (rewriter.getConfig().buildMaterializations) {
3500 // Use a new rewriter, so the modifications are not tracked for rollback
3501 // purposes etc.
3502 IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
3503 rewriter.getConfig().listener);
3504 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
3505 auto it = materializations.find(castOp);
3506 assert(it != materializations.end() && "inconsistent state");
3507 if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
3508 it->second)))
3509 return failure();
3510 }
3511 }
3512
3513 return success();
3514}
3515
3516//===----------------------------------------------------------------------===//
3517// Type Conversion
3518//===----------------------------------------------------------------------===//
3519
3520void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
3521 ArrayRef<Type> types) {
3522 assert(!types.empty() && "expected valid types");
3523 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
3524 addInputs(types);
3525}
3526
3527void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
3528 assert(!types.empty() &&
3529 "1->0 type remappings don't need to be added explicitly");
3530 argTypes.append(types.begin(), types.end());
3531}
3532
3533void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
3534 unsigned newInputNo,
3535 unsigned newInputCount) {
3536 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3537 assert(newInputCount != 0 && "expected valid input count");
3538 remappedInputs[origInputNo] =
3539 InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
3540}
3541
3542void TypeConverter::SignatureConversion::remapInput(
3543 unsigned origInputNo, ArrayRef<Value> replacements) {
3544 assert(!remappedInputs[origInputNo] && "input has already been remapped");
3545 remappedInputs[origInputNo] = InputMapping{
3546 origInputNo, /*size=*/0,
3547 SmallVector<Value, 1>(replacements.begin(), replacements.end())};
3548}
3549
3550/// Internal implementation of the type conversion.
3551/// This is used with either a Type or a Value as the first argument.
3552/// - we can cache the context-free conversions until the last registered
3553/// context-aware conversion.
3554/// - we can't cache the result of type conversion happening after context-aware
3555/// conversions, because the type converter may return different results for the
3556/// same input type.
3557LogicalResult
3558TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
3559 SmallVectorImpl<Type> &results) const {
3560 assert(typeOrValue && "expected non-null type");
3561 Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
3562 : cast<Type>(typeOrValue);
3563 {
3564 std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
3565 std::defer_lock);
3567 cacheReadLock.lock();
3568 auto existingIt = cachedDirectConversions.find(t);
3569 if (existingIt != cachedDirectConversions.end()) {
3570 if (existingIt->second)
3571 results.push_back(existingIt->second);
3572 return success(existingIt->second != nullptr);
3573 }
3574 auto multiIt = cachedMultiConversions.find(t);
3575 if (multiIt != cachedMultiConversions.end()) {
3576 results.append(multiIt->second.begin(), multiIt->second.end());
3577 return success();
3578 }
3579 }
3580 // Walk the added converters in reverse order to apply the most recently
3581 // registered first.
3582 size_t currentCount = results.size();
3583
3584 // We can cache the context-free conversions until the last registered
3585 // context-aware conversion. But only if we're processing a Value right now.
3586 auto isCacheable = [&](int index) {
3587 int numberOfConversionsUntilContextAware =
3588 conversions.size() - 1 - contextAwareTypeConversionsIndex;
3589 return index < numberOfConversionsUntilContextAware;
3590 };
3591
3592 std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
3593 std::defer_lock);
3594
3595 for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
3596 const ConversionCallbackFn &converter = indexedConverter.value();
3597 std::optional<LogicalResult> result = converter(typeOrValue, results);
3598 if (!result) {
3599 assert(results.size() == currentCount &&
3600 "failed type conversion should not change results");
3601 continue;
3602 }
3603 if (!isCacheable(indexedConverter.index()))
3604 return success();
3606 cacheWriteLock.lock();
3607 if (!succeeded(*result)) {
3608 assert(results.size() == currentCount &&
3609 "failed type conversion should not change results");
3610 cachedDirectConversions.try_emplace(t, nullptr);
3611 return failure();
3612 }
3613 auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
3614 if (newTypes.size() == 1)
3615 cachedDirectConversions.try_emplace(t, newTypes.front());
3616 else
3617 cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
3618 return success();
3619 }
3620 return failure();
3621}
3622
3623LogicalResult TypeConverter::convertType(Type t,
3624 SmallVectorImpl<Type> &results) const {
3625 return convertTypeImpl(t, results);
3626}
3627
3628LogicalResult TypeConverter::convertType(Value v,
3629 SmallVectorImpl<Type> &results) const {
3630 return convertTypeImpl(v, results);
3631}
3632
3633Type TypeConverter::convertType(Type t) const {
3634 // Use the multi-type result version to convert the type.
3635 SmallVector<Type, 1> results;
3636 if (failed(convertType(t, results)))
3637 return nullptr;
3638
3639 // Check to ensure that only one type was produced.
3640 return results.size() == 1 ? results.front() : nullptr;
3641}
3642
3643Type TypeConverter::convertType(Value v) const {
3644 // Use the multi-type result version to convert the type.
3645 SmallVector<Type, 1> results;
3646 if (failed(convertType(v, results)))
3647 return nullptr;
3648
3649 // Check to ensure that only one type was produced.
3650 return results.size() == 1 ? results.front() : nullptr;
3651}
3652
3653LogicalResult
3654TypeConverter::convertTypes(TypeRange types,
3655 SmallVectorImpl<Type> &results) const {
3656 for (Type type : types)
3657 if (failed(convertType(type, results)))
3658 return failure();
3659 return success();
3660}
3661
3662LogicalResult
3663TypeConverter::convertTypes(ValueRange values,
3664 SmallVectorImpl<Type> &results) const {
3665 for (Value value : values)
3666 if (failed(convertType(value, results)))
3667 return failure();
3668 return success();
3669}
3670
3671bool TypeConverter::isLegal(Type type) const {
3672 return convertType(type) == type;
3673}
3674
3675bool TypeConverter::isLegal(Value value) const {
3676 return convertType(value) == value.getType();
3677}
3678
3679bool TypeConverter::isLegal(Operation *op) const {
3680 return isLegal(op->getOperands()) && isLegal(op->getResults());
3681}
3682
3683bool TypeConverter::isLegal(Region *region) const {
3684 return llvm::all_of(
3685 *region, [this](Block &block) { return isLegal(block.getArguments()); });
3686}
3687
3688bool TypeConverter::isSignatureLegal(FunctionType ty) const {
3689 if (!isLegal(ty.getInputs()))
3690 return false;
3691 if (!isLegal(ty.getResults()))
3692 return false;
3693 return true;
3694}
3695
3696LogicalResult
3697TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
3698 SignatureConversion &result) const {
3699 // Try to convert the given input type.
3700 SmallVector<Type, 1> convertedTypes;
3701 if (failed(convertType(type, convertedTypes)))
3702 return failure();
3703
3704 // If this argument is being dropped, there is nothing left to do.
3705 if (convertedTypes.empty())
3706 return success();
3707
3708 // Otherwise, add the new inputs.
3709 result.addInputs(inputNo, convertedTypes);
3710 return success();
3711}
3712LogicalResult
3713TypeConverter::convertSignatureArgs(TypeRange types,
3714 SignatureConversion &result,
3715 unsigned origInputOffset) const {
3716 for (unsigned i = 0, e = types.size(); i != e; ++i)
3717 if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
3718 return failure();
3719 return success();
3720}
3721LogicalResult
3722TypeConverter::convertSignatureArg(unsigned inputNo, Value value,
3723 SignatureConversion &result) const {
3724 // Try to convert the given input type.
3725 SmallVector<Type, 1> convertedTypes;
3726 if (failed(convertType(value, convertedTypes)))
3727 return failure();
3728
3729 // If this argument is being dropped, there is nothing left to do.
3730 if (convertedTypes.empty())
3731 return success();
3732
3733 // Otherwise, add the new inputs.
3734 result.addInputs(inputNo, convertedTypes);
3735 return success();
3736}
3737LogicalResult
3738TypeConverter::convertSignatureArgs(ValueRange values,
3739 SignatureConversion &result,
3740 unsigned origInputOffset) const {
3741 for (unsigned i = 0, e = values.size(); i != e; ++i)
3742 if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
3743 return failure();
3744 return success();
3745}
3746
3747Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
3748 Location loc, Type resultType,
3749 ValueRange inputs) const {
3750 for (const SourceMaterializationCallbackFn &fn :
3751 llvm::reverse(sourceMaterializations))
3752 if (Value result = fn(builder, resultType, inputs, loc))
3753 return result;
3754 return nullptr;
3755}
3756
3757Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
3758 Location loc, Type resultType,
3759 ValueRange inputs,
3760 Type originalType) const {
3761 SmallVector<Value> result = materializeTargetConversion(
3762 builder, loc, TypeRange(resultType), inputs, originalType);
3763 if (result.empty())
3764 return nullptr;
3765 assert(result.size() == 1 && "expected single result");
3766 return result.front();
3767}
3768
3769SmallVector<Value> TypeConverter::materializeTargetConversion(
3770 OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
3771 Type originalType) const {
3772 for (const TargetMaterializationCallbackFn &fn :
3773 llvm::reverse(targetMaterializations)) {
3774 SmallVector<Value> result =
3775 fn(builder, resultTypes, inputs, loc, originalType);
3776 if (result.empty())
3777 continue;
3778 assert(TypeRange(ValueRange(result)) == resultTypes &&
3779 "callback produced incorrect number of values or values with "
3780 "incorrect types");
3781 return result;
3782 }
3783 return {};
3784}
3785
3786std::optional<TypeConverter::SignatureConversion>
3787TypeConverter::convertBlockSignature(Block *block) const {
3788 SignatureConversion conversion(block->getNumArguments());
3789 if (failed(convertSignatureArgs(block->getArguments(), conversion)))
3790 return std::nullopt;
3791 return conversion;
3792}
3793
3794//===----------------------------------------------------------------------===//
3795// Type attribute conversion
3796//===----------------------------------------------------------------------===//
3797TypeConverter::AttributeConversionResult
3798TypeConverter::AttributeConversionResult::result(Attribute attr) {
3799 return AttributeConversionResult(attr, resultTag);
3800}
3801
3802TypeConverter::AttributeConversionResult
3803TypeConverter::AttributeConversionResult::na() {
3804 return AttributeConversionResult(nullptr, naTag);
3805}
3806
3807TypeConverter::AttributeConversionResult
3808TypeConverter::AttributeConversionResult::abort() {
3809 return AttributeConversionResult(nullptr, abortTag);
3810}
3811
3812bool TypeConverter::AttributeConversionResult::hasResult() const {
3813 return impl.getInt() == resultTag;
3814}
3815
3816bool TypeConverter::AttributeConversionResult::isNa() const {
3817 return impl.getInt() == naTag;
3818}
3819
3820bool TypeConverter::AttributeConversionResult::isAbort() const {
3821 return impl.getInt() == abortTag;
3822}
3823
3824Attribute TypeConverter::AttributeConversionResult::getResult() const {
3825 assert(hasResult() && "Cannot get result from N/A or abort");
3826 return impl.getPointer();
3827}
3828
3829std::optional<Attribute>
3830TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
3831 for (const TypeAttributeConversionCallbackFn &fn :
3832 llvm::reverse(typeAttributeConversions)) {
3833 AttributeConversionResult res = fn(type, attr);
3834 if (res.hasResult())
3835 return res.getResult();
3836 if (res.isAbort())
3837 return std::nullopt;
3838 }
3839 return std::nullopt;
3840}
3841
3842//===----------------------------------------------------------------------===//
3843// FunctionOpInterfaceSignatureConversion
3844//===----------------------------------------------------------------------===//
3845
3846static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
3847 const TypeConverter &typeConverter,
3848 ConversionPatternRewriter &rewriter) {
3849 FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
3850 if (!type)
3851 return failure();
3852
3853 // Convert the function signature (inputs and results).
3854 TypeConverter::SignatureConversion funcConversion(type.getNumInputs());
3855 SmallVector<Type, 1> newResults;
3856 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
3857 funcConversion)) ||
3858 failed(typeConverter.convertTypes(type.getResults(), newResults)))
3859 return failure();
3860
3861 // If the function has a body, apply a separate signature conversion to the
3862 // entry block. Some function ops (e.g., gpu.func) have extra block arguments
3863 // beyond the function type inputs (e.g., workgroup memory arguments that are
3864 // not part of the public signature). Use a distinct conversion sized for all
3865 // entry block arguments so that applySignatureConversion does not access
3866 // out-of-bounds mappings.
3867 if (!funcOp.getFunctionBody().empty()) {
3868 Block *entryBlock = &funcOp.getFunctionBody().front();
3869 unsigned numEntryBlockArgs = entryBlock->getNumArguments();
3870 unsigned numFuncTypeInputs = type.getNumInputs();
3871 TypeConverter::SignatureConversion blockConversion(numEntryBlockArgs);
3872 // Convert the function-type inputs the same way as for the function type.
3873 if (failed(typeConverter.convertSignatureArgs(type.getInputs(),
3874 blockConversion)))
3875 return failure();
3876 // Add identity mappings for extra block args beyond the function type
3877 // inputs. These arguments are preserved as-is.
3878 for (unsigned i = numFuncTypeInputs; i < numEntryBlockArgs; ++i)
3879 blockConversion.addInputs(i, entryBlock->getArgument(i).getType());
3880 rewriter.applySignatureConversion(entryBlock, blockConversion,
3881 &typeConverter);
3882 }
3883
3884 auto newType = FunctionType::get(
3885 rewriter.getContext(), funcConversion.getConvertedTypes(), newResults);
3886
3887 rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
3888
3889 return success();
3890}
3891
3892/// Create a default conversion pattern that rewrites the type signature of a
3893/// FunctionOpInterface op. This only supports ops which use FunctionType to
3894/// represent their type.
3895namespace {
3896struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
3897 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
3898 MLIRContext *ctx,
3899 const TypeConverter &converter,
3900 PatternBenefit benefit)
3901 : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
3902
3903 LogicalResult
3904 matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
3905 ConversionPatternRewriter &rewriter) const override {
3906 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
3907 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3908 }
3909};
3910
3911struct AnyFunctionOpInterfaceSignatureConversion
3912 : public OpInterfaceConversionPattern<FunctionOpInterface> {
3913 using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
3914
3915 LogicalResult
3916 matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
3917 ConversionPatternRewriter &rewriter) const override {
3918 return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
3919 }
3920};
3921} // namespace
3922
3923FailureOr<Operation *>
3924mlir::convertOpResultTypes(Operation *op, ValueRange operands,
3925 const TypeConverter &converter,
3926 ConversionPatternRewriter &rewriter) {
3927 assert(op && "Invalid op");
3928 Location loc = op->getLoc();
3929 if (converter.isLegal(op))
3930 return rewriter.notifyMatchFailure(loc, "op already legal");
3931
3932 OperationState newOp(loc, op->getName());
3933 newOp.addOperands(operands);
3934
3935 SmallVector<Type> newResultTypes;
3936 if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
3937 return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3938
3939 newOp.addTypes(newResultTypes);
3940 newOp.addAttributes(op->getAttrs());
3941 return rewriter.create(newOp);
3942}
3943
3944void mlir::populateFunctionOpInterfaceTypeConversionPattern(
3945 StringRef functionLikeOpName, RewritePatternSet &patterns,
3946 const TypeConverter &converter, PatternBenefit benefit) {
3947 patterns.add<FunctionOpInterfaceSignatureConversion>(
3948 functionLikeOpName, patterns.getContext(), converter, benefit);
3949}
3950
3951void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
3952 RewritePatternSet &patterns, const TypeConverter &converter,
3953 PatternBenefit benefit) {
3954 patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
3955 converter, patterns.getContext(), benefit);
3956}
3957
3958//===----------------------------------------------------------------------===//
3959// ConversionTarget
3960//===----------------------------------------------------------------------===//
3961
3962void ConversionTarget::setOpAction(OperationName op,
3963 LegalizationAction action) {
3964 legalOperations[op].action = action;
3965}
3966
3967void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
3968 LegalizationAction action) {
3969 for (StringRef dialect : dialectNames)
3970 legalDialects[dialect] = action;
3971}
3972
3973auto ConversionTarget::getOpAction(OperationName op) const
3974 -> std::optional<LegalizationAction> {
3975 std::optional<LegalizationInfo> info = getOpInfo(op);
3976 return info ? info->action : std::optional<LegalizationAction>();
3977}
3978
3979auto ConversionTarget::isLegal(Operation *op) const
3980 -> std::optional<LegalOpDetails> {
3981 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
3982 if (!info)
3983 return std::nullopt;
3984
3985 // Returns true if this operation instance is known to be legal.
3986 auto isOpLegal = [&] {
3987 // Handle dynamic legality either with the provided legality function.
3988 if (info->action == LegalizationAction::Dynamic) {
3989 std::optional<bool> result = info->legalityFn(op);
3990 if (result)
3991 return *result;
3992 }
3993
3994 // Otherwise, the operation is only legal if it was marked 'Legal'.
3995 return info->action == LegalizationAction::Legal;
3996 };
3997 if (!isOpLegal())
3998 return std::nullopt;
3999
4000 // This operation is legal, compute any additional legality information.
4001 LegalOpDetails legalityDetails;
4002 if (info->isRecursivelyLegal) {
4003 auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
4004 if (legalityFnIt != opRecursiveLegalityFns.end()) {
4005 legalityDetails.isRecursivelyLegal =
4006 legalityFnIt->second(op).value_or(true);
4007 } else {
4008 legalityDetails.isRecursivelyLegal = true;
4009 }
4010 }
4011 return legalityDetails;
4012}
4013
4014bool ConversionTarget::isIllegal(Operation *op) const {
4015 std::optional<LegalizationInfo> info = getOpInfo(op->getName());
4016 if (!info)
4017 return false;
4018
4019 if (info->action == LegalizationAction::Dynamic) {
4020 std::optional<bool> result = info->legalityFn(op);
4021 if (!result)
4022 return false;
4023
4024 return !(*result);
4025 }
4026
4027 return info->action == LegalizationAction::Illegal;
4028}
4029
4030static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
4031 ConversionTarget::DynamicLegalityCallbackFn oldCallback,
4032 ConversionTarget::DynamicLegalityCallbackFn newCallback) {
4033 if (!oldCallback)
4034 return newCallback;
4035
4036 auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
4037 Operation *op) -> std::optional<bool> {
4038 if (std::optional<bool> result = newCl(op))
4039 return *result;
4040
4041 return oldCl(op);
4042 };
4043 return chain;
4044}
4045
4046void ConversionTarget::setLegalityCallback(
4047 OperationName name, const DynamicLegalityCallbackFn &callback) {
4048 assert(callback && "expected valid legality callback");
4049 auto *infoIt = legalOperations.find(name);
4050 assert(infoIt != legalOperations.end() &&
4051 infoIt->second.action == LegalizationAction::Dynamic &&
4052 "expected operation to already be marked as dynamically legal");
4053 infoIt->second.legalityFn =
4054 composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
4055}
4056
4057void ConversionTarget::markOpRecursivelyLegal(
4058 OperationName name, const DynamicLegalityCallbackFn &callback) {
4059 auto *infoIt = legalOperations.find(name);
4060 assert(infoIt != legalOperations.end() &&
4061 infoIt->second.action != LegalizationAction::Illegal &&
4062 "expected operation to already be marked as legal");
4063 infoIt->second.isRecursivelyLegal = true;
4064 if (callback)
4065 opRecursiveLegalityFns[name] = composeLegalityCallbacks(
4066 std::move(opRecursiveLegalityFns[name]), callback);
4067 else
4068 opRecursiveLegalityFns.erase(name);
4069}
4070
4071void ConversionTarget::setLegalityCallback(
4072 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
4073 assert(callback && "expected valid legality callback");
4074 for (StringRef dialect : dialects)
4075 dialectLegalityFns[dialect] = composeLegalityCallbacks(
4076 std::move(dialectLegalityFns[dialect]), callback);
4077}
4078
4079void ConversionTarget::setLegalityCallback(
4080 const DynamicLegalityCallbackFn &callback) {
4081 assert(callback && "expected valid legality callback");
4082 unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
4083}
4084
4085auto ConversionTarget::getOpInfo(OperationName op) const
4086 -> std::optional<LegalizationInfo> {
4087 // Check for info for this specific operation.
4088 const auto *it = legalOperations.find(op);
4089 if (it != legalOperations.end())
4090 return it->second;
4091 // Check for info for the parent dialect.
4092 auto dialectIt = legalDialects.find(op.getDialectNamespace());
4093 if (dialectIt != legalDialects.end()) {
4094 DynamicLegalityCallbackFn callback;
4095 auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
4096 if (dialectFn != dialectLegalityFns.end())
4097 callback = dialectFn->second;
4098 return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
4099 callback};
4100 }
4101 // Otherwise, check if we mark unknown operations as dynamic.
4102 if (unknownLegalityFn)
4103 return LegalizationInfo{LegalizationAction::Dynamic,
4104 /*isRecursivelyLegal=*/false, unknownLegalityFn};
4105 return std::nullopt;
4106}
4107
4108#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
4109//===----------------------------------------------------------------------===//
4110// PDL Configuration
4111//===----------------------------------------------------------------------===//
4112
4113void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
4114 auto &rewriterImpl =
4115 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4116 rewriterImpl.currentTypeConverter = getTypeConverter();
4117}
4118
4119void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
4120 auto &rewriterImpl =
4121 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4122 rewriterImpl.currentTypeConverter = nullptr;
4123}
4124
4125/// Remap the given value using the rewriter and the type converter in the
4126/// provided config.
4127static FailureOr<SmallVector<Value>>
4128pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
4129 SmallVector<Value> mappedValues;
4130 if (failed(rewriter.getRemappedValues(values, mappedValues)))
4131 return failure();
4132 return std::move(mappedValues);
4133}
4134
4135void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
4136 patterns.getPDLPatterns().registerRewriteFunction(
4137 "convertValue",
4138 [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
4139 auto results = pdllConvertValues(
4140 static_cast<ConversionPatternRewriter &>(rewriter), value);
4141 if (failed(results))
4142 return failure();
4143 return results->front();
4144 });
4145 patterns.getPDLPatterns().registerRewriteFunction(
4146 "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
4147 return pdllConvertValues(
4148 static_cast<ConversionPatternRewriter &>(rewriter), values);
4149 });
4150 patterns.getPDLPatterns().registerRewriteFunction(
4151 "convertType",
4152 [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
4153 auto &rewriterImpl =
4154 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4155 if (const TypeConverter *converter =
4156 rewriterImpl.currentTypeConverter) {
4157 if (Type newType = converter->convertType(type))
4158 return newType;
4159 return failure();
4160 }
4161 return type;
4162 });
4163 patterns.getPDLPatterns().registerRewriteFunction(
4164 "convertTypes",
4165 [](PatternRewriter &rewriter,
4166 TypeRange types) -> FailureOr<SmallVector<Type>> {
4167 auto &rewriterImpl =
4168 static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
4169 const TypeConverter *converter = rewriterImpl.currentTypeConverter;
4170 if (!converter)
4171 return SmallVector<Type>(types);
4172
4173 SmallVector<Type> remappedTypes;
4174 if (failed(converter->convertTypes(types, remappedTypes)))
4175 return failure();
4176 return std::move(remappedTypes);
4177 });
4178}
4179#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
4180
4181//===----------------------------------------------------------------------===//
4182// Op Conversion Entry Points
4183//===----------------------------------------------------------------------===//
4184
4185/// This is the type of Action that is dispatched when a conversion is applied.
4187 : public tracing::ActionImpl<ApplyConversionAction> {
4188public:
4191 static constexpr StringLiteral tag = "apply-conversion";
4192 static constexpr StringLiteral desc =
4193 "Encapsulate the application of a dialect conversion";
4194
4195 void print(raw_ostream &os) const override { os << tag; }
4196};
4197
4199 const ConversionTarget &target,
4200 const FrozenRewritePatternSet &patterns,
4201 ConversionConfig config,
4202 OpConversionMode mode) {
4203 if (ops.empty())
4204 return success();
4205 MLIRContext *ctx = ops.front()->getContext();
4206 LogicalResult status = success();
4207 SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
4209 [&] {
4210 OperationConverter opConverter(ops.front()->getContext(), target,
4211 patterns, config, mode);
4212 status = opConverter.applyConversion(ops);
4213 },
4214 irUnits);
4215 return status;
4216}
4217
4218//===----------------------------------------------------------------------===//
4219// Partial Conversion
4220//===----------------------------------------------------------------------===//
4221
4222LogicalResult mlir::applyPartialConversion(
4223 ArrayRef<Operation *> ops, const ConversionTarget &target,
4224 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4225 return applyConversion(ops, target, patterns, config,
4226 OpConversionMode::Partial);
4227}
4228LogicalResult
4229mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
4230 const FrozenRewritePatternSet &patterns,
4231 ConversionConfig config) {
4232 return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
4233}
4234
4235//===----------------------------------------------------------------------===//
4236// Full Conversion
4237//===----------------------------------------------------------------------===//
4238
4239LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
4240 const ConversionTarget &target,
4241 const FrozenRewritePatternSet &patterns,
4242 ConversionConfig config) {
4243 return applyConversion(ops, target, patterns, config, OpConversionMode::Full);
4244}
4245LogicalResult mlir::applyFullConversion(Operation *op,
4246 const ConversionTarget &target,
4247 const FrozenRewritePatternSet &patterns,
4248 ConversionConfig config) {
4249 return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
4250}
4251
4252//===----------------------------------------------------------------------===//
4253// Analysis Conversion
4254//===----------------------------------------------------------------------===//
4255
4256/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
4257/// op is a top-level module op (which is expected to be isolated from above),
4258/// return that op.
4260 // Check if there is a top-level operation within `ops`. If so, return that
4261 // op.
4262 for (Operation *op : ops) {
4263 if (!op->getParentOp()) {
4264#ifndef NDEBUG
4265 assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
4266 "expected top-level op to be isolated from above");
4267 for (Operation *other : ops)
4268 assert(op->isAncestor(other) &&
4269 "expected ops to have a common ancestor");
4270#endif // NDEBUG
4271 return op;
4272 }
4273 }
4274
4275 // No top-level op. Find a common ancestor.
4276 Operation *commonAncestor =
4277 ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
4278 for (Operation *op : ops.drop_front()) {
4279 while (!commonAncestor->isProperAncestor(op)) {
4280 commonAncestor =
4282 assert(commonAncestor &&
4283 "expected to find a common isolated from above ancestor");
4284 }
4285 }
4286
4287 return commonAncestor;
4288}
4289
4290LogicalResult mlir::applyAnalysisConversion(
4291 ArrayRef<Operation *> ops, ConversionTarget &target,
4292 const FrozenRewritePatternSet &patterns, ConversionConfig config) {
4293#ifndef NDEBUG
4294 if (config.legalizableOps)
4295 assert(config.legalizableOps->empty() && "expected empty set");
4296#endif // NDEBUG
4297
4298 // Clone closted common ancestor that is isolated from above.
4299 Operation *commonAncestor = findCommonAncestor(ops);
4300 IRMapping mapping;
4301 Operation *clonedAncestor = commonAncestor->clone(mapping);
4302 // Compute inverse IR mapping.
4303 DenseMap<Operation *, Operation *> inverseOperationMap;
4304 for (auto &it : mapping.getOperationMap())
4305 inverseOperationMap[it.second] = it.first;
4306
4307 // Convert the cloned operations. The original IR will remain unchanged.
4308 SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
4309 ops, [&](Operation *op) { return mapping.lookup(op); });
4310 LogicalResult status = applyConversion(opsToConvert, target, patterns, config,
4311 OpConversionMode::Analysis);
4312
4313 // Remap `legalizableOps`, so that they point to the original ops and not the
4314 // cloned ops.
4315 if (config.legalizableOps) {
4316 DenseSet<Operation *> originalLegalizableOps;
4317 for (Operation *op : *config.legalizableOps)
4318 originalLegalizableOps.insert(inverseOperationMap[op]);
4319 *config.legalizableOps = std::move(originalLegalizableOps);
4320 }
4321
4322 // Erase the cloned IR.
4323 clonedAncestor->erase();
4324 return status;
4325}
4326
4327LogicalResult
4328mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
4329 const FrozenRewritePatternSet &patterns,
4330 ConversionConfig config) {
4331 return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
4332}
return success()
static void setInsertionPointAfter(OpBuilder &b, Value value)
static Operation * getCommonDefiningOp(const ValueVector &values)
Return the operation that defines all values in the vector.
static LogicalResult applyConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config, OpConversionMode mode)
static T moveAndReset(T &obj)
Helper function that moves and returns the given object.
SmallVector< Value, 2 > ValueVector
A vector of SSA values, optimized for the most common case of one or two values.
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(ConversionTarget::DynamicLegalityCallbackFn oldCallback, ConversionTarget::DynamicLegalityCallbackFn newCallback)
static bool isPureTypeConversion(const ValueVector &values)
A vector of values is a pure type conversion if all values are defined by the same operation and the ...
static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, UnrealizedConversionCastOp op, const UnresolvedMaterializationInfo &info)
static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a failure result for the given reason.
static void reconcileUnrealizedCastsImpl(RangeT castOps, function_ref< bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
Try to reconcile all given UnrealizedConversionCastOps and store the left-over ops in remainingCastOp...
static Operation * findCommonAncestor(ArrayRef< Operation * > ops)
Find a common IsolatedFromAbove ancestor of the given ops.
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args)
A utility function to log a successful result for the given reason.
static void performReplaceValue(RewriterBase &rewriter, Value from, Value repl, function_ref< bool(OpOperand &)> functor=nullptr)
Replace all uses of from with repl.
static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector< Operation * > &newOps, const SetVector< Operation * > &modifiedOps)
Report a fatal error indicating that newly produced or modified IR could not be legalized.
static OpBuilder::InsertPoint computeInsertPoint(Value value)
Helper function that computes an insertion point where the given value is defined and can be used wit...
static const StringRef kPureTypeConversionMarker
Marker attribute for pure type conversions.
static SmallVector< Value > getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, const SmallVector< SmallVector< Value > > &toRange, const TypeConverter *converter)
Given that fromRange is about to be replaced with toRange, compute replacement values with the types ...
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
#define DEBUG_TYPE
b getContext())
static std::string diag(const llvm::Value &value)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition SCCP.cpp:67
This is the type of Action that is dispatched when a conversion is applied.
tracing::ActionImpl< ApplyConversionAction > Base
static constexpr StringLiteral desc
ApplyConversionAction(ArrayRef< IRUnit > irUnits)
void print(raw_ostream &os) const override
static constexpr StringLiteral tag
This class represents an argument of a Block.
Definition Value.h:306
Location getLoc() const
Return the location for this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:318
BlockArgListType getArguments()
Definition Block.h:97
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
UnitAttr getUnitAttr()
Definition Builders.cpp:102
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
MLIRContext * context
Definition Builders.h:204
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
Definition Dominance.h:143
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:161
This class represents a frozen set of patterns that can be processed by a pattern applicator.
const DenseMap< Operation *, Operation * > & getOperationMap() const
Return the held operation mapping.
Definition IRMapping.h:88
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
user_range getUsers() const
Returns a range of all users.
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
This class represents a saved insertion point.
Definition Builders.h:329
Block::iterator getPoint() const
Definition Builders.h:342
bool isSet() const
Returns true if this insert point is set.
Definition Builders.h:339
Block * getBlock() const
Definition Builders.h:341
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
LogicalResult tryFold(Operation *op, SmallVectorImpl< Value > &results, SmallVectorImpl< Operation * > *materializedConstants=nullptr)
Attempts to fold the given operation and places new results within results.
Definition Builders.cpp:478
Operation * insert(Operation *op)
Insert the given operation at the current insertion point and return it.
Definition Builders.cpp:426
This class represents an operand of an operation.
Definition Value.h:254
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition Value.h:454
This class provides the API for ops that are known to be isolated from above.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1142
type_range getTypes() const
void destroyOpProperties(PropertyRef properties) const
This hooks destroy the op properties.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
TypeID getOpPropertiesTypeID() const
Return the TypeID of the op properties.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
PropertyRef getPropertiesStorage()
Return a generic (but typed) reference to the property type storage.
Definition Operation.h:926
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:243
void copyProperties(PropertyRef rhs)
Copy properties from an existing other properties object.
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:877
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:774
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:859
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:585
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:273
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:699
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
OperandRange operand_range
Definition Operation.h:396
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:702
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
void setSuccessor(Block *block, unsigned index)
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:288
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
result_range getResults()
Definition Operation.h:440
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
static PatternBenefit impossibleToMatch()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
std::optional< OperationName > getRootKind() const
Return the root node that this pattern matches.
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
StringRef getDebugName() const
Return a readable name for this pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void cancelOpModification(Operation *op)
This method cancels a pending in-place modification.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
CRTP Implementation of an action.
Definition Action.h:76
ArrayRef< IRUnit > irUnits
Set of IR units (operations, regions, blocks, values) that are associated with this action.
Definition Action.h:66
AttrTypeReplacer.
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
static void reconcileUnrealizedCasts(const llvm::MapVector< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
This iterator enumerates elements according to their dominance relationship.
Definition Iterators.h:52
virtual void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt)
Notify the listener that the specified block was inserted.
Definition Builders.h:310
virtual void notifyOperationInserted(Operation *op, InsertPoint previous)
Notify the listener that the specified operation was inserted.
Definition Builders.h:300
OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode)
const ConversionTarget & getTarget()
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, bool isRecursiveLegalization=false)
LogicalResult convert(Operation *op, bool isRecursiveLegalization=false)
Converts a single operation.
LogicalResult legalizeOperations(ArrayRef< Operation * > ops, Fn onFailure, bool isRecursiveLegalization=false)
Legalizes the given operations (and their nested operations) to the conversion target.
LogicalResult applyConversion(ArrayRef< Operation * > ops)
Applies the conversion to the given operations (and their nested operations).
virtual void notifyOperationModified(Operation *op)
Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationErased(Operation *op)
Notify the listener that the specified operation is about to be erased.
virtual void notifyOperationReplaced(Operation *op, Operation *replacement)
Notify the listener that all uses of the specified operation's results are about to be replaced with ...
virtual void notifyBlockErased(Block *block)
Notify the listener that the specified block is about to be erased.
A rewriter that keeps track of erased ops and blocks.
SingleEraseRewriter(MLIRContext *context, std::function< void(Operation *)> opErasedCallback=nullptr)
void eraseOp(Operation *op) override
Erase the given op (unless it was already erased).
void notifyBlockErased(Block *block) override
Notify the listener that the specified block is about to be erased.
void notifyOperationErased(Operation *op) override
Notify the listener that the specified operation is about to be erased.
void eraseBlock(Block *block) override
Erase the given block (unless it was already erased).
llvm::impl::raw_ldbg_ostream os
A raw output stream used to prefix the debug log.
void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override
Notify the listener that the specified operation was inserted.
Value findOrBuildReplacementValue(Value value, const TypeConverter *converter)
Find a replacement value for the given SSA value in the conversion value mapping.
SetVector< Operation * > patternNewOps
A set of operations that were created by the current pattern.
void replaceValueUses(Value from, ValueRange to, const TypeConverter *converter, function_ref< bool(OpOperand &)> functor=nullptr)
Replace the uses of the given value with the given values.
DenseSet< Block * > erasedBlocks
A set of erased blocks.
llvm::MapVector< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > unresolvedMaterializations
A mapping for looking up metadata of unresolved materializations.
DenseMap< Region *, const TypeConverter * > regionToConverter
A mapping of regions to type converters that should be used when converting the arguments of blocks w...
bool wasOpReplaced(Operation *op) const
Return "true" if the given operation was replaced or erased.
ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config, OperationConverter &opConverter)
void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override
Notifies that a block was inserted.
void undoRewrites(unsigned numRewritesToKeep=0, StringRef patternName="")
Undo the rewrites (motions, splits) one by one in reverse order until "numRewritesToKeep" rewrites re...
LogicalResult remapValues(StringRef valueDiagTag, std::optional< Location > inputLoc, ValueRange values, SmallVector< ValueVector > &remapped)
Remap the given values to those with potentially different types.
ValueRange buildUnresolvedMaterialization(MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, bool isPureTypeConversion=true)
Build an unresolved materialization operation given a range of output types and a list of input opera...
DenseSet< UnrealizedConversionCastOp > patternMaterializations
A list of unresolved materializations that were created by the current pattern.
void resetState(RewriterState state, StringRef patternName="")
Reset the state of the rewriter to a previously saved point.
void applyRewrites()
Apply all requested operation rewrites.
Block * applySignatureConversion(Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion)
Apply the given signature conversion on the given block.
void inlineBlockBefore(Block *source, Block *dest, Block::iterator before)
Inline the source block into the destination block before the given iterator.
void replaceOp(Operation *op, SmallVector< SmallVector< Value > > &&newValues)
Replace the results of the given operation with the given values and erase the operation.
RewriterState getCurrentState()
Return the current state of the rewriter.
llvm::ScopedPrinter logger
A logger used to emit diagnostics during the conversion process.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes={}) const
Lookup the given value within the map, or return an empty vector if the value is not mapped.
void appendRewrite(Args &&...args)
Append a rewrite.
SmallPtrSet< Operation *, 1 > pendingRootUpdates
A set of operations that have pending updates.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notifies that a pattern match failed for the given reason.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes={}, bool skipPureTypeConversions=false) const
Lookup the most recently mapped values with the desired types in the mapping, taking into account onl...
bool isOpIgnored(Operation *op) const
Return "true" if the given operation is ignored, and does not need to be converted.
IRRewriter notifyingRewriter
A rewriter that notifies the listener (if any) about all IR modifications.
OperationConverter & opConverter
The operation converter to use for recursive legalization.
DenseSet< Value > replacedValues
A set of replaced values.
DenseSet< Operation * > erasedOps
A set of erased operations.
void eraseBlock(Block *block)
Erase the given block and its contents.
SetVector< Operation * > ignoredOps
A set of operations that should no longer be considered for legalization.
SmallVector< std::unique_ptr< IRRewrite > > rewrites
Ordered list of block operations (creations, splits, motions).
SetVector< Operation * > patternModifiedOps
A set of operations that were modified by the current pattern.
const ConversionConfig & config
Dialect conversion configuration.
SetVector< Operation * > replacedOps
A set of operations that were replaced/erased.
ConversionPatternRewriter & rewriter
The rewriter that is used to perform the conversion.
const TypeConverter * currentTypeConverter
The current type converter, or nullptr if no type converter is currently active.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion)
Convert the types of block arguments within the given region.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.